From 2d6d9a1cf71408990fe2dcd826aff675fc11981f Mon Sep 17 00:00:00 2001 From: Lukasz Antoniak Date: Wed, 13 Nov 2024 13:20:55 +0100 Subject: [PATCH] CASSGO-22: Refactor Query and Batch to be immutable --- cassandra_test.go | 111 ++++++++++++++----- conn.go | 75 +++++++------ conn_test.go | 12 +- control.go | 20 +++- example_batch_test.go | 2 +- integration_test.go | 2 +- policies.go | 9 +- policies_test.go | 15 +-- query_executor.go | 48 ++++---- session.go | 249 ++++++++++++++++++++++++++++++------------ session_test.go | 59 ++++------ 11 files changed, 392 insertions(+), 210 deletions(-) diff --git a/cassandra_test.go b/cassandra_test.go index ec6969190..315716bb9 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -482,7 +482,7 @@ func TestCAS(t *testing.T) { insertBatch := session.Batch(LoggedBatch) insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") - if err := session.ExecuteBatch(insertBatch); err != nil { + if _, err := session.ExecuteBatch(insertBatch); err != nil { t.Fatal("insert:", err) } @@ -616,7 +616,7 @@ func TestBatch(t *testing.T) { batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i) } - if err := session.ExecuteBatch(batch); err != nil { + if _, err := session.ExecuteBatch(batch); err != nil { t.Fatal("execute batch:", err) } @@ -652,7 +652,7 @@ func TestUnpreparedBatch(t *testing.T) { batch.Query(`UPDATE batch_unprepared SET c = c + 1 WHERE id = 1`) } - if err := session.ExecuteBatch(batch); err != nil { + if _, err := session.ExecuteBatch(batch); err != nil { t.Fatal("execute batch:", err) } @@ -688,7 +688,7 @@ func TestBatchLimit(t *testing.T) { for i := 0; i < 65537; i++ { batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i) } - if err := session.ExecuteBatch(batch); err != ErrTooManyStmts { + if _, err := session.ExecuteBatch(batch); err != ErrTooManyStmts { t.Fatal("gocql attempted to execute a batch larger than the support limit of statements.") } @@ -740,7 +740,7 @@ func TestTooManyQueryArgs(t *testing.T) { batch := session.Batch(UnloggedBatch) batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3) - err = session.ExecuteBatch(batch) + _, err = session.ExecuteBatch(batch) if err == nil { t.Fatal("'`INSERT INTO too_many_query_args (id, value) VALUES (?, ?)`, 1, 2, 3' should return an error") @@ -772,7 +772,7 @@ func TestNotEnoughQueryArgs(t *testing.T) { batch := session.Batch(UnloggedBatch) batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2) - err = session.ExecuteBatch(batch) + _, err = session.ExecuteBatch(batch) if err == nil { t.Fatal("'`INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)`, 1, 2' should return an error") @@ -1110,8 +1110,9 @@ func Test_RetryPolicyIdempotence(t *testing.T) { q.RetryPolicy(&MyRetryPolicy{}) q.Consistency(All) - _ = q.Exec() - require.Equal(t, tc.expectedNumberOfTries, q.Attempts()) + iter := q.Iter() + _ = iter.Close() + require.Equal(t, tc.expectedNumberOfTries, iter.Attempts()) }) } } @@ -1395,7 +1396,7 @@ func TestBatchQueryInfo(t *testing.T) { batch := session.Batch(LoggedBatch) batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write) - if err := session.ExecuteBatch(batch); err != nil { + if _, err := session.ExecuteBatch(batch); err != nil { t.Fatalf("batch insert into batch_query_info failed, err '%v'", err) } @@ -1481,7 +1482,7 @@ func TestPrepare_MissingSchemaPrepare(t *testing.T) { defer s.Close() insertQry := s.Query("INSERT INTO invalidschemaprep (val) VALUES (?)", 5) - if err := conn.executeQuery(ctx, insertQry).err; err == nil { + if err := conn.executeQuery(ctx, insertQry, nil).err; err == nil { t.Fatal("expected error, but got nil.") } @@ -1489,7 +1490,7 @@ func TestPrepare_MissingSchemaPrepare(t *testing.T) { t.Fatal("create table:", err) } - if err := conn.executeQuery(ctx, insertQry).err; err != nil { + if err := conn.executeQuery(ctx, insertQry, nil).err; err != nil { t.Fatal(err) // unconfigured columnfamily } } @@ -1503,7 +1504,7 @@ func TestPrepare_ReprepareStatement(t *testing.T) { stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement") query := session.Query(stmt, "bar") - if err := conn.executeQuery(ctx, query).Close(); err != nil { + if err := conn.executeQuery(ctx, query, nil).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) } } @@ -1867,14 +1868,15 @@ func TestQueryStats(t *testing.T) { session := createSession(t) defer session.Close() qry := session.Query("SELECT * FROM system.peers") - if err := qry.Exec(); err != nil { + iter := qry.Iter() + if err := iter.Close(); err != nil { t.Fatalf("query failed. %v", err) } else { - if qry.Attempts() < 1 { + if iter.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } - if qry.Latency() <= 0 { - t.Fatalf("expected latency to be greater than 0, but got %v instead.", qry.Latency()) + if iter.Latency() <= 0 { + t.Fatalf("expected latency to be greater than 0, but got %v instead.", iter.Latency()) } } } @@ -1908,14 +1910,14 @@ func TestBatchStats(t *testing.T) { b.Query("INSERT INTO batchStats (id) VALUES (?)", 1) b.Query("INSERT INTO batchStats (id) VALUES (?)", 2) - if err := session.ExecuteBatch(b); err != nil { + if iter, err := session.ExecuteBatch(b); err != nil { t.Fatalf("query failed. %v", err) } else { - if b.Attempts() < 1 { + if iter.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } - if b.Latency() <= 0 { - t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency()) + if iter.Latency() <= 0 { + t.Fatalf("expected latency to be greater than 0, but got %v instead.", iter.Latency()) } } } @@ -1965,7 +1967,7 @@ func TestBatchObserve(t *testing.T) { batch.Query(fmt.Sprintf(`INSERT INTO batch_observe_table (id,other) VALUES (?,%d)`, i), i) } - if err := session.ExecuteBatch(batch); err != nil { + if _, err := session.ExecuteBatch(batch); err != nil { t.Fatal("execute batch:", err) } if observedBatch == nil { @@ -2876,6 +2878,65 @@ func TestManualQueryPaging(t *testing.T) { } } +func TestQueryImmutability(t *testing.T) { + const rowsToInsert = 5 + + session := createSession(t) + defer session.Close() + + if err := createTable(session, "CREATE TABLE gocql_test.testAutomaticPaging (id int, count int, PRIMARY KEY (id))"); err != nil { + t.Fatal(err) + } + + for i := 0; i < rowsToInsert; i++ { + err := session.Query("INSERT INTO testAutomaticPaging(id, count) VALUES(?, ?)", i, i*i).Exec() + if err != nil { + t.Fatal(err) + } + } + + query := session.Query("SELECT id, count FROM testAutomaticPaging").PageSize(2) + var id, count, fetched1, fetched2 int + + iter1 := query.Iter() + iter2 := query.Iter() + scanner1 := iter1.Scanner() + scanner2 := iter2.Scanner() + for scanner1.Next() { + err := scanner1.Scan(&id, &count) + if err != nil { + t.Fatalf(err.Error()) + } + if fetched1%2 == 0 { + // move two iterators at different pace, to verify that one does not impact the other + if !scanner2.Next() { + t.Fatalf("unexpected end of pagination after %d entries", fetched2) + } else { + fetched2++ + } + } + if count != (id * id) { + t.Fatalf("got wrong value from iteration: got %d expected %d", count, id*id) + } + require.True(t, query.pageState == nil, "initial page state was not set") + require.True(t, iter1.PageState() != nil, "page state is handled by the iterator") + + fetched1++ + } + + if err := iter1.Close(); err != nil { + t.Fatal(err) + } + if err := iter2.Close(); err != nil { + t.Fatal(err) + } + + if fetched1 != rowsToInsert { + t.Fatalf("expected to fetch %d rows got %d", rowsToInsert, fetched1) + } + require.Equal(t, math.Ceil(rowsToInsert/2.0), float64(fetched2)) +} + func TestLexicalUUIDType(t *testing.T) { session := createSession(t) defer session.Close() @@ -3291,14 +3352,14 @@ func TestUnsetColBatch(t *testing.T) { b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "") b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue) - if err := session.ExecuteBatch(b); err != nil { + if iter, err := session.ExecuteBatch(b); err != nil { t.Fatalf("query failed. %v", err) } else { - if b.Attempts() < 1 { + if iter.Attempts() < 1 { t.Fatal("expected at least 1 attempt, but got 0") } - if b.Latency() <= 0 { - t.Fatalf("expected latency to be greater than 0, but got %v instead.", b.Latency()) + if iter.Latency() <= 0 { + t.Fatalf("expected latency to be greater than 0, but got %v instead.", iter.Latency()) } } var id, mInt, count int diff --git a/conn.go b/conn.go index ae02bd71c..871a318c9 100644 --- a/conn.go +++ b/conn.go @@ -1322,7 +1322,7 @@ func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error return nil } -func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { +func (c *Conn) executeQuery(ctx context.Context, qry *Query, it *Iter) *Iter { params := queryParams{ consistency: qry.cons, } @@ -1332,7 +1332,9 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { params.defaultTimestamp = qry.defaultTimestamp params.defaultTimestampValue = qry.defaultTimestampValue - if len(qry.pageState) > 0 { + if it != nil && it.next != nil && len(it.next.pageState) > 0 { + params.pagingState = it.next.pageState + } else if len(qry.pageState) > 0 { params.pagingState = qry.pageState } if qry.pageSize > 0 { @@ -1352,7 +1354,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { var err error info, err = c.prepareStatement(ctx, qry.stmt, qry.trace) if err != nil { - return &Iter{err: err} + return NewIterErr(qry, err) } values := qry.values @@ -1365,12 +1367,12 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { }) if err != nil { - return &Iter{err: err} + return NewIterErr(qry, err) } } if len(values) != info.request.actualColCount { - return &Iter{err: fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))} + return NewIterErr(qry, fmt.Errorf("gocql: expected %d values send got %d", info.request.actualColCount, len(values))) } params.values = make([]queryValues, len(values)) @@ -1379,7 +1381,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { value := values[i] typ := info.request.columns[i].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { - return &Iter{err: err} + return NewIterErr(qry, err) } } @@ -1406,12 +1408,12 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { framer, err := c.exec(ctx, frame, qry.trace) if err != nil { - return &Iter{err: err} + return NewIterErr(qry, err) } resp, err := framer.parseFrame() if err != nil { - return &Iter{err: err} + return NewIterErr(qry, err) } if len(framer.traceID) > 0 && qry.trace != nil { @@ -1420,12 +1422,14 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { switch x := resp.(type) { case *resultVoidFrame: - return &Iter{framer: framer} + return newIterFramer(qry, framer) case *resultRowsFrame: iter := &Iter{ + qry: qry, meta: x.meta, framer: framer, numRows: x.numRows, + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, } if params.skipMeta { @@ -1433,21 +1437,17 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { iter.meta = info.response iter.meta.pagingState = copyBytes(x.meta.pagingState) } else { - return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} + return newIterErrFramer(qry, errors.New("gocql: did not receive metadata but prepared info is nil"), framer) } } else { iter.meta = x.meta } if x.meta.morePages() && !qry.disableAutoPage { - newQry := new(Query) - *newQry = *qry - newQry.pageState = copyBytes(x.meta.pagingState) - newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} - iter.next = &nextIter{ - qry: newQry, - pos: int((1 - qry.prefetch) * float64(x.numRows)), + iter: iter, + pageState: copyBytes(x.meta.pagingState), + pos: int((1 - qry.prefetch) * float64(x.numRows)), } if iter.next.pos < 1 { @@ -1457,9 +1457,9 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { return iter case *resultKeyspaceFrame: - return &Iter{framer: framer} + return newIterFramer(qry, framer) case *schemaChangeKeyspace, *schemaChangeTable, *schemaChangeFunction, *schemaChangeAggregate, *schemaChangeType: - iter := &Iter{framer: framer} + iter := newIterFramer(qry, framer) if err := c.awaitSchemaAgreement(ctx); err != nil { // TODO: should have this behind a flag c.logger.Println(err) @@ -1471,14 +1471,15 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { case *RequestErrUnprepared: stmtCacheKey := c.session.stmtsLRU.keyFor(c.host.HostID(), c.currentKeyspace, qry.stmt) c.session.stmtsLRU.evictPreparedID(stmtCacheKey, x.StatementId) - return c.executeQuery(ctx, qry) + return c.executeQuery(ctx, qry, nil) case error: - return &Iter{err: x, framer: framer} + return newIterErrFramer(qry, x, framer) default: - return &Iter{ - err: NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x), - framer: framer, - } + return newIterErrFramer( + qry, + NewErrProtocol("Unknown type in response to execute query (%T): %s", x, x), + framer, + ) } } @@ -1532,7 +1533,7 @@ func (c *Conn) UseKeyspace(keyspace string) error { func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { if c.version == protoVersion1 { - return &Iter{err: ErrUnsupported} + return NewIterErr(batch, ErrUnsupported) } n := len(batch.Entries) @@ -1555,7 +1556,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { if len(entry.Args) > 0 || entry.binding != nil { info, err := c.prepareStatement(batch.Context(), entry.Stmt, batch.trace) if err != nil { - return &Iter{err: err} + return NewIterErr(batch, err) } var values []interface{} @@ -1569,12 +1570,12 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { PKeyColumns: info.request.pkeyColumns, }) if err != nil { - return &Iter{err: err} + return NewIterErr(batch, err) } } if len(values) != info.request.actualColCount { - return &Iter{err: fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))} + return NewIterErr(batch, fmt.Errorf("gocql: batch statement %d expected %d values send got %d", i, info.request.actualColCount, len(values))) } b.preparedID = info.id @@ -1587,7 +1588,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { value := values[j] typ := info.request.columns[j].TypeInfo if err := marshalQueryValue(typ, value, v); err != nil { - return &Iter{err: err} + return NewIterErr(batch, err) } } } else { @@ -1597,12 +1598,12 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { framer, err := c.exec(batch.Context(), req, batch.trace) if err != nil { - return &Iter{err: err} + return NewIterErr(batch, err) } resp, err := framer.parseFrame() if err != nil { - return &Iter{err: err, framer: framer} + return newIterErrFramer(batch, err, framer) } if len(framer.traceID) > 0 && batch.trace != nil { @@ -1611,7 +1612,7 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { switch x := resp.(type) { case *resultVoidFrame: - return &Iter{} + return newIter(batch) case *RequestErrUnprepared: stmt, found := stmts[string(x.StatementId)] if found { @@ -1621,16 +1622,18 @@ func (c *Conn) executeBatch(ctx context.Context, batch *Batch) *Iter { return c.executeBatch(ctx, batch) case *resultRowsFrame: iter := &Iter{ + qry: batch, meta: x.meta, framer: framer, numRows: x.numRows, + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, } return iter case error: - return &Iter{err: x, framer: framer} + return newIterErrFramer(batch, x, framer) default: - return &Iter{err: NewErrProtocol("Unknown type in response to batch statement: %s", x), framer: framer} + return newIterErrFramer(batch, NewErrProtocol("Unknown type in response to batch statement: %s", x), framer) } } @@ -1640,7 +1643,7 @@ func (c *Conn) query(ctx context.Context, statement string, values ...interface{ q.disableSkipMetadata = true // we want to keep the query on this connection q.conn = c - return c.executeQuery(ctx, q) + return c.executeQuery(ctx, q, nil) } func (c *Conn) querySystemPeers(ctx context.Context, version cassVersion) *Iter { diff --git a/conn_test.go b/conn_test.go index 8706683ff..e4857a7ad 100644 --- a/conn_test.go +++ b/conn_test.go @@ -389,12 +389,13 @@ func TestQueryRetry(t *testing.T) { rt := &SimpleRetryPolicy{NumRetries: 1} qry := db.Query("kill").RetryPolicy(rt) - if err := qry.Exec(); err == nil { + iter := qry.Iter() + if err := iter.Close(); err == nil { t.Fatalf("expected error") } requests := atomic.LoadInt64(&srv.nKillReq) - attempts := qry.Attempts() + attempts := iter.Attempts() if requests != int64(attempts) { t.Fatalf("expected requests %v to match query attempts %v", requests, attempts) } @@ -436,13 +437,14 @@ func TestQueryMultinodeWithMetrics(t *testing.T) { rt := &SimpleRetryPolicy{NumRetries: 3} observer := &testQueryObserver{metrics: make(map[string]*hostMetrics), verbose: false, logger: log} qry := db.Query("kill").RetryPolicy(rt).Observer(observer).Idempotent(true) - if err := qry.Exec(); err == nil { + iter := qry.Iter() + if err := iter.Close(); err == nil { t.Fatalf("expected error") } for i, ip := range addresses { host := &HostInfo{connectAddress: net.ParseIP(ip)} - queryMetric := qry.metrics.hostMetrics(host) + queryMetric := iter.metrics.hostMetrics(host) observedMetrics := observer.GetMetrics(host) requests := int(atomic.LoadInt64(&nodes[i].nKillReq)) @@ -462,7 +464,7 @@ func TestQueryMultinodeWithMetrics(t *testing.T) { } } // the query will only be attempted once, but is being retried - attempts := qry.Attempts() + attempts := iter.Attempts() if attempts != rt.NumRetries { t.Fatalf("failed to retry the query %v time(s). Query executed %v times", rt.NumRetries, attempts) } diff --git a/control.go b/control.go index b30b44ea3..7cd3555eb 100644 --- a/control.go +++ b/control.go @@ -493,7 +493,7 @@ func (c *controlConn) withConnHost(fn func(*connHost) *Iter) *Iter { return fn(ch) } - return &Iter{err: errNoControl} + return NewIterErr(nil, errNoControl) } func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { @@ -506,21 +506,31 @@ func (c *controlConn) withConn(fn func(*Conn) *Iter) *Iter { func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter) { q := c.session.Query(statement, values...).Consistency(One).RoutingKey([]byte{}).Trace(nil) + var prev *Iter for { iter = c.withConn(func(conn *Conn) *Iter { // we want to keep the query on the control connection q.conn = conn - return conn.executeQuery(context.TODO(), q) + return conn.executeQuery(context.TODO(), q, nil) }) if gocqlDebug && iter.err != nil { c.session.logger.Printf("control: error executing %q: %v\n", statement, iter.err) } - q.AddAttempts(1, c.getConn().host) - if iter.err == nil || !c.retry.Attempt(q) { + iter.AddAttempts(1, c.getConn().host) + // merge state of the previous iterator, so that mutable state (e.g. metrics) is carried over + iter.merge(prev) + if iter.err == nil { break } + // clone to make the query attributes updatable by retry policy and original immutable + iter.qry = q.Clone() + if !c.retry.Attempt(iter) { + break + } + q = iter.qry.(*Query) + prev = iter } return @@ -528,7 +538,7 @@ func (c *controlConn) query(statement string, values ...interface{}) (iter *Iter func (c *controlConn) awaitSchemaAgreement() error { return c.withConn(func(conn *Conn) *Iter { - return &Iter{err: conn.awaitSchemaAgreement(context.TODO())} + return NewIterErr(nil, conn.awaitSchemaAgreement(context.TODO())) }).err } diff --git a/example_batch_test.go b/example_batch_test.go index b27085ccc..2bf14781a 100644 --- a/example_batch_test.go +++ b/example_batch_test.go @@ -61,7 +61,7 @@ func Example_batch() { Idempotent: true, }) - err = session.ExecuteBatch(b) + _, err = session.ExecuteBatch(b) if err != nil { log.Fatal(err) } diff --git a/integration_test.go b/integration_test.go index 61ffbf504..776a6ee61 100644 --- a/integration_test.go +++ b/integration_test.go @@ -221,7 +221,7 @@ func TestCustomPayloadMessages(t *testing.T) { b := session.Batch(LoggedBatch) b.CustomPayload = customPayload b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)") - if err := session.ExecuteBatch(b); err != nil { + if _, err := session.ExecuteBatch(b); err != nil { t.Fatalf("query failed. %v", err) } } diff --git a/policies.go b/policies.go index 1157da87b..aafc8e14b 100644 --- a/policies.go +++ b/policies.go @@ -118,6 +118,7 @@ func (c *cowHostList) remove(ip net.IP) bool { // RetryableQuery is an interface that represents a query or batch statement that // exposes the correct functions for the retry policy logic to evaluate correctly. type RetryableQuery interface { + GetQuery() ExecutableQuery Attempts() int SetConsistency(c Consistency) GetConsistency() Consistency @@ -208,6 +209,10 @@ func (e *ExponentialBackoffRetryPolicy) GetRetryType(err error) RetryType { return RetryNextHost } +func (e *ExponentialBackoffRetryPolicy) napTime(attempts int) time.Duration { + return getExponentialTime(e.Min, e.Max, attempts) +} + // DowngradingConsistencyRetryPolicy: Next retry will be with the next consistency level // provided in the slice // @@ -263,10 +268,6 @@ func (d *DowngradingConsistencyRetryPolicy) GetRetryType(err error) RetryType { } } -func (e *ExponentialBackoffRetryPolicy) napTime(attempts int) time.Duration { - return getExponentialTime(e.Min, e.Max, attempts) -} - type HostStateNotifier interface { AddHost(host *HostInfo) RemoveHost(host *HostInfo) diff --git a/policies_test.go b/policies_test.go index 231c2a7e2..c2655650d 100644 --- a/policies_test.go +++ b/policies_test.go @@ -283,7 +283,7 @@ func TestCOWList_Add(t *testing.T) { // TestSimpleRetryPolicy makes sure that we only allow 1 + numRetries attempts func TestSimpleRetryPolicy(t *testing.T) { - q := &Query{routingInfo: &queryRoutingInfo{}} + i := &Iter{} // this should allow a total of 3 tries. rt := &SimpleRetryPolicy{NumRetries: 2} @@ -301,11 +301,11 @@ func TestSimpleRetryPolicy(t *testing.T) { } for _, c := range cases { - q.metrics = preFilledQueryMetrics(map[string]*hostMetrics{"127.0.0.1": {Attempts: c.attempts}}) - if c.allow && !rt.Attempt(q) { + i.metrics = preFilledQueryMetrics(map[string]*hostMetrics{"127.0.0.1": {Attempts: c.attempts}}) + if c.allow && !rt.Attempt(i) { t.Fatalf("should allow retry after %d attempts", c.attempts) } - if !c.allow && rt.Attempt(q) { + if !c.allow && rt.Attempt(i) { t.Fatalf("should not allow retry after %d attempts", c.attempts) } } @@ -342,6 +342,7 @@ func TestExponentialBackoffPolicy(t *testing.T) { func TestDowngradingConsistencyRetryPolicy(t *testing.T) { q := &Query{cons: LocalQuorum, routingInfo: &queryRoutingInfo{}} + i := &Iter{qry: q} rewt0 := &RequestErrWriteTimeout{ Received: 0, @@ -385,14 +386,14 @@ func TestDowngradingConsistencyRetryPolicy(t *testing.T) { } for _, c := range cases { - q.metrics = preFilledQueryMetrics(map[string]*hostMetrics{"127.0.0.1": {Attempts: c.attempts}}) + i.metrics = preFilledQueryMetrics(map[string]*hostMetrics{"127.0.0.1": {Attempts: c.attempts}}) if c.retryType != rt.GetRetryType(c.err) { t.Fatalf("retry type should be %v", c.retryType) } - if c.allow && !rt.Attempt(q) { + if c.allow && !rt.Attempt(i) { t.Fatalf("should allow retry after %d attempts", c.attempts) } - if !c.allow && rt.Attempt(q) { + if !c.allow && rt.Attempt(i) { t.Fatalf("should not allow retry after %d attempts", c.attempts) } } diff --git a/query_executor.go b/query_executor.go index d6be02e53..5bbda2dd3 100644 --- a/query_executor.go +++ b/query_executor.go @@ -33,7 +33,7 @@ import ( type ExecutableQuery interface { borrowForExecution() // Used to ensure that the query stays alive for lifetime of a particular execution goroutine. releaseAfterExecution() // Used when a goroutine finishes its execution attempts, either with ok result or an error. - execute(ctx context.Context, conn *Conn) *Iter + execute(ctx context.Context, conn *Conn, iter *Iter) *Iter attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) retryPolicy() RetryPolicy speculativeExecutionPolicy() SpeculativeExecutionPolicy @@ -43,8 +43,11 @@ type ExecutableQuery interface { IsIdempotent() bool withContext(context.Context) ExecutableQuery + Clone() ExecutableQuery - RetryableQuery + SetConsistency(c Consistency) + GetConsistency() Consistency + Context() context.Context } type queryExecutor struct { @@ -52,9 +55,9 @@ type queryExecutor struct { policy HostSelectionPolicy } -func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, conn *Conn) *Iter { +func (q *queryExecutor) attemptQuery(ctx context.Context, qry ExecutableQuery, it *Iter, conn *Conn) *Iter { start := time.Now() - iter := qry.execute(ctx, conn) + iter := qry.execute(ctx, conn, it) end := time.Now() qry.attempt(q.pool.keyspace, end, start, iter, conn.host) @@ -67,13 +70,14 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S ticker := time.NewTicker(sp.Delay()) defer ticker.Stop() + qry = qry.Clone() for i := 0; i < sp.Attempts(); i++ { select { case <-ticker.C: qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). - go q.run(ctx, qry, hostIter, results) + go q.run(ctx, qry, nil, hostIter, results) case <-ctx.Done(): - return &Iter{err: ctx.Err()} + return NewIterErr(qry, ctx.Err()) case iter := <-results: return iter } @@ -82,14 +86,14 @@ func (q *queryExecutor) speculate(ctx context.Context, qry ExecutableQuery, sp S return nil } -func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { +func (q *queryExecutor) executeQuery(qry ExecutableQuery, it *Iter) (*Iter, error) { hostIter := q.policy.Pick(qry) // check if the query is not marked as idempotent, if // it is, we force the policy to NonSpeculative sp := qry.speculativeExecutionPolicy() if !qry.IsIdempotent() || sp.Attempts() == 0 { - return q.do(qry.Context(), qry, hostIter), nil + return q.do(qry.Context(), qry, it, hostIter), nil } // When speculative execution is enabled, we could be accessing the host iterator from multiple goroutines below. @@ -109,7 +113,7 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { // Launch the main execution qry.borrowForExecution() // ensure liveness in case of executing Query to prevent races with Query.Release(). - go q.run(ctx, qry, hostIter, results) + go q.run(ctx, qry, it, hostIter, results) // The speculative executions are launched _in addition_ to the main // execution, on a timer. So Speculation{2} would make 3 executions running @@ -122,11 +126,11 @@ func (q *queryExecutor) executeQuery(qry ExecutableQuery) (*Iter, error) { case iter := <-results: return iter, nil case <-ctx.Done(): - return &Iter{err: ctx.Err()}, nil + return NewIterErr(qry, ctx.Err()), nil } } -func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter NextHost) *Iter { +func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, it *Iter, hostIter NextHost) *Iter { selectedHost := hostIter() rt := qry.retryPolicy() @@ -151,12 +155,14 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne continue } - iter = q.attemptQuery(ctx, qry, conn) + ni := q.attemptQuery(ctx, qry, it, conn) + ni.merge(iter) + iter = ni iter.host = selectedHost.Info() // Update host switch iter.err { case context.Canceled, context.DeadlineExceeded, ErrNotFound: - // those errors represents logical errors, they should not count + // those errors represent logical errors, they should not count // toward removing a node from the pool selectedHost.Mark(nil) return iter @@ -170,11 +176,12 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne return iter } - attemptsReached := !rt.Attempt(qry) + // clone to make the query attributes updatable by retry policy and original immutable + iter.qry = qry.Clone() + attemptsReached := !rt.Attempt(iter) retryType := rt.GetRetryType(iter.err) var stopRetries bool - // If query is unsuccessful, check the error with RetryPolicy to retry switch retryType { case Retry: @@ -189,27 +196,28 @@ func (q *queryExecutor) do(ctx context.Context, qry ExecutableQuery, hostIter Ne stopRetries = true default: // Undefined? Return nil and error, this will panic in the requester - return &Iter{err: ErrUnknownRetryType} + return NewIterErrFromIter(qry, ErrUnknownRetryType, iter) } if stopRetries || attemptsReached { return iter } + qry = iter.qry lastErr = iter.err continue } if lastErr != nil { - return &Iter{err: lastErr} + return NewIterErrFromIter(qry, lastErr, iter) } - return &Iter{err: ErrNoConnections} + return NewIterErrFromIter(qry, ErrNoConnections, iter) } -func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, hostIter NextHost, results chan<- *Iter) { +func (q *queryExecutor) run(ctx context.Context, qry ExecutableQuery, it *Iter, hostIter NextHost, results chan<- *Iter) { select { - case results <- q.do(ctx, qry, hostIter): + case results <- q.do(ctx, qry, it, hostIter): case <-ctx.Done(): } qry.releaseAfterExecution() diff --git a/session.go b/session.go index d04a13672..1278a65d1 100644 --- a/session.go +++ b/session.go @@ -377,7 +377,7 @@ func (s *Session) AwaitSchemaAgreement(ctx context.Context) error { return errNoControl } return s.control.withConn(func(conn *Conn) *Iter { - return &Iter{err: conn.awaitSchemaAgreement(ctx)} + return NewIterErr(nil, conn.awaitSchemaAgreement(ctx)) }).err } @@ -537,15 +537,15 @@ func (s *Session) initialized() bool { return initialized } -func (s *Session) executeQuery(qry *Query) (it *Iter) { +func (s *Session) executeQuery(qry *Query, it *Iter) *Iter { // fail fast if s.Closed() { - return &Iter{err: ErrSessionClosed} + return NewIterErr(qry, ErrSessionClosed) } - iter, err := s.executor.executeQuery(qry) + iter, err := s.executor.executeQuery(qry, it) if err != nil { - return &Iter{err: err} + return NewIterErr(qry, err) } if iter == nil { panic("nil iter") @@ -727,7 +727,7 @@ func (s *Session) routingKeyInfo(ctx context.Context, stmt string) (*routingKeyI return routingKeyInfo, nil } -func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter { +func (b *Batch) execute(ctx context.Context, conn *Conn, it *Iter) *Iter { return conn.executeBatch(ctx, b) } @@ -741,19 +741,19 @@ func (b *Batch) Exec() error { func (s *Session) executeBatch(batch *Batch) *Iter { // fail fast if s.Closed() { - return &Iter{err: ErrSessionClosed} + return NewIterErr(batch, ErrSessionClosed) } // Prevent the execution of the batch if greater than the limit // Currently batches have a limit of 65536 queries. // https://datastax-oss.atlassian.net/browse/JAVA-229 if batch.Size() > BatchSizeMaximum { - return &Iter{err: ErrTooManyStmts} + return NewIterErr(batch, ErrTooManyStmts) } - iter, err := s.executor.executeQuery(batch) + iter, err := s.executor.executeQuery(batch, nil) if err != nil { - return &Iter{err: err} + return NewIterErr(batch, err) } return iter @@ -761,9 +761,9 @@ func (s *Session) executeBatch(batch *Batch) *Iter { // ExecuteBatch executes a batch operation and returns nil if successful // otherwise an error is returned describing the failure. -func (s *Session) ExecuteBatch(batch *Batch) error { +func (s *Session) ExecuteBatch(batch *Batch) (*Iter, error) { iter := s.executeBatch(batch) - return iter.Close() + return iter, iter.Close() } // ExecuteBatchCAS executes a batch operation and returns true if successful and @@ -816,6 +816,11 @@ type hostMetrics struct { TotalLatency int64 } +func (h *hostMetrics) merge(o *hostMetrics) { + h.TotalLatency += o.TotalLatency + h.Attempts += o.Attempts +} + type queryMetrics struct { l sync.RWMutex m map[string]*hostMetrics @@ -824,6 +829,20 @@ type queryMetrics struct { totalAttempts int } +func (qm *queryMetrics) merge(o *queryMetrics) { + o.l.Lock() + qm.totalAttempts += o.totalAttempts + for h, m := range o.m { + _, exists := qm.m[h] + if exists { + qm.m[h].merge(m) + } else { + qm.m[h] = m + } + } + o.l.Unlock() +} + // preFilledQueryMetrics initializes new queryMetrics based on per-host supplied data. func preFilledQueryMetrics(m map[string]*hostMetrics) *queryMetrics { qm := &queryMetrics{m: m} @@ -929,7 +948,6 @@ type Query struct { context context.Context idempotent bool customPayload map[string][]byte - metrics *queryMetrics refCount uint32 disableAutoPage bool @@ -945,6 +963,37 @@ type Query struct { routingInfo *queryRoutingInfo } +func (q *Query) Clone() ExecutableQuery { + return &Query{ + stmt: q.stmt, + values: q.values, + cons: q.cons, + pageSize: q.pageSize, + routingKey: q.routingKey, + pageState: q.pageState, + prefetch: q.prefetch, + trace: q.trace, + observer: q.observer, + session: q.session, + conn: q.conn, + rt: q.rt, + spec: q.spec, + binding: q.binding, + serialCons: q.serialCons, + defaultTimestamp: q.defaultTimestamp, + defaultTimestampValue: q.defaultTimestampValue, + disableSkipMetadata: q.disableSkipMetadata, + context: q.context, + idempotent: q.idempotent, + customPayload: q.customPayload, + refCount: q.refCount, + disableAutoPage: q.disableAutoPage, + getKeyspace: q.getKeyspace, + skipPrepare: q.skipPrepare, + routingInfo: q.routingInfo, + } +} + type queryRoutingInfo struct { // mu protects contents of queryRoutingInfo. mu sync.RWMutex @@ -967,7 +1016,6 @@ func (q *Query) defaultsFromSession() { q.serialCons = s.cfg.SerialConsistency q.defaultTimestamp = s.cfg.DefaultTimestamp q.idempotent = s.cfg.DefaultIdempotence - q.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} q.spec = &NonSpeculativeExecution{} s.mu.RUnlock() @@ -989,24 +1037,6 @@ func (q Query) String() string { return fmt.Sprintf("[query statement=%q values=%+v consistency=%s]", q.stmt, q.values, q.cons) } -// Attempts returns the number of times the query was executed. -func (q *Query) Attempts() int { - return q.metrics.attempts() -} - -func (q *Query) AddAttempts(i int, host *HostInfo) { - q.metrics.attempt(i, 0, host, false) -} - -// Latency returns the average amount of nanoseconds per attempt of the query. -func (q *Query) Latency() int64 { - return q.metrics.latency() -} - -func (q *Query) AddLatency(l int64, host *HostInfo) { - q.metrics.attempt(0, time.Duration(l)*time.Nanosecond, host, false) -} - // Consistency sets the consistency level for this query. If no consistency // level have been set, the default consistency level of the cluster // is used. @@ -1114,13 +1144,13 @@ func (q *Query) Cancel() { // TODO: delete } -func (q *Query) execute(ctx context.Context, conn *Conn) *Iter { - return conn.executeQuery(ctx, q) +func (q *Query) execute(ctx context.Context, conn *Conn, it *Iter) *Iter { + return conn.executeQuery(ctx, q, it) } func (q *Query) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { latency := end.Sub(start) - attempt, metricsForHost := q.metrics.attempt(1, latency, host, q.observer != nil) + attempt, metricsForHost := iter.metrics.attempt(1, latency, host, q.observer != nil) if q.observer != nil { q.observer.ObserveQuery(q.Context(), ObservedQuery{ @@ -1315,14 +1345,14 @@ func isUseStatement(stmt string) bool { // over all results. func (q *Query) Iter() *Iter { if isUseStatement(q.stmt) { - return &Iter{err: ErrUseStmt} + return NewIterErr(q, ErrUseStmt) } // if the query was specifically run on a connection then re-use that // connection when fetching the next results if q.conn != nil { - return q.conn.executeQuery(q.Context(), q) + return q.conn.executeQuery(q.Context(), q, nil) } - return q.session.executeQuery(q) + return q.session.executeQuery(q, nil) } // MapScan executes the query, copies the columns of the first selected @@ -1434,6 +1464,7 @@ func (q *Query) releaseAfterExecution() { // were returned by a query. The iterator might send additional queries to the // database during the iteration if paging was enabled. type Iter struct { + qry ExecutableQuery err error pos int meta resultMetadata @@ -1441,10 +1472,72 @@ type Iter struct { next *nextIter host *HostInfo + metrics *queryMetrics + framer *framer closed int32 } +func newIter(qry ExecutableQuery) *Iter { + return &Iter{ + qry: qry, + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + } +} + +func newIterFramer(qry ExecutableQuery, f *framer) *Iter { + return &Iter{ + qry: qry, + framer: f, + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + } +} + +func NewIterErr(qry ExecutableQuery, e error) *Iter { + return newIterErrFramer(qry, e, nil) +} + +func NewIterErrFromIter(qry ExecutableQuery, e error, iter *Iter) *Iter { + i := newIterErrFramer(qry, e, nil) + if iter != nil { + i.metrics = iter.metrics + } + return i +} + +func newIterErrFramer(qry ExecutableQuery, e error, f *framer) *Iter { + return &Iter{ + qry: qry, + err: e, + framer: f, + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + } +} + +// merge state of two iterators +func (iter *Iter) merge(other *Iter) { + if other != nil { + iter.metrics.merge(other.metrics) + } +} + +// GetQuery returns single CQL query or batch associated with this iterator +func (iter *Iter) GetQuery() ExecutableQuery { + return iter.qry +} + +func (iter *Iter) GetConsistency() Consistency { + return iter.qry.GetConsistency() +} + +func (iter *Iter) SetConsistency(c Consistency) { + iter.qry.SetConsistency(c) +} + +func (iter *Iter) Context() context.Context { + return iter.qry.Context() +} + // Host returns the host which the query was sent to. func (iter *Iter) Host() *HostInfo { return iter.host @@ -1455,6 +1548,24 @@ func (iter *Iter) Columns() []ColumnInfo { return iter.meta.columns } +// Attempts returns the number of times the query was executed. +func (iter *Iter) Attempts() int { + return iter.metrics.attempts() +} + +func (iter *Iter) AddAttempts(i int, host *HostInfo) { + iter.metrics.attempt(i, 0, host, false) +} + +// Latency returns the average amount of nanoseconds per attempt of the query. +func (iter *Iter) Latency() int64 { + return iter.metrics.latency() +} + +func (iter *Iter) AddLatency(l int64, host *HostInfo) { + iter.metrics.attempt(0, time.Duration(l)*time.Nanosecond, host, false) +} + type Scanner interface { // Next advances the row pointer to point at the next row, the row is valid until // the next call of Next. It returns true if there is a row which is available to be @@ -1705,11 +1816,12 @@ func (iter *Iter) NumRows() int { // nextIter holds state for fetching a single page in an iterator. // single page might be attempted multiple times due to retries. type nextIter struct { - qry *Query - pos int - oncea sync.Once - once sync.Once - next *Iter + iter *Iter + pageState []byte + pos int + oncea sync.Once + once sync.Once + next *Iter } func (n *nextIter) fetchAsync() { @@ -1722,10 +1834,11 @@ func (n *nextIter) fetch() *Iter { n.once.Do(func() { // if the query was specifically run on a connection then re-use that // connection when fetching the next results - if n.qry.conn != nil { - n.next = n.qry.conn.executeQuery(n.qry.Context(), n.qry) + qry := n.iter.qry.(*Query) + if qry.conn != nil { + n.next = qry.conn.executeQuery(qry.Context(), qry, n.iter) } else { - n.next = n.qry.session.executeQuery(n.qry) + n.next = qry.session.executeQuery(qry, n.iter) } }) return n.next @@ -1748,7 +1861,6 @@ type Batch struct { context context.Context cancelBatch func() keyspace string - metrics *queryMetrics // routingInfo is a pointer because Query can be copied and copyable struct can't hold a mutex. routingInfo *queryRoutingInfo @@ -1774,7 +1886,6 @@ func (s *Session) Batch(typ BatchType) *Batch { Cons: s.cons, defaultTimestamp: s.cfg.DefaultTimestamp, keyspace: s.cfg.Keyspace, - metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, spec: &NonSpeculativeExecution{}, routingInfo: &queryRoutingInfo{}, } @@ -1783,6 +1894,28 @@ func (s *Session) Batch(typ BatchType) *Batch { return batch } +func (b *Batch) Clone() ExecutableQuery { + return &Batch{ + Type: b.Type, + Entries: b.Entries, + Cons: b.Cons, + routingKey: b.routingKey, + CustomPayload: b.CustomPayload, + rt: b.rt, + spec: b.spec, + trace: b.trace, + observer: b.observer, + session: b.session, + serialCons: b.serialCons, + defaultTimestamp: b.defaultTimestamp, + defaultTimestampValue: b.defaultTimestampValue, + context: b.context, + cancelBatch: b.cancelBatch, + keyspace: b.keyspace, + routingInfo: b.routingInfo, + } +} + // Trace enables tracing of this batch. Look at the documentation of the // Tracer interface to learn more about tracing. func (b *Batch) Trace(trace Tracer) *Batch { @@ -1806,24 +1939,6 @@ func (b *Batch) Table() string { return b.routingInfo.table } -// Attempts returns the number of attempts made to execute the batch. -func (b *Batch) Attempts() int { - return b.metrics.attempts() -} - -func (b *Batch) AddAttempts(i int, host *HostInfo) { - b.metrics.attempt(i, 0, host, false) -} - -// Latency returns the average number of nanoseconds to execute a single attempt of the batch. -func (b *Batch) Latency() int64 { - return b.metrics.latency() -} - -func (b *Batch) AddLatency(l int64, host *HostInfo) { - b.metrics.attempt(0, time.Duration(l)*time.Nanosecond, host, false) -} - // GetConsistency returns the currently configured consistency level for the batch // operation. func (b *Batch) GetConsistency() Consistency { @@ -1947,7 +2062,7 @@ func (b *Batch) WithTimestamp(timestamp int64) *Batch { func (b *Batch) attempt(keyspace string, end, start time.Time, iter *Iter, host *HostInfo) { latency := end.Sub(start) - attempt, metricsForHost := b.metrics.attempt(1, latency, host, b.observer != nil) + attempt, metricsForHost := iter.metrics.attempt(1, latency, host, b.observer != nil) if b.observer == nil { return diff --git a/session_test.go b/session_test.go index 8633f9957..b272d8e8c 100644 --- a/session_test.go +++ b/session_test.go @@ -91,14 +91,14 @@ func TestSessionAPI(t *testing.T) { t.Fatalf("expected qry.stmt to be 'test', got '%v'", boundQry.stmt) } - itr := s.executeQuery(qry) + itr := s.executeQuery(qry, nil) if itr.err != ErrNoConnections { t.Fatalf("expected itr.err to be '%v', got '%v'", ErrNoConnections, itr.err) } testBatch := s.Batch(LoggedBatch) testBatch.Query("test") - err := s.ExecuteBatch(testBatch) + _, err := s.ExecuteBatch(testBatch) if err != ErrNoConnections { t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrNoConnections, err) @@ -111,7 +111,7 @@ func TestSessionAPI(t *testing.T) { //Should just return cleanly s.Close() - err = s.ExecuteBatch(testBatch) + _, err = s.ExecuteBatch(testBatch) if err != ErrSessionClosed { t.Fatalf("expected session.ExecuteBatch to return '%v', got '%v'", ErrSessionClosed, err) } @@ -123,29 +123,33 @@ func (f funcQueryObserver) ObserveQuery(ctx context.Context, o ObservedQuery) { f(ctx, o) } -func TestQueryBasicAPI(t *testing.T) { - qry := &Query{routingInfo: &queryRoutingInfo{}} +func TestIterBasicAPI(t *testing.T) { + iter := &Iter{} // Initiate host ip := "127.0.0.1" - qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 0, TotalLatency: 0}}) - if qry.Latency() != 0 { - t.Fatalf("expected Query.Latency() to return 0, got %v", qry.Latency()) + iter.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 0, TotalLatency: 0}}) + if iter.Latency() != 0 { + t.Fatalf("expected Iter.Latency() to return 0, got %v", iter.Latency()) } - qry.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 2, TotalLatency: 4}}) - if qry.Attempts() != 2 { - t.Fatalf("expected Query.Attempts() to return 2, got %v", qry.Attempts()) + iter.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 2, TotalLatency: 4}}) + if iter.Attempts() != 2 { + t.Fatalf("expected Iter.Attempts() to return 2, got %v", iter.Attempts()) } - if qry.Latency() != 2 { - t.Fatalf("expected Query.Latency() to return 2, got %v", qry.Latency()) + if iter.Latency() != 2 { + t.Fatalf("expected Iter.Latency() to return 2, got %v", iter.Latency()) } - qry.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) - if qry.Attempts() != 4 { - t.Fatalf("expected Query.Attempts() to return 4, got %v", qry.Attempts()) + iter.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) + if iter.Attempts() != 4 { + t.Fatalf("expected Iter.Attempts() to return 4, got %v", iter.Attempts()) } +} + +func TestQueryBasicAPI(t *testing.T) { + qry := &Query{routingInfo: &queryRoutingInfo{}} qry.Consistency(All) if qry.GetConsistency() != All { @@ -232,29 +236,6 @@ func TestBatchBasicAPI(t *testing.T) { t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type) } - ip := "127.0.0.1" - - // Test attempts - b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1}}) - if b.Attempts() != 1 { - t.Fatalf("expected batch.Attempts() to return %v, got %v", 1, b.Attempts()) - } - - b.AddAttempts(2, &HostInfo{hostname: ip, connectAddress: net.ParseIP(ip), port: 9042}) - if b.Attempts() != 3 { - t.Fatalf("expected batch.Attempts() to return %v, got %v", 3, b.Attempts()) - } - - // Test latency - if b.Latency() != 0 { - t.Fatalf("expected batch.Latency() to be 0, got %v", b.Latency()) - } - - b.metrics = preFilledQueryMetrics(map[string]*hostMetrics{ip: {Attempts: 1, TotalLatency: 4}}) - if b.Latency() != 4 { - t.Fatalf("expected batch.Latency() to return %v, got %v", 4, b.Latency()) - } - // Test Consistency b.Cons = One if b.GetConsistency() != One {