diff --git a/api.go b/api.go index 5908bc3..3c51796 100644 --- a/api.go +++ b/api.go @@ -108,28 +108,31 @@ func ToAny[T any](f *Future[T]) *Future[any] { }) } -func AllOf[T any](fs ...*Future[T]) *Future[struct{}] { +func AllOf[T any](fs ...*Future[T]) *Future[[]T] { if len(fs) == 0 { - return Done(struct{}{}) + return Done[[]T](nil) } var done uint32 - s := &state[struct{}]{} + s := &state[[]T]{} c := int32(len(fs)) - for _, f := range fs { + results := make([]T, len(fs)) + for i, f := range fs { + i := i f.state.subscribe(func(val T, err error) { if err != nil { if atomic.CompareAndSwapUint32(&done, 0, 1) { - s.set(struct{}{}, err) + s.set(nil, err) } } else { + results[i] = val if atomic.AddInt32(&c, -1) == 0 { - s.set(struct{}{}, nil) + s.set(results, nil) } } }) } - return &Future[struct{}]{state: s} + return &Future[[]T]{state: s} } func Timeout[T any](f *Future[T], d time.Duration) *Future[T] { diff --git a/api_test.go b/api_test.go index 7539b79..467ad31 100644 --- a/api_test.go +++ b/api_test.go @@ -327,9 +327,9 @@ func TestAllOf(t *testing.T) { } f := AllOf(fs...) - _, err := f.Get() + results, err := f.Get() + assert.Equal(t, vals, results) assert.NoError(t, err) - for i := 0; i < 10; i++ { ff := fs[i] val, err := ff.Get()