From 5161ea8e4919a0ab5f80144d15818aac29a594c0 Mon Sep 17 00:00:00 2001 From: tengu-alt Date: Fri, 1 Nov 2024 10:44:42 +0200 Subject: [PATCH] Exec() method for batch was added & Query() method was refactored --- batch_test.go | 16 +++++++++------- cassandra_test.go | 34 +++++++++++++++++----------------- doc.go | 2 +- example_batch_test.go | 14 ++++++++++++-- example_lwt_batch_test.go | 4 ++-- integration_test.go | 2 +- session.go | 31 ++++++++++++++++++++++++++++++- session_test.go | 6 +++--- 8 files changed, 75 insertions(+), 34 deletions(-) diff --git a/batch_test.go b/batch_test.go index 25f8c8364..44b52663f 100644 --- a/batch_test.go +++ b/batch_test.go @@ -47,9 +47,9 @@ func TestBatch_Errors(t *testing.T) { t.Fatal(err) } - b := session.NewBatch(LoggedBatch) - b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil) - if err := session.ExecuteBatch(b); err == nil { + b := session.Batch(LoggedBatch) + b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil) + if err := b.Exec(); err == nil { t.Fatal("expected to get error for invalid query in batch") } } @@ -68,15 +68,17 @@ func TestBatch_WithTimestamp(t *testing.T) { micros := time.Now().UnixNano()/1e3 - 1000 - b := session.NewBatch(LoggedBatch) + b := session.Batch(LoggedBatch) b.WithTimestamp(micros) - b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val") - if err := session.ExecuteBatch(b); err != nil { + b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val") + b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val") + + if err := b.Exec(); err != nil { t.Fatal(err) } var storedTs int64 - if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil { + if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil { t.Fatal(err) } diff --git a/cassandra_test.go b/cassandra_test.go index 797a7cf7f..59b331f18 100644 --- a/cassandra_test.go +++ b/cassandra_test.go @@ -44,7 +44,7 @@ import ( "time" "unicode" - inf "gopkg.in/inf.v0" + "gopkg.in/inf.v0" ) func TestEmptyHosts(t *testing.T) { @@ -453,7 +453,7 @@ func TestCAS(t *testing.T) { t.Fatal("truncate:", err) } - successBatch := session.NewBatch(LoggedBatch) + successBatch := session.Batch(LoggedBatch) successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified) if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) @@ -461,7 +461,7 @@ func TestCAS(t *testing.T) { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } - successBatch = session.NewBatch(LoggedBatch) + successBatch = session.Batch(LoggedBatch) successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified) casMap := make(map[string]interface{}) if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil { @@ -470,7 +470,7 @@ func TestCAS(t *testing.T) { t.Fatal("insert should have been applied") } - failBatch := session.NewBatch(LoggedBatch) + failBatch := session.Batch(LoggedBatch) failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified) if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { t.Fatal("insert:", err) @@ -478,14 +478,14 @@ func TestCAS(t *testing.T) { t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS) } - insertBatch := session.NewBatch(LoggedBatch) + insertBatch := session.Batch(LoggedBatch) insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))") if err := session.ExecuteBatch(insertBatch); err != nil { t.Fatal("insert:", err) } - failBatch = session.NewBatch(LoggedBatch) + failBatch = session.Batch(LoggedBatch) failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());") failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());") if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil { @@ -610,7 +610,7 @@ func TestBatch(t *testing.T) { t.Fatal("create table:", err) } - batch := session.NewBatch(LoggedBatch) + batch := session.Batch(LoggedBatch) for i := 0; i < 100; i++ { batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i) } @@ -642,9 +642,9 @@ func TestUnpreparedBatch(t *testing.T) { var batch *Batch if session.cfg.ProtoVersion == 2 { - batch = session.NewBatch(CounterBatch) + batch = session.Batch(CounterBatch) } else { - batch = session.NewBatch(UnloggedBatch) + batch = session.Batch(UnloggedBatch) } for i := 0; i < 100; i++ { @@ -683,7 +683,7 @@ func TestBatchLimit(t *testing.T) { t.Fatal("create table:", err) } - batch := session.NewBatch(LoggedBatch) + batch := session.Batch(LoggedBatch) for i := 0; i < 65537; i++ { batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i) } @@ -737,7 +737,7 @@ func TestTooManyQueryArgs(t *testing.T) { t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error") } - batch := session.NewBatch(UnloggedBatch) + batch := session.Batch(UnloggedBatch) batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3) err = session.ExecuteBatch(batch) @@ -769,7 +769,7 @@ func TestNotEnoughQueryArgs(t *testing.T) { t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error") } - batch := session.NewBatch(UnloggedBatch) + batch := session.Batch(UnloggedBatch) batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2) err = session.ExecuteBatch(batch) @@ -1342,7 +1342,7 @@ func TestBatchQueryInfo(t *testing.T) { return values, nil } - batch := session.NewBatch(LoggedBatch) + batch := session.Batch(LoggedBatch) batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write) if err := session.ExecuteBatch(batch); err != nil { @@ -1470,7 +1470,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) { } stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch") - batch := session.NewBatch(UnloggedBatch) + batch := session.Batch(UnloggedBatch) batch.Query(stmt, "bar") if err := conn.executeBatch(ctx, batch).Close(); err != nil { t.Fatalf("Failed to execute query for reprepare statement: %v", err) @@ -1854,7 +1854,7 @@ func TestBatchStats(t *testing.T) { t.Fatalf("failed to create table with error '%v'", err) } - b := session.NewBatch(LoggedBatch) + b := session.Batch(LoggedBatch) b.Query("INSERT INTO batchStats (id) VALUES (?)", 1) b.Query("INSERT INTO batchStats (id) VALUES (?)", 2) @@ -1897,7 +1897,7 @@ func TestBatchObserve(t *testing.T) { var observedBatch *observation - batch := session.NewBatch(LoggedBatch) + batch := session.Batch(LoggedBatch) batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) { if observedBatch != nil { t.Fatal("batch observe called more than once") @@ -3236,7 +3236,7 @@ func TestUnsetColBatch(t *testing.T) { t.Fatalf("failed to create table with error '%v'", err) } - b := session.NewBatch(LoggedBatch) + b := session.Batch(LoggedBatch) b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue) b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "") b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue) diff --git a/doc.go b/doc.go index f23e812c5..e71d91a7f 100644 --- a/doc.go +++ b/doc.go @@ -300,7 +300,7 @@ // # Batches // // The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql. -// Use Session.NewBatch to create a new batch and then fill-in details of individual queries. +// Use Session.Batch to create a new batch and then fill-in details of individual queries. // Then execute the batch with Session.ExecuteBatch. // // Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have diff --git a/example_batch_test.go b/example_batch_test.go index 2695e48bd..b27085ccc 100644 --- a/example_batch_test.go +++ b/example_batch_test.go @@ -29,7 +29,7 @@ import ( "fmt" "log" - gocql "github.com/gocql/gocql" + "github.com/gocql/gocql" ) // Example_batch demonstrates how to execute a batch of statements. @@ -49,7 +49,7 @@ func Example_batch() { ctx := context.Background() - b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx) + b := session.Batch(gocql.UnloggedBatch).WithContext(ctx) b.Entries = append(b.Entries, gocql.BatchEntry{ Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", Args: []interface{}{1, 2, "1.2"}, @@ -60,11 +60,19 @@ func Example_batch() { Args: []interface{}{1, 3, "1.3"}, Idempotent: true, }) + err = session.ExecuteBatch(b) if err != nil { log.Fatal(err) } + err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4"). + Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5"). + Exec() + if err != nil { + log.Fatal(err) + } + scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner() for scanner.Next() { var pk, ck int32 @@ -77,4 +85,6 @@ func Example_batch() { } // 1 2 1.2 // 1 3 1.3 + // 1 4 1.4 + // 1 5 1.5 } diff --git a/example_lwt_batch_test.go b/example_lwt_batch_test.go index 916367eb3..c3cc8383d 100644 --- a/example_lwt_batch_test.go +++ b/example_lwt_batch_test.go @@ -29,7 +29,7 @@ import ( "fmt" "log" - gocql "github.com/gocql/gocql" + "github.com/gocql/gocql" ) // ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch lightweight transaction. @@ -62,7 +62,7 @@ func ExampleSession_MapExecuteBatchCAS() { } executeBatch := func(ck2Version int) { - b := session.NewBatch(gocql.LoggedBatch) + b := session.Batch(gocql.LoggedBatch) b.Entries = append(b.Entries, gocql.BatchEntry{ Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?", Args: []interface{}{"b", "pk1", "ck1", 1}, diff --git a/integration_test.go b/integration_test.go index 3622dfbd6..61ffbf504 100644 --- a/integration_test.go +++ b/integration_test.go @@ -218,7 +218,7 @@ func TestCustomPayloadMessages(t *testing.T) { iter.Close() // Batch Message - b := session.NewBatch(LoggedBatch) + b := session.Batch(LoggedBatch) b.CustomPayload = customPayload b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)") if err := session.ExecuteBatch(b); err != nil { diff --git a/session.go b/session.go index a600b95f3..047b401ae 100644 --- a/session.go +++ b/session.go @@ -731,6 +731,11 @@ func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter { return conn.executeBatch(ctx, b) } +func (b *Batch) Exec() error { + iter := b.session.executeBatch(b) + return iter.Close() +} + func (s *Session) executeBatch(batch *Batch) *Iter { // fail fast if s.Closed() { @@ -1760,6 +1765,8 @@ func NewBatch(typ BatchType) *Batch { } // NewBatch creates a new batch operation using defaults defined in the cluster +// +// Deprcated: use session.Batch instead func (s *Session) NewBatch(typ BatchType) *Batch { s.mu.RLock() batch := &Batch{ @@ -1781,6 +1788,27 @@ func (s *Session) NewBatch(typ BatchType) *Batch { return batch } +func (s *Session) Batch(typ BatchType) *Batch { + s.mu.RLock() + batch := &Batch{ + Type: typ, + rt: s.cfg.RetryPolicy, + serialCons: s.cfg.SerialConsistency, + trace: s.trace, + observer: s.batchObserver, + session: s, + Cons: s.cons, + defaultTimestamp: s.cfg.DefaultTimestamp, + keyspace: s.cfg.Keyspace, + metrics: &queryMetrics{m: make(map[string]*hostMetrics)}, + spec: &NonSpeculativeExecution{}, + routingInfo: &queryRoutingInfo{}, + } + + s.mu.RUnlock() + return batch +} + // Trace enables tracing of this batch. Look at the documentation of the // Tracer interface to learn more about tracing. func (b *Batch) Trace(trace Tracer) *Batch { @@ -1860,8 +1888,9 @@ func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch } // Query adds the query to the batch operation -func (b *Batch) Query(stmt string, args ...interface{}) { +func (b *Batch) Query(stmt string, args ...interface{}) *Batch { b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args}) + return b } // Bind adds the query to the batch operation and correlates it with a binding callback diff --git a/session_test.go b/session_test.go index 0319a8a4c..850e88531 100644 --- a/session_test.go +++ b/session_test.go @@ -96,7 +96,7 @@ func TestSessionAPI(t *testing.T) { t.Fatalf("expected itr.err to be '%v', got '%v'", ErrNoConnections, itr.err) } - testBatch := s.NewBatch(LoggedBatch) + testBatch := s.Batch(LoggedBatch) testBatch.Query("test") err := s.ExecuteBatch(testBatch) @@ -219,7 +219,7 @@ func TestBatchBasicAPI(t *testing.T) { s.pool = cfg.PoolConfig.buildPool(s) // Test UnloggedBatch - b := s.NewBatch(UnloggedBatch) + b := s.Batch(UnloggedBatch) if b.Type != UnloggedBatch { t.Fatalf("expceted batch.Type to be '%v', got '%v'", UnloggedBatch, b.Type) } else if b.rt != cfg.RetryPolicy { @@ -227,7 +227,7 @@ func TestBatchBasicAPI(t *testing.T) { } // Test LoggedBatch - b = s.NewBatch(LoggedBatch) + b = s.Batch(LoggedBatch) if b.Type != LoggedBatch { t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type) }