From 4f42cde8e6a9f237111e629fbe0bd2ee9d4fdb88 Mon Sep 17 00:00:00 2001 From: Nick Stogner Date: Sat, 7 Dec 2024 15:22:27 -0500 Subject: [PATCH] Fix unit tests --- internal/loadbalancer/group.go | 7 +++-- internal/loadbalancer/group_test.go | 31 +++++++++++++++------ internal/loadbalancer/load_balancer_test.go | 6 ++-- internal/modelproxy/request.go | 14 +++++----- 4 files changed, 39 insertions(+), 19 deletions(-) diff --git a/internal/loadbalancer/group.go b/internal/loadbalancer/group.go index d3d19b18..4528109c 100644 --- a/internal/loadbalancer/group.go +++ b/internal/loadbalancer/group.go @@ -13,9 +13,11 @@ import ( func newEndpointGroup() *group { g := &group{ endpoints: make(map[string]endpoint), + totalInFlight: &atomic.Int64{}, chwblReplication: 100, chwblHashes: map[uint64]string{}, chwblSortedHashes: []uint64{}, + bcast: make(chan struct{}), } return g } @@ -103,8 +105,8 @@ func (g *group) getAllAddrs() []string { defer g.mtx.RUnlock() var hosts []string - for ip := range g.endpoints { - hosts = append(hosts, ip) + for _, ep := range g.endpoints { + hosts = append(hosts, ep.address) } return hosts @@ -115,6 +117,7 @@ func (g *group) reconcileEndpoints(observed map[string]endpoint) { for name, observedEp := range observed { if currentEp, ok := g.endpoints[name]; ok { currentEp.adapters = observedEp.adapters + g.endpoints[name] = currentEp } else { g.endpoints[name] = endpoint{ inFlight: &atomic.Int64{}, diff --git a/internal/loadbalancer/group_test.go b/internal/loadbalancer/group_test.go index a56b698f..f7fdacba 100644 --- a/internal/loadbalancer/group_test.go +++ b/internal/loadbalancer/group_test.go @@ -8,31 +8,46 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + v1 "github.com/substratusai/kubeai/api/v1" "k8s.io/apimachinery/pkg/util/rand" ) func TestConcurrentAccess(t *testing.T) { - const myModel = "myModel" + const ( + myModel = "myModel" + myAddr = "10.0.0.1:8000" + ) testCases := map[string]struct { readerCount int writerCount int }{ - "lot of reader": {readerCount: 1_000, writerCount: 1}, - "lot of writer": {readerCount: 1, writerCount: 1_000}, - "lot of both": {readerCount: 1_000, writerCount: 1_000}, + "one reader_one_writer": {readerCount: 1, writerCount: 1}, + "lot of reader": {readerCount: 1_000, writerCount: 1}, + "lot of writer": {readerCount: 1, writerCount: 1_000}, + "lot of both": {readerCount: 1_000, writerCount: 1_000}, } for name, spec := range testCases { randomReadFn := []func(g *group){ - func(g *group) { g.getBestAddr(context.Background(), AddressRequest{}, false) }, + func(g *group) { + ip, f, err := g.getBestAddr(context.Background(), AddressRequest{ + Model: myModel, + LoadBalancing: v1.LoadBalancing{ + Strategy: v1.LeastLoadStrategy, + }, + }, false) + require.NoError(t, err) + defer f() + assert.Equal(t, myAddr, ip) + }, func(g *group) { g.getAllAddrs() }, } t.Run(name, func(t *testing.T) { - // setup endpoint with one service so that requests are not waiting + // setup endpoint with one endpoint so that requests are not waiting group := newEndpointGroup() group.reconcileEndpoints( - map[string]endpoint{myModel: {}}, + map[string]endpoint{myModel: {address: myAddr}}, ) var startWg, doneWg sync.WaitGroup @@ -52,7 +67,7 @@ func TestConcurrentAccess(t *testing.T) { startTogether(spec.readerCount, func() { randomReadFn[rand.Intn(len(randomReadFn)-1)](group) }) startTogether(spec.writerCount, func() { group.reconcileEndpoints( - map[string]endpoint{rand.String(1): {}}, + map[string]endpoint{myModel: {address: myAddr}}, ) }) doneWg.Wait() diff --git a/internal/loadbalancer/load_balancer_test.go b/internal/loadbalancer/load_balancer_test.go index caa1b809..4ce4d7f5 100644 --- a/internal/loadbalancer/load_balancer_test.go +++ b/internal/loadbalancer/load_balancer_test.go @@ -20,8 +20,6 @@ func TestAwaitBestHost(t *testing.T) { myAddrWithAdapter = "10.0.0.2:8000" ) - manager := &LoadBalancer{endpoints: make(map[string]*group, 1)} - testCases := map[string]struct { model string adapter string @@ -63,6 +61,10 @@ func TestAwaitBestHost(t *testing.T) { for name, spec := range testCases { t.Run(name, func(t *testing.T) { + manager := &LoadBalancer{ + endpoints: make(map[string]*group, 1), + } + manager.getEndpoints(myModel).reconcileEndpoints(spec.endpoints) ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) diff --git a/internal/modelproxy/request.go b/internal/modelproxy/request.go index 2c7d81cd..8b664309 100644 --- a/internal/modelproxy/request.go +++ b/internal/modelproxy/request.go @@ -23,18 +23,18 @@ type proxyRequest struct { } func newProxyRequest(r *http.Request) (*proxyRequest, error) { + pr := &proxyRequest{ + http: r, + status: http.StatusOK, + } + apiReq, err := apiutils.ParseRequest(r.Body, r.Header) if err != nil { - return nil, err + return pr, err } // The content length might have changed after the body was read and rewritten. r.ContentLength = apiReq.ContentLength - - pr := &proxyRequest{ - Request: apiReq, - http: r, - status: http.StatusOK, - } + pr.Request = apiReq return pr, nil }