diff --git a/api.go b/api.go index 1b966fe..a85bd85 100644 --- a/api.go +++ b/api.go @@ -70,13 +70,24 @@ type AnyResult[T any] struct { } func AnyOf[T any](fs ...*Future[T]) *Future[AnyResult[T]] { + var counter int32 var done uint32 + var errIndex int32 = -1 s := &state[AnyResult[T]]{} for i, f := range fs { i := i f.state.subscribe(func(val T, err error) { - if atomic.CompareAndSwapUint32(&done, 0, 1) { - s.set(AnyResult[T]{Index: i, Val: val, Err: err}, nil) + if err == nil { + if atomic.CompareAndSwapUint32(&done, 0, 1) { + s.set(AnyResult[T]{Index: i, Val: val, Err: err}, nil) + } + } else { + atomic.CompareAndSwapInt32(&errIndex, -1, int32(i)) + if atomic.AddInt32(&counter, 1) == int32(len(fs)) { + idx := atomic.LoadInt32(&errIndex) + fval, ferr := fs[idx].Get() + s.set(AnyResult[T]{Index: int(idx), Val: fval, Err: ferr}, nil) + } } }) } diff --git a/api_test.go b/api_test.go index 7ab4823..c74fc48 100644 --- a/api_test.go +++ b/api_test.go @@ -235,6 +235,9 @@ func TestAnyOf(t *testing.T) { i := i fs[i] = Async(func() (int, error) { time.Sleep(time.Duration(vals[i]) * time.Millisecond) + if i != target && rand.Intn(2) == 0 { // random error + return 0, errFoo + } return vals[i], nil }) } @@ -246,7 +249,7 @@ func TestAnyOf(t *testing.T) { assert.Equal(t, nil, r.Err) } -func TestAnyOfWhenErrFirst(t *testing.T) { +func TestAnyOfWhenAllErr(t *testing.T) { target := rand.Intn(10) vals := make([]int, 10) for i := 0; i < len(vals); i++ { @@ -261,11 +264,11 @@ func TestAnyOfWhenErrFirst(t *testing.T) { for i := 0; i < 10; i++ { i := i fs[i] = Async(func() (int, error) { - time.Sleep(time.Duration(vals[i]) * time.Millisecond) if i == target { return 0, errFoo } - return vals[i], nil + time.Sleep(time.Duration(vals[i]) * time.Millisecond) + return 0, errFoo }) } f := AnyOf(fs...)