From a1cf2dc39ee02878064d8084dd5aabaa3db7f9cb Mon Sep 17 00:00:00 2001 From: guregu Date: Tue, 17 Dec 2024 04:24:32 +0900 Subject: [PATCH] fix Query.One + Filter behavior (#248) --- query.go | 44 ++++++++++++++++++-------------------------- query_test.go | 25 +++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 26 deletions(-) diff --git a/query.go b/query.go index ed82375..bc6b553 100644 --- a/query.go +++ b/query.go @@ -239,34 +239,19 @@ func (q *Query) One(ctx context.Context, out interface{}) error { } // If not, try a Query. - req := q.queryInput() - - var res *dynamodb.QueryOutput - err := q.table.db.retry(ctx, func() error { - var err error - res, err = q.table.db.client.Query(ctx, req) - q.cc.incRequests() - if err != nil { - return err - } - - switch { - case len(res.Items) == 0: - return ErrNotFound - case len(res.Items) > 1 && q.limit != 1: - return ErrTooMany - case res.LastEvaluatedKey != nil && q.searchLimit != 0: - return ErrTooMany - } - - return nil - }) - if err != nil { + iter := q.Iter().(*queryIter) + ok := iter.Next(ctx, out) + if err := iter.Err(); err != nil { return err } - q.cc.add(res.ConsumedCapacity) - - return unmarshalItem(res.Items[0], out) + if !ok { + return ErrNotFound + } + // Best effort: do we have any pending unused items? + if iter.hasMore() { + return ErrTooMany + } + return nil } // Count executes this request, returning the number of results. @@ -422,6 +407,13 @@ func (itr *queryIter) Next(ctx context.Context, out interface{}) bool { return itr.err == nil } +func (itr *queryIter) hasMore() bool { + if itr.query.limit > 0 && itr.n == itr.query.limit { + return false + } + return itr.output != nil && itr.idx < len(itr.output.Items) +} + // Err returns the error encountered, if any. // You should check this after Next is finished. func (itr *queryIter) Err() error { diff --git a/query_test.go b/query_test.go index d4ac428..4942067 100644 --- a/query_test.go +++ b/query_test.go @@ -2,6 +2,7 @@ package dynamo import ( "context" + "errors" "reflect" "testing" "time" @@ -111,6 +112,30 @@ func TestGetAllCount(t *testing.T) { t.Errorf("bad result for get one. %v ≠ %v", one, item) } + // trigger ErrTooMany + one = widget{} + err = table.Get("UserID", 42).Range("Time", Greater, "0").Consistent(true).One(ctx, &one) + if !errors.Is(err, ErrTooMany) { + t.Errorf("bad error from get one. %v ≠ %v", err, ErrTooMany) + } + + // suppress ErrTooMany with Limit(1) + one = widget{} + err = table.Get("UserID", 42).Range("Time", Greater, "0").Consistent(true).Limit(1).One(ctx, &one) + if err != nil { + t.Error("unexpected error:", err) + } + if one.UserID == 0 { + t.Errorf("bad result for get one: %v", one) + } + + // trigger ErrNotFound via SearchLimit + Filter + One + one = widget{} + err = table.Get("UserID", 42).Range("Time", Greater, "0").Filter("Msg = ?", item.Msg).Consistent(true).SearchLimit(1).One(ctx, &one) + if !errors.Is(err, ErrNotFound) { + t.Errorf("bad error from get one. %v ≠ %v", err, ErrNotFound) + } + // GetItem + Project one = widget{} projected := widget{