From 009ecaf28ec0765fed272937b97c421594a2eb35 Mon Sep 17 00:00:00 2001 From: Alex Peters Date: Wed, 31 Jan 2024 17:03:46 +0100 Subject: [PATCH] Add integration test --- tests/integration/integration_test.go | 137 +++++++++++++++++++++----- tests/integration/main_test.go | 1 + 2 files changed, 114 insertions(+), 24 deletions(-) diff --git a/tests/integration/integration_test.go b/tests/integration/integration_test.go index 10fb5e45..6325ba7a 100644 --- a/tests/integration/integration_test.go +++ b/tests/integration/integration_test.go @@ -3,6 +3,7 @@ package integration import ( "bytes" "fmt" + "io" "log" "net/http" "net/http/httptest" @@ -43,17 +44,7 @@ func TestScaleUpAndDown(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -103,17 +94,7 @@ func TestHandleModelUndeployment(t *testing.T) { })) // Mock an EndpointSlice. - testBackendURL, err := url.Parse(testBackend.URL) - require.NoError(t, err) - testBackendPort, err := strconv.Atoi(testBackendURL.Port()) - require.NoError(t, err) - require.NoError(t, testK8sClient.Create(testCtx, - endpointSlice( - modelName, - testBackendURL.Hostname(), - int32(testBackendPort), - ), - )) + withMockEndpointSlice(t, testBackend, modelName) // Wait for deployment mapping to sync. time.Sleep(3 * time.Second) @@ -132,7 +113,7 @@ func TestHandleModelUndeployment(t *testing.T) { require.NoError(t, testK8sClient.Delete(testCtx, deploy)) // Check that the deployment was deleted - err = testK8sClient.Get(testCtx, client.ObjectKey{ + err := testK8sClient.Get(testCtx, client.ObjectKey{ Namespace: deploy.Namespace, Name: deploy.Name, }, deploy) @@ -151,6 +132,100 @@ func TestHandleModelUndeployment(t *testing.T) { wg.Wait() } +func TestRetryMiddleware(t *testing.T) { + const modelName = "test-model-c" + deploy := testDeployment(modelName) + require.NoError(t, testK8sClient.Create(testCtx, deploy)) + + // Wait for deployment mapping to sync. + time.Sleep(3 * time.Second) + backendRequests := &atomic.Int32{} + var serverCodes []int + testBackend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + expBody := []byte(fmt.Sprintf(`{"model": %q}`, modelName)) + gotBody, err := io.ReadAll(r.Body) + require.NoError(t, err) + assert.Equal(t, expBody, gotBody) + + i := backendRequests.Add(1) + code := serverCodes[i-1] + t.Logf("Serving request from testBackend: %d; code: %d\n", i, code) + w.WriteHeader(code) + _, err = w.Write([]byte(strconv.Itoa(code))) + require.NoError(t, err) + })) + + // Mock an EndpointSlice. + withMockEndpointSlice(t, testBackend, modelName) + + specs := map[string]struct { + serverCodes []int + header []tuple + expResultCode int + expResultBody string + expBackendHits int32 + }{ + "max retries - succeeds": { + serverCodes: []int{http.StatusBadGateway, http.StatusServiceUnavailable, http.StatusGatewayTimeout, http.StatusOK}, + expResultCode: http.StatusOK, + expResultBody: "200", + expBackendHits: 4, + }, + "max retries - fails": { + serverCodes: []int{http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusServiceUnavailable, http.StatusBadGateway}, + expResultCode: http.StatusBadGateway, + expResultBody: "{\"error\":\"Bad Gateway\"}\n", // note the linebreak + expBackendHits: 4, + }, + "non retryable error code": { + serverCodes: []int{http.StatusNotImplemented}, + expResultCode: http.StatusNotImplemented, + expResultBody: "501", + expBackendHits: 1, + }, + "200 status code": { + serverCodes: []int{http.StatusOK}, + expResultCode: http.StatusOK, + expResultBody: "200", + expBackendHits: 1, + }, + "200 status code - model header": { + serverCodes: []int{http.StatusOK}, + header: []tuple{{k: "X-Model", v: modelName}}, + expResultCode: http.StatusOK, + expResultBody: "200", + expBackendHits: 1, + }, + } + for name, spec := range specs { + t.Run(name, func(t *testing.T) { + // setup + serverCodes = spec.serverCodes + backendRequests.Store(0) + + // when single request sent + gotBody := <-sendRequest(t, &sync.WaitGroup{}, modelName, spec.expResultCode, spec.header...) + // then only the last body is written + assert.Equal(t, spec.expResultBody, gotBody) + require.Equal(t, spec.expBackendHits, backendRequests.Load(), "ensure backend hit") + }) + } +} + +func withMockEndpointSlice(t *testing.T, testBackend *httptest.Server, modelName string) { + testBackendURL, err := url.Parse(testBackend.URL) + require.NoError(t, err) + testBackendPort, err := strconv.Atoi(testBackendURL.Port()) + require.NoError(t, err) + require.NoError(t, testK8sClient.Create(testCtx, + endpointSlice( + modelName, + testBackendURL.Hostname(), + int32(testBackendPort), + ), + )) +} + func requireDeploymentReplicas(t *testing.T, deploy *appsv1.Deployment, n int32) { require.EventuallyWithT(t, func(t *assert.CollectT) { err := testK8sClient.Get(testCtx, types.NamespacedName{Namespace: deploy.Namespace, Name: deploy.Name}, deploy) @@ -166,20 +241,34 @@ func sendRequests(t *testing.T, wg *sync.WaitGroup, modelName string, n int, exp } } -func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int) { +type tuple struct { + k, v string +} + +func sendRequest(t *testing.T, wg *sync.WaitGroup, modelName string, expCode int, headers ...tuple) <-chan string { t.Helper() wg.Add(1) + bodyRespChan := make(chan string, 1) go func() { defer wg.Done() + defer close(bodyRespChan) body := []byte(fmt.Sprintf(`{"model": %q}`, modelName)) req, err := http.NewRequest(http.MethodPost, testServer.URL, bytes.NewReader(body)) requireNoError(err) + for _, e := range headers { + req.Header.Add(e.k, e.v) + } res, err := testHTTPClient.Do(req) require.NoError(t, err) require.Equal(t, expCode, res.StatusCode) + got, err := io.ReadAll(res.Body) + _ = res.Body.Close() + require.NoError(t, err) + bodyRespChan <- string(got) }() + return bodyRespChan } func completeRequests(c chan struct{}, n int) { diff --git a/tests/integration/main_test.go b/tests/integration/main_test.go index 74697559..95f3c2d1 100644 --- a/tests/integration/main_test.go +++ b/tests/integration/main_test.go @@ -109,6 +109,7 @@ func TestMain(m *testing.M) { Deployments: deploymentManager, Endpoints: endpointManager, Queues: queueManager, + MaxRetries: 3, } testServer = httptest.NewServer(handler) defer testServer.Close()