Skip to content

Commit

Permalink
Configurable number of retries
Browse files Browse the repository at this point in the history
  • Loading branch information
alpe committed Jan 16, 2024
1 parent b74a395 commit 296c6c4
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 30 deletions.
4 changes: 3 additions & 1 deletion cmd/lingo/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,13 @@ func run() error {
var metricsAddr string
var probeAddr string
var concurrencyPerReplica int
var maxRetriesOnErr int

flag.StringVar(&metricsAddr, "metrics-bind-address", ":8082", "The address the metric endpoint binds to.")
flag.StringVar(&probeAddr, "health-probe-bind-address", ":8081", "The address the probe endpoint binds to.")
flag.IntVar(&concurrencyPerReplica, "concurrency", concurrency, "the number of simultaneous requests that can be processed by each replica")
flag.IntVar(&scaleDownDelay, "scale-down-delay", scaleDownDelay, "seconds to wait before scaling down")
flag.IntVar(&maxRetriesOnErr, "max-retries", 0, "max number of retries on a http error code: 502,503,504")
opts := zap.Options{
Development: true,
}
Expand Down Expand Up @@ -154,7 +156,7 @@ func run() error {

proxy.MustRegister(metricsRegistry)
var proxyHandler http.Handler = proxy.NewHandler(deploymentManager, endpointManager, queueManager)
proxyHandler = proxy.NewRetryMiddleware(3, proxyHandler)
proxyHandler = proxy.NewRetryMiddleware(maxRetriesOnErr, proxyHandler)
proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler}

statsHandler := &stats.Handler{
Expand Down
69 changes: 46 additions & 23 deletions pkg/proxy/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,19 +12,34 @@ import (
var _ http.Handler = &RetryMiddleware{}

type RetryMiddleware struct {
nextHandler http.Handler
MaxRetries int
rSource *rand.Rand
nextHandler http.Handler
maxRetries int
rSource *rand.Rand
retryStatusCodes map[int]struct{}
}

func NewRetryMiddleware(maxRetries int, other http.Handler) *RetryMiddleware {
if maxRetries < 1 {
panic("invalid retries")
// NewRetryMiddleware creates a new HTTP middleware that adds retry functionality.
// It takes the maximum number of retries, the next handler in the middleware chain,
// and an optional list of retryable status codes.
// If the maximum number of retries is 0, it returns the next handler without adding any retries.
// If the list of retryable status codes is empty, it uses a default set of status codes (http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout).
// The function creates a RetryMiddleware struct with the given parameters and returns it as an http.Handler.
func NewRetryMiddleware(maxRetries int, other http.Handler, optRetryStatusCodes ...int) http.Handler {
if maxRetries == 0 {
return other
}
if len(optRetryStatusCodes) == 0 {
optRetryStatusCodes = []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout}
}
statusCodeIndex := make(map[int]struct{}, len(optRetryStatusCodes))
for _, c := range optRetryStatusCodes {
statusCodeIndex[c] = struct{}{}
}
return &RetryMiddleware{
nextHandler: other,
MaxRetries: maxRetries,
rSource: rand.New(rand.NewSource(time.Now().UnixNano())),
nextHandler: other,
maxRetries: maxRetries,
retryStatusCodes: statusCodeIndex,
rSource: rand.New(rand.NewSource(time.Now().UnixNano())),
}
}

Expand All @@ -34,12 +49,12 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req
buf: bytes.NewBuffer([]byte{}),
}
request.Body = lazyBody
var capturedResp *responseWriterDelegator
for i := 0; ; i++ {
capturedResp = &responseWriterDelegator{
ResponseWriter: writer,
headerBuf: make(http.Header),
discardErrResp: i < r.MaxRetries &&
capturedResp := &responseWriterDelegator{
isRetryableStatusCode: r.isRetryableStatusCode,
ResponseWriter: writer,
headerBuf: make(http.Header),
discardErrResp: i < r.maxRetries &&
request.Context().Err() == nil, // abort early on timeout, context cancel
}
// call next handler in chain
Expand All @@ -50,7 +65,7 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req
r.nextHandler.ServeHTTP(capturedResp, req)
lazyBody.Capture()
if !capturedResp.discardErrResp || // max retries reached
!isRetryableStatusCode(capturedResp.statusCode) {
!r.isRetryableStatusCode(capturedResp.statusCode) {
break
}
totalRetries.Inc()
Expand All @@ -60,24 +75,27 @@ func (r RetryMiddleware) ServeHTTP(writer http.ResponseWriter, request *http.Req
}
}

func isRetryableStatusCode(status int) bool {
return status == http.StatusBadGateway ||
status == http.StatusServiceUnavailable ||
status == http.StatusGatewayTimeout
func (r RetryMiddleware) isRetryableStatusCode(status int) bool {
_, ok := r.retryStatusCodes[status]
return ok
}

var (
_ http.Flusher = &responseWriterDelegator{}
_ io.ReaderFrom = &responseWriterDelegator{}
)

// responseWriterDelegator represents a wrapper around http.ResponseWriter that provides additional
// functionalities for handling response writing. Depending on the status code and discard settings,
// the heeader + content on write is skipped so that it can be re-used on retry.
type responseWriterDelegator struct {
http.ResponseWriter
headerBuf http.Header
wroteHeader bool
statusCode int
// always writes to responseWriter when false
discardErrResp bool
discardErrResp bool
isRetryableStatusCode func(status int) bool
}

func (r *responseWriterDelegator) Header() http.Header {
Expand All @@ -91,7 +109,7 @@ func (r *responseWriterDelegator) WriteHeader(status int) {
// any 1xx informational response should be written
r.discardErrResp = r.discardErrResp && !(status >= 100 && status < 200)
}
if r.discardErrResp && isRetryableStatusCode(status) {
if r.discardErrResp && r.isRetryableStatusCode(status) {
return
}
// copy header values to target
Expand All @@ -103,12 +121,17 @@ func (r *responseWriterDelegator) WriteHeader(status int) {
r.ResponseWriter.WriteHeader(status)
}

// Write writes data to the response.
// If the response header has not been set, it sets the default status code to 200.
// When the status code qualifies for a retry, no content is written.
//
// It returns the number of bytes written and any error encountered.
func (r *responseWriterDelegator) Write(data []byte) (int, error) {
// ensure header is set. default is 200 in Go stdlib
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
if r.discardErrResp && isRetryableStatusCode(r.statusCode) {
if r.discardErrResp && r.isRetryableStatusCode(r.statusCode) {
return io.Discard.Write(data)
} else {
return r.ResponseWriter.Write(data)
Expand All @@ -120,7 +143,7 @@ func (r *responseWriterDelegator) ReadFrom(re io.Reader) (int64, error) {
if !r.wroteHeader {
r.WriteHeader(http.StatusOK)
}
if r.discardErrResp && isRetryableStatusCode(r.statusCode) {
if r.discardErrResp && r.isRetryableStatusCode(r.statusCode) {
return io.Copy(io.Discard, re)
} else {
return r.ResponseWriter.(io.ReaderFrom).ReadFrom(re)
Expand Down
20 changes: 14 additions & 6 deletions tests/integration/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package integration
import (
"bytes"
"fmt"
"io"
"log"
"net/http"
"net/http/httptest"
Expand Down Expand Up @@ -145,6 +146,8 @@ func TestRetryMiddleware(t *testing.T) {
code := serverCodes[i-1]
t.Logf("Serving request from testBackend: %d; code: %d\n", i, code)
w.WriteHeader(code)
_, err := w.Write([]byte(strconv.Itoa(code)))
require.NoError(t, err)
}))

// Mock an EndpointSlice.
Expand Down Expand Up @@ -183,11 +186,9 @@ func TestRetryMiddleware(t *testing.T) {
backendRequests.Store(0)

// when single request sent
var wg sync.WaitGroup
sendRequest(t, &wg, modelName, spec.expResultCode)
wg.Wait()

// then
gotBody := <-sendRequest(t, &sync.WaitGroup{}, modelName, spec.expResultCode)
// then only the last body is written
assert.Equal(t, strconv.Itoa(spec.expResultCode), gotBody)
require.Equal(t, spec.expBackendHits, backendRequests.Load(), "ensure backend hit with retries")
})
}
Expand Down Expand Up @@ -222,9 +223,10 @@ func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int, exp
}
}

func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) {
func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) <-chan string {
t.Helper()
wg.Add(1)
bodyRespChan := make(chan string, 1)
go func() {
defer wg.Done()

Expand All @@ -235,7 +237,13 @@ func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int
res, err := testHTTPClient.Do(req)
require.NoError(t, err)
require.Equal(t, expCode, res.StatusCode)
got, err := io.ReadAll(res.Body)
_ = res.Body.Close()
require.NoError(t, err)
bodyRespChan <- string(got)
close(bodyRespChan)
}()
return bodyRespChan
}

func completeRequests(c chan struct{}, n int) {
Expand Down

0 comments on commit 296c6c4

Please sign in to comment.