diff --git a/cmd/lingo/main.go b/cmd/lingo/main.go index 06b1e0a6..b561b0d9 100644 --- a/cmd/lingo/main.go +++ b/cmd/lingo/main.go @@ -155,8 +155,7 @@ func run() error { go autoscaler.Start() proxy.MustRegister(metricsRegistry) - var proxyHandler http.Handler = proxy.NewHandler(deploymentManager, endpointManager, queueManager) - proxyHandler = proxy.NewRetryMiddleware(maxRetriesOnErr, proxyHandler) + var proxyHandler http.Handler = proxy.NewHandler(deploymentManager, endpointManager, queueManager, maxRetriesOnErr) proxyServer := &http.Server{Addr: ":8080", Handler: proxyHandler} statsHandler := &stats.Handler{ diff --git a/pkg/proxy/handler.go b/pkg/proxy/handler.go index 4acfb57e..cec4b1f6 100644 --- a/pkg/proxy/handler.go +++ b/pkg/proxy/handler.go @@ -28,13 +28,19 @@ type deploymentSource interface { // Handler serves http requests for end-clients. // It is also responsible for triggering scale-from-zero. type Handler struct { - Deployments deploymentSource - Endpoints *endpoints.Manager - Queues *queue.Manager + Deployments deploymentSource + Endpoints *endpoints.Manager + Queues *queue.Manager + retriesOnErr int } -func NewHandler(deployments deploymentSource, endpoints *endpoints.Manager, queues *queue.Manager) *Handler { - return &Handler{Deployments: deployments, Endpoints: endpoints, Queues: queues} +func NewHandler(deployments deploymentSource, endpoints *endpoints.Manager, queues *queue.Manager, retriesOnErr int) *Handler { + return &Handler{ + Deployments: deployments, + Endpoints: endpoints, + Queues: queues, + retriesOnErr: retriesOnErr, + } } func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { @@ -120,7 +126,8 @@ func (h *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) { log.Printf("Proxying request to host %v: %v\n", host, id) // TODO: Avoid creating new reverse proxies for each request. // TODO: Consider implementing a round robin scheme. - newReverseProxy(host).ServeHTTP(w, proxyRequest) + proxy := newReverseProxy(host) + NewRetryMiddleware(h.retriesOnErr, proxy).ServeHTTP(w, proxyRequest) } // parseModel parses the model name from the request diff --git a/pkg/proxy/handler_test.go b/pkg/proxy/handler_test.go index 36fbf08b..0442bc96 100644 --- a/pkg/proxy/handler_test.go +++ b/pkg/proxy/handler_test.go @@ -35,7 +35,7 @@ func TestProxy(t *testing.T) { em, err := endpoints.NewManager(&fakeManager{}, func(deploymentName string, replicas int) {}) require.NoError(t, err) em.SetEndpoints("my-deployment", map[string]struct{}{"my-ip": {}}, map[string]int32{"my-port": 8080}) - h := NewHandler(deplMgr, em, queue.NewManager(10)) + h := NewHandler(deplMgr, em, queue.NewManager(10), 1) svr := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { em.SetEndpoints("my-deployment", map[string]struct{}{"my-other-ip": {}}, map[string]int32{"my-other-port": 8080}) diff --git a/pkg/proxy/metrics_test.go b/pkg/proxy/metrics_test.go index a3c1a7ea..79bc68ff 100644 --- a/pkg/proxy/metrics_test.go +++ b/pkg/proxy/metrics_test.go @@ -56,7 +56,7 @@ func TestMetrics(t *testing.T) { deplMgr, err := deployments.NewManager(&fakeManager{}) require.NoError(t, err) - h := NewHandler(deplMgr, nil, nil) + h := NewHandler(deplMgr, nil, nil, 2) recorder := httptest.NewRecorder() // when diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index 36d8987c..e7f41684 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -104,12 +104,8 @@ func TestMain(m *testing.M) { autoscaler.Endpoints = endpointManager go autoscaler.Start() - handler := &proxy.Handler{ - Deployments: deploymentManager, - Endpoints: endpointManager, - Queues: queueManager, - } - testServer = httptest.NewServer(proxy.NewRetryMiddleware(3, handler)) + handler := proxy.NewHandler(deploymentManager, endpointManager, queueManager, 3) + testServer = httptest.NewServer(handler) defer testServer.Close() go func() {