Skip to content

Commit

Permalink
fix: AnyOf return first not err result or first err if all err
Browse files Browse the repository at this point in the history
  • Loading branch information
jizhuozhi committed Aug 4, 2024
1 parent 539f28b commit dcd1626
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 5 deletions.
15 changes: 13 additions & 2 deletions api.go
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,24 @@ type AnyResult[T any] struct {
}

func AnyOf[T any](fs ...*Future[T]) *Future[AnyResult[T]] {
var counter int32
var done uint32
var errIndex int32 = -1
s := &state[AnyResult[T]]{}
for i, f := range fs {
i := i
f.state.subscribe(func(val T, err error) {
if atomic.CompareAndSwapUint32(&done, 0, 1) {
s.set(AnyResult[T]{Index: i, Val: val, Err: err}, nil)
if err == nil {
if atomic.CompareAndSwapUint32(&done, 0, 1) {
s.set(AnyResult[T]{Index: i, Val: val, Err: err}, nil)
}
} else {
atomic.CompareAndSwapInt32(&errIndex, -1, int32(i))
if atomic.AddInt32(&counter, 1) == int32(len(fs)) {
idx := atomic.LoadInt32(&errIndex)
fval, ferr := fs[idx].Get()
s.set(AnyResult[T]{Index: int(idx), Val: fval, Err: ferr}, nil)
}
}
})
}
Expand Down
9 changes: 6 additions & 3 deletions api_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,9 @@ func TestAnyOf(t *testing.T) {
i := i
fs[i] = Async(func() (int, error) {
time.Sleep(time.Duration(vals[i]) * time.Millisecond)
if i != target && rand.Intn(2) == 0 { // random error
return 0, errFoo
}
return vals[i], nil
})
}
Expand All @@ -246,7 +249,7 @@ func TestAnyOf(t *testing.T) {
assert.Equal(t, nil, r.Err)
}

func TestAnyOfWhenErrFirst(t *testing.T) {
func TestAnyOfWhenAllErr(t *testing.T) {
target := rand.Intn(10)
vals := make([]int, 10)
for i := 0; i < len(vals); i++ {
Expand All @@ -261,11 +264,11 @@ func TestAnyOfWhenErrFirst(t *testing.T) {
for i := 0; i < 10; i++ {
i := i
fs[i] = Async(func() (int, error) {
time.Sleep(time.Duration(vals[i]) * time.Millisecond)
if i == target {
return 0, errFoo
}
return vals[i], nil
time.Sleep(time.Duration(vals[i]) * time.Millisecond)
return 0, errFoo
})
}
f := AnyOf(fs...)
Expand Down

0 comments on commit dcd1626

Please sign in to comment.