Skip to content

Commit

Permalink
Change Batch API to be consistent with Query()
Browse files Browse the repository at this point in the history
Exec() method for batch was added & Query() method was refactored.
Batch for now behaves the same way as query.

patch by Oleksandr Luzhniy; reviewed by João Reis, Danylo Savchenko, Bohdan Siryk, Jackson Fleming, for CASSGO-7
  • Loading branch information
tengu-alt authored and sylwiaszunejko committed Jan 23, 2025
1 parent d0a8589 commit b2d210a
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 34 deletions.
16 changes: 9 additions & 7 deletions batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ func TestBatch_Errors(t *testing.T) {
t.Fatal(err)
}

b := session.NewBatch(LoggedBatch)
b.Query("SELECT * FROM batch_errors WHERE id=2 AND val=?", nil)
if err := session.ExecuteBatch(b); err == nil {
b := session.Batch(LoggedBatch)
b = b.Query("SELECT * FROM gocql_test.batch_errors WHERE id=2 AND val=?", nil)
if err := b.Exec(); err == nil {
t.Fatal("expected to get error for invalid query in batch")
}
}
Expand All @@ -44,15 +44,17 @@ func TestBatch_WithTimestamp(t *testing.T) {

micros := time.Now().UnixNano()/1e3 - 1000

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.WithTimestamp(micros)
b.Query("INSERT INTO batch_ts (id, val) VALUES (?, ?)", 1, "val")
if err := session.ExecuteBatch(b); err != nil {
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 1, "val")
b = b.Query("INSERT INTO gocql_test.batch_ts (id, val) VALUES (?, ?)", 2, "val")

if err := b.Exec(); err != nil {
t.Fatal(err)
}

var storedTs int64
if err := session.Query(`SELECT writetime(val) FROM batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
if err := session.Query(`SELECT writetime(val) FROM gocql_test.batch_ts WHERE id = ?`, 1).Scan(&storedTs); err != nil {
t.Fatal(err)
}

Expand Down
34 changes: 17 additions & 17 deletions cassandra_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ import (
"time"
"unicode"

inf "gopkg.in/inf.v0"
"gopkg.in/inf.v0"
)

func TestEmptyHosts(t *testing.T) {
Expand Down Expand Up @@ -565,15 +565,15 @@ func TestCAS(t *testing.T) {
t.Fatal("truncate:", err)
}

successBatch := session.NewBatch(LoggedBatch)
successBatch := session.Batch(LoggedBatch)
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
t.Fatal("insert:", err)
} else if !applied {
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
}

successBatch = session.NewBatch(LoggedBatch)
successBatch = session.Batch(LoggedBatch)
successBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title+"_foo", revid, modified)
casMap := make(map[string]interface{})
if applied, _, err := session.MapExecuteBatchCAS(successBatch, casMap); err != nil {
Expand All @@ -582,22 +582,22 @@ func TestCAS(t *testing.T) {
t.Fatal("insert should have been applied")
}

failBatch := session.NewBatch(LoggedBatch)
failBatch := session.Batch(LoggedBatch)
failBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES (?, ?, ?) IF NOT EXISTS", title, revid, modified)
if applied, _, err := session.ExecuteBatchCAS(successBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
t.Fatal("insert:", err)
} else if applied {
t.Fatalf("insert should have been applied: title=%v revID=%v modified=%v", titleCAS, revidCAS, modifiedCAS)
}

insertBatch := session.NewBatch(LoggedBatch)
insertBatch := session.Batch(LoggedBatch)
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 2c3af400-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
insertBatch.Query("INSERT INTO cas_table (title, revid, last_modified) VALUES ('_foo', 3e4ad2f1-73a4-11e5-9381-29463d90c3f0, DATEOF(NOW()))")
if err := session.ExecuteBatch(insertBatch); err != nil {
t.Fatal("insert:", err)
}

failBatch = session.NewBatch(LoggedBatch)
failBatch = session.Batch(LoggedBatch)
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=2c3af400-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
failBatch.Query("UPDATE cas_table SET last_modified = DATEOF(NOW()) WHERE title='_foo' AND revid=3e4ad2f1-73a4-11e5-9381-29463d90c3f0 IF last_modified=DATEOF(NOW());")
if applied, iter, err := session.ExecuteBatchCAS(failBatch, &titleCAS, &revidCAS, &modifiedCAS); err != nil {
Expand Down Expand Up @@ -722,7 +722,7 @@ func TestBatch(t *testing.T) {
t.Fatal("create table:", err)
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
for i := 0; i < 100; i++ {
batch.Query(`INSERT INTO batch_table (id) VALUES (?)`, i)
}
Expand Down Expand Up @@ -754,9 +754,9 @@ func TestUnpreparedBatch(t *testing.T) {

var batch *Batch
if session.cfg.ProtoVersion == 2 {
batch = session.NewBatch(CounterBatch)
batch = session.Batch(CounterBatch)
} else {
batch = session.NewBatch(UnloggedBatch)
batch = session.Batch(UnloggedBatch)
}

for i := 0; i < 100; i++ {
Expand Down Expand Up @@ -795,7 +795,7 @@ func TestBatchLimit(t *testing.T) {
t.Fatal("create table:", err)
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
for i := 0; i < 65537; i++ {
batch.Query(`INSERT INTO batch_table2 (id) VALUES (?)`, i)
}
Expand Down Expand Up @@ -849,7 +849,7 @@ func TestTooManyQueryArgs(t *testing.T) {
t.Fatal("'`SELECT * FROM too_many_query_args WHERE id = ?`, 1, 2' should return an error")
}

batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query("INSERT INTO too_many_query_args (id, value) VALUES (?, ?)", 1, 2, 3)
err = session.ExecuteBatch(batch)

Expand Down Expand Up @@ -881,7 +881,7 @@ func TestNotEnoughQueryArgs(t *testing.T) {
t.Fatal("'`SELECT * FROM not_enough_query_args WHERE id = ? and cluster = ?`, 1' should return an error")
}

batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query("INSERT INTO not_enough_query_args (id, cluster, value) VALUES (?, ?, ?)", 1, 2)
err = session.ExecuteBatch(batch)

Expand Down Expand Up @@ -1454,7 +1454,7 @@ func TestBatchQueryInfo(t *testing.T) {
return values, nil
}

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
batch.Bind("INSERT INTO batch_query_info (id, cluster, value) VALUES (?, ?,?)", write)

if err := session.ExecuteBatch(batch); err != nil {
Expand Down Expand Up @@ -1582,7 +1582,7 @@ func TestPrepare_ReprepareBatch(t *testing.T) {
}

stmt, conn := injectInvalidPreparedStatement(t, session, "test_reprepare_statement_batch")
batch := session.NewBatch(UnloggedBatch)
batch := session.Batch(UnloggedBatch)
batch.Query(stmt, "bar")
if err := conn.executeBatch(ctx, batch).Close(); err != nil {
t.Fatalf("Failed to execute query for reprepare statement: %v", err)
Expand Down Expand Up @@ -1966,7 +1966,7 @@ func TestBatchStats(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.Query("INSERT INTO batchStats (id) VALUES (?)", 1)
b.Query("INSERT INTO batchStats (id) VALUES (?)", 2)

Expand Down Expand Up @@ -2009,7 +2009,7 @@ func TestBatchObserve(t *testing.T) {

var observedBatch *observation

batch := session.NewBatch(LoggedBatch)
batch := session.Batch(LoggedBatch)
batch.Observer(funcBatchObserver(func(ctx context.Context, o ObservedBatch) {
if observedBatch != nil {
t.Fatal("batch observe called more than once")
Expand Down Expand Up @@ -2632,7 +2632,7 @@ func TestUnsetColBatch(t *testing.T) {
t.Fatalf("failed to create table with error '%v'", err)
}

b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, 1, UnsetValue)
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 1, UnsetValue, "")
b.Query("INSERT INTO gocql_test.batchUnsetInsert(id, my_int, my_text) VALUES (?,?,?)", 2, 2, UnsetValue)
Expand Down
2 changes: 1 addition & 1 deletion doc.go
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@
// # Batches
//
// The CQL protocol supports sending batches of DML statements (INSERT/UPDATE/DELETE) and so does gocql.
// Use Session.NewBatch to create a new batch and then fill-in details of individual queries.
// Use Session.Batch to create a new batch and then fill-in details of individual queries.
// Then execute the batch with Session.ExecuteBatch.
//
// Logged batches ensure atomicity, either all or none of the operations in the batch will succeed, but they have
Expand Down
15 changes: 13 additions & 2 deletions example_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package gocql_test
import (
"context"
"fmt"
"github.com/gocql/gocql"
"log"

"github.com/gocql/gocql"
)

// Example_batch demonstrates how to execute a batch of statements.
Expand All @@ -24,7 +25,7 @@ func Example_batch() {

ctx := context.Background()

b := session.NewBatch(gocql.UnloggedBatch).WithContext(ctx)
b := session.Batch(gocql.UnloggedBatch).WithContext(ctx)
b.Entries = append(b.Entries, gocql.BatchEntry{
Stmt: "INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)",
Args: []interface{}{1, 2, "1.2"},
Expand All @@ -35,11 +36,19 @@ func Example_batch() {
Args: []interface{}{1, 3, "1.3"},
Idempotent: true,
})

err = session.ExecuteBatch(b)
if err != nil {
log.Fatal(err)
}

err = b.Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 4, "1.4").
Query("INSERT INTO example.batches (pk, ck, description) VALUES (?, ?, ?)", 1, 5, "1.5").
Exec()
if err != nil {
log.Fatal(err)
}

scanner := session.Query("SELECT pk, ck, description FROM example.batches").Iter().Scanner()
for scanner.Next() {
var pk, ck int32
Expand All @@ -52,4 +61,6 @@ func Example_batch() {
}
// 1 2 1.2
// 1 3 1.3
// 1 4 1.4
// 1 5 1.5
}
5 changes: 3 additions & 2 deletions example_lwt_batch_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@ package gocql_test
import (
"context"
"fmt"
"github.com/gocql/gocql"
"log"

"github.com/gocql/gocql"
)

// ExampleSession_MapExecuteBatchCAS demonstrates how to execute a batch lightweight transaction.
Expand Down Expand Up @@ -37,7 +38,7 @@ func ExampleSession_MapExecuteBatchCAS() {
}

executeBatch := func(ck2Version int) {
b := session.NewBatch(gocql.LoggedBatch)
b := session.Batch(gocql.LoggedBatch)
b.Entries = append(b.Entries, gocql.BatchEntry{
Stmt: "UPDATE my_lwt_batch_table SET value=? WHERE pk=? AND ck=? IF version=?",
Args: []interface{}{"b", "pk1", "ck1", 1},
Expand Down
2 changes: 1 addition & 1 deletion integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ func TestCustomPayloadMessages(t *testing.T) {
iter.Close()

// Batch Message
b := session.NewBatch(LoggedBatch)
b := session.Batch(LoggedBatch)
b.CustomPayload = customPayload
b.Query("INSERT INTO testCustomPayloadMessages(id,value) VALUES(1, 1)")
if err := session.ExecuteBatch(b); err != nil {
Expand Down
17 changes: 16 additions & 1 deletion session.go
Original file line number Diff line number Diff line change
Expand Up @@ -856,6 +856,13 @@ func (b *Batch) execute(ctx context.Context, conn *Conn) *Iter {
return conn.executeBatch(ctx, b)
}

// Exec executes a batch operation and returns nil if successful
// otherwise an error is returned describing the failure.
func (b *Batch) Exec() error {
iter := b.session.executeBatch(b)
return iter.Close()
}

func (s *Session) executeBatch(batch *Batch) *Iter {
// fail fast
if s.Closed() {
Expand Down Expand Up @@ -1933,7 +1940,14 @@ type Batch struct {
}

// NewBatch creates a new batch operation using defaults defined in the cluster
//
// Deprecated: use session.Batch instead
func (s *Session) NewBatch(typ BatchType) *Batch {
return s.Batch(typ)
}

// Batch creates a new batch operation using defaults defined in the cluster
func (s *Session) Batch(typ BatchType) *Batch {
s.mu.RLock()
batch := &Batch{
Type: typ,
Expand Down Expand Up @@ -2045,8 +2059,9 @@ func (b *Batch) SpeculativeExecutionPolicy(sp SpeculativeExecutionPolicy) *Batch
}

// Query adds the query to the batch operation
func (b *Batch) Query(stmt string, args ...interface{}) {
func (b *Batch) Query(stmt string, args ...interface{}) *Batch {
b.Entries = append(b.Entries, BatchEntry{Stmt: stmt, Args: args})
return b
}

// Bind adds the query to the batch operation and correlates it with a binding callback
Expand Down
6 changes: 3 additions & 3 deletions session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func TestSessionAPI(t *testing.T) {
t.Fatalf("expected itr.err to be '%v', got '%v'", ErrSessionNotReady, itr.err)
}

testBatch := s.NewBatch(LoggedBatch)
testBatch := s.Batch(LoggedBatch)
testBatch.Query("test")
err := s.ExecuteBatch(testBatch)

Expand Down Expand Up @@ -205,15 +205,15 @@ func TestBatchBasicAPI(t *testing.T) {
s.pool = cfg.PoolConfig.buildPool(s)

// Test UnloggedBatch
b := s.NewBatch(UnloggedBatch)
b := s.Batch(UnloggedBatch)
if b.Type != UnloggedBatch {
t.Fatalf("expceted batch.Type to be '%v', got '%v'", UnloggedBatch, b.Type)
} else if b.rt != cfg.RetryPolicy {
t.Fatalf("expceted batch.RetryPolicy to be '%v', got '%v'", cfg.RetryPolicy, b.rt)
}

// Test LoggedBatch
b = s.NewBatch(LoggedBatch)
b = s.Batch(LoggedBatch)
if b.Type != LoggedBatch {
t.Fatalf("expected batch.Type to be '%v', got '%v'", LoggedBatch, b.Type)
}
Expand Down

0 comments on commit b2d210a

Please sign in to comment.