Skip to content

Commit

Permalink
Fix double close and polish code
Browse files Browse the repository at this point in the history
  • Loading branch information
alpe committed Dec 22, 2023
1 parent 4d3035b commit 5f02710
Show file tree
Hide file tree
Showing 2 changed files with 74 additions and 35 deletions.
36 changes: 6 additions & 30 deletions pkg/endpoints/endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@ type endpointGroup struct {
ports map[string]int32
endpoints map[string]endpoint

bmtx sync.Mutex
bcast chan struct{} // closed when there's a broadcast
listeners []chan struct{} // keeps track of all listeners
bmtx sync.RWMutex
bcast chan struct{} // closed when there's a broadcast
}

// getBestHost returns the best host for the given port name. It blocks until there are available endpoints
Expand All @@ -45,12 +44,10 @@ func (e *endpointGroup) getBestHost(ctx context.Context, portName string) (strin
// await endpoints exists
for len(e.endpoints) == 0 {
e.mtx.RUnlock()
_, err := execWithCtxAbort(ctx, func() any {
<-e.waitForEndpoints(ctx)
return nil
})
if err != nil {
return "", err
select {
case <-e.bcast:
case <-ctx.Done():
return "", ctx.Err()
}
e.mtx.RLock()
}
Expand Down Expand Up @@ -117,31 +114,10 @@ func (g *endpointGroup) setIPs(ips map[string]struct{}, ports map[string]int32)
}
}

func (e *endpointGroup) waitForEndpoints(ctx context.Context) <-chan struct{} {
e.bmtx.Lock()
defer e.bmtx.Unlock()

ch := make(chan struct{})
e.listeners = append(e.listeners, ch)

go execWithCtxAbort(ctx, func() any {
<-e.bcast
close(ch)
return nil
})

return ch
}

func (g *endpointGroup) broadcastEndpoints() {
g.bmtx.Lock()
defer g.bmtx.Unlock()

close(g.bcast)
g.bcast = make(chan struct{})

for _, ch := range g.listeners {
close(ch)
}
clear(g.listeners)
}
73 changes: 68 additions & 5 deletions pkg/endpoints/endpoints_test.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
package endpoints

import (
"context"
"sync"
"sync/atomic"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"k8s.io/apimachinery/pkg/util/rand"
)

Expand All @@ -15,13 +20,18 @@ func TestConcurrentAccess(t *testing.T) {
readerCount int
writerCount int
}{
"lot of reader": {readerCount: 10_000, writerCount: 1},
"lot of writer": {readerCount: 1, writerCount: 10_000},
"lot of both": {readerCount: 10_000, writerCount: 10_000},
"lot of reader": {readerCount: 1_000, writerCount: 1},
"lot of writer": {readerCount: 1, writerCount: 1_000},
"lot of both": {readerCount: 1_000, writerCount: 1_000},
}

for name, spec := range testCases {
randomReadFn := []func(g *endpointGroup){
func(g *endpointGroup) { g.getBestHost(nil, myPort) },
func(g *endpointGroup) { g.getAllHosts(myPort) },
func(g *endpointGroup) { g.lenIPs() },
}
t.Run(name, func(t *testing.T) {
// setup endpoint with one service so that requests are not waiting
endpoint := newEndpointGroup()
endpoint.setIPs(
map[string]struct{}{myService: {}},
Expand All @@ -41,7 +51,8 @@ func TestConcurrentAccess(t *testing.T) {
}()
}
}
startTogether(spec.readerCount, func() { endpoint.getBestHost(nil, myPort) })
// when
startTogether(spec.readerCount, func() { randomReadFn[rand.Intn(len(randomReadFn)-1)](endpoint) })
startTogether(spec.writerCount, func() {
endpoint.setIPs(
map[string]struct{}{rand.String(1): {}},
Expand All @@ -52,3 +63,55 @@ func TestConcurrentAccess(t *testing.T) {
})
}
}

func TestBlockAndWaitForEndpoints(t *testing.T) {
var completed atomic.Int32
var startWg, doneWg sync.WaitGroup
startTogether := func(n int, f func()) {
startWg.Add(n)
doneWg.Add(n)
for i := 0; i < n; i++ {
go func() {
startWg.Done()
startWg.Wait()
f()
completed.Add(1)
doneWg.Done()
}()
}
}
endpoint := newEndpointGroup()
ctx := context.TODO()
startTogether(100, func() {
endpoint.getBestHost(ctx, rand.String(4))
})
startWg.Wait()

// when broadcast triggered
endpoint.setIPs(
map[string]struct{}{rand.String(4): {}},
map[string]int32{rand.String(4): 1},
)
// then
doneWg.Wait()
assert.Equal(t, int32(100), completed.Load())
}

func TestAbortOnCtxCancel(t *testing.T) {
ctx, cancel := context.WithCancel(context.Background())

var startWg, doneWg sync.WaitGroup
startWg.Add(1)
doneWg.Add(1)
go func(t *testing.T) {
startWg.Wait()
endpoint := newEndpointGroup()
_, err := endpoint.getBestHost(ctx, rand.String(4))
require.Error(t, err)
doneWg.Done()
}(t)
startWg.Done()
cancel()

doneWg.Wait()
}

0 comments on commit 5f02710

Please sign in to comment.