diff --git a/pkg/endpoints/endpoints.go b/pkg/endpoints/endpoints.go index e98649f4..e5112a92 100644 --- a/pkg/endpoints/endpoints.go +++ b/pkg/endpoints/endpoints.go @@ -24,9 +24,8 @@ type endpointGroup struct { ports map[string]int32 endpoints map[string]endpoint - bmtx sync.Mutex - bcast chan struct{} // closed when there's a broadcast - listeners []chan struct{} // keeps track of all listeners + bmtx sync.RWMutex + bcast chan struct{} // closed when there's a broadcast } // getBestHost returns the best host for the given port name. It blocks until there are available endpoints @@ -45,12 +44,10 @@ func (e *endpointGroup) getBestHost(ctx context.Context, portName string) (strin // await endpoints exists for len(e.endpoints) == 0 { e.mtx.RUnlock() - _, err := execWithCtxAbort(ctx, func() any { - <-e.waitForEndpoints(ctx) - return nil - }) - if err != nil { - return "", err + select { + case <-e.bcast: + case <-ctx.Done(): + return "", ctx.Err() } e.mtx.RLock() } @@ -117,31 +114,10 @@ func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32) } } -func (e *endpointGroup) waitForEndpoints(ctx context.Context) <-chan struct{} { - e.bmtx.Lock() - defer e.bmtx.Unlock() - - ch := make(chan struct{}) - e.listeners = append(e.listeners, ch) - - go execWithCtxAbort(ctx, func() any { - <-e.bcast - close(ch) - return nil - }) - - return ch -} - func (g *endpointGroup) broadcastEndpoints() { g.bmtx.Lock() defer g.bmtx.Unlock() close(g.bcast) g.bcast = make(chan struct{}) - - for _, ch := range g.listeners { - close(ch) - } - clear(g.listeners) } diff --git a/pkg/endpoints/endpoints_test.go b/pkg/endpoints/endpoints_test.go index 810e4064..888bf2f0 100644 --- a/pkg/endpoints/endpoints_test.go +++ b/pkg/endpoints/endpoints_test.go @@ -1,9 +1,14 @@ package endpoints import ( + "context" "sync" + "sync/atomic" "testing" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "k8s.io/apimachinery/pkg/util/rand" ) @@ -15,13 +20,18 @@ func TestConcurrentAccess(t *testing.T) { readerCount int writerCount int }{ - "lot of reader": {readerCount: 10_000, writerCount: 1}, - "lot of writer": {readerCount: 1, writerCount: 10_000}, - "lot of both": {readerCount: 10_000, writerCount: 10_000}, + "lot of reader": {readerCount: 1_000, writerCount: 1}, + "lot of writer": {readerCount: 1, writerCount: 1_000}, + "lot of both": {readerCount: 1_000, writerCount: 1_000}, } - for name, spec := range testCases { + randomReadFn := []func(g *endpointGroup){ + func(g *endpointGroup) { g.getBestHost(nil, myPort) }, + func(g *endpointGroup) { g.getAllHosts(myPort) }, + func(g *endpointGroup) { g.lenIPs() }, + } t.Run(name, func(t *testing.T) { + // setup endpoint with one service so that requests are not waiting endpoint := newEndpointGroup() endpoint.setIPs( map[string]struct{}{myService: {}}, @@ -41,7 +51,8 @@ func TestConcurrentAccess(t *testing.T) { }() } } - startTogether(spec.readerCount, func() { endpoint.getBestHost(nil, myPort) }) + // when + startTogether(spec.readerCount, func() { randomReadFn[rand.Intn(len(randomReadFn)-1)](endpoint) }) startTogether(spec.writerCount, func() { endpoint.setIPs( map[string]struct{}{rand.String(1): {}}, @@ -52,3 +63,55 @@ func TestConcurrentAccess(t *testing.T) { }) } } + +func TestBlockAndWaitForEndpoints(t *testing.T) { + var completed atomic.Int32 + var startWg, doneWg sync.WaitGroup + startTogether := func(n int, f func()) { + startWg.Add(n) + doneWg.Add(n) + for i := 0; i < n; i++ { + go func() { + startWg.Done() + startWg.Wait() + f() + completed.Add(1) + doneWg.Done() + }() + } + } + endpoint := newEndpointGroup() + ctx := context.TODO() + startTogether(100, func() { + endpoint.getBestHost(ctx, rand.String(4)) + }) + startWg.Wait() + + // when broadcast triggered + endpoint.setIPs( + map[string]struct{}{rand.String(4): {}}, + map[string]int32{rand.String(4): 1}, + ) + // then + doneWg.Wait() + assert.Equal(t, int32(100), completed.Load()) +} + +func TestAbortOnCtxCancel(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + + var startWg, doneWg sync.WaitGroup + startWg.Add(1) + doneWg.Add(1) + go func(t *testing.T) { + startWg.Wait() + endpoint := newEndpointGroup() + _, err := endpoint.getBestHost(ctx, rand.String(4)) + require.Error(t, err) + doneWg.Done() + }(t) + startWg.Done() + cancel() + + doneWg.Wait() +}