Skip to content

Commit

Permalink
Removed unnecessary DB transactions use (#99)
Browse files Browse the repository at this point in the history
  • Loading branch information
begmaroman authored Jun 27, 2024
1 parent be68a80 commit 97f3683
Show file tree
Hide file tree
Showing 13 changed files with 218 additions and 1,020 deletions.
104 changes: 50 additions & 54 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@ package db
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"

"github.com/0xPolygon/cdk-data-availability/types"
"github.com/ethereum/go-ethereum/common"
Expand All @@ -18,30 +18,21 @@ var (

// DB defines functions that a DB instance should implement
type DB interface {
BeginStateTransaction(ctx context.Context) (Tx, error)

StoreLastProcessedBlock(ctx context.Context, task string, block uint64, dbTx sqlx.ExecerContext) error
StoreLastProcessedBlock(ctx context.Context, task string, block uint64) error
GetLastProcessedBlock(ctx context.Context, task string) (uint64, error)

StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey, dbTx sqlx.ExecerContext) error
StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error
GetUnresolvedBatchKeys(ctx context.Context, limit uint) ([]types.BatchKey, error)
DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey, dbTx sqlx.ExecerContext) error
DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error

Exists(ctx context.Context, key common.Hash) bool
GetOffChainData(ctx context.Context, key common.Hash, dbTx sqlx.QueryerContext) (types.ArgBytes, error)
ListOffChainData(ctx context.Context, keys []common.Hash, dbTx sqlx.QueryerContext) (map[common.Hash]types.ArgBytes, error)
StoreOffChainData(ctx context.Context, od []types.OffChainData, dbTx sqlx.ExecerContext) error
GetOffChainData(ctx context.Context, key common.Hash) (types.ArgBytes, error)
ListOffChainData(ctx context.Context, keys []common.Hash) (map[common.Hash]types.ArgBytes, error)
StoreOffChainData(ctx context.Context, od []types.OffChainData) error

CountOffchainData(ctx context.Context) (uint64, error)
}

// Tx is the interface that defines functions a db tx has to implement
type Tx interface {
sqlx.ExecerContext
sqlx.QueryerContext
driver.Tx
}

// DB is the database layer of the data node
type pgDB struct {
pg *sqlx.DB
Expand All @@ -54,21 +45,16 @@ func New(pg *sqlx.DB) DB {
}
}

// BeginStateTransaction begins a DB transaction. The caller is responsible for committing or rolling back the transaction
func (db *pgDB) BeginStateTransaction(ctx context.Context) (Tx, error) {
return db.pg.BeginTxx(ctx, nil)
}

// StoreLastProcessedBlock stores a record of a block processed by the synchronizer for named task
func (db *pgDB) StoreLastProcessedBlock(ctx context.Context, task string, block uint64, dbTx sqlx.ExecerContext) error {
func (db *pgDB) StoreLastProcessedBlock(ctx context.Context, task string, block uint64) error {
const storeLastProcessedBlockSQL = `
INSERT INTO data_node.sync_tasks (task, block)
VALUES ($1, $2)
ON CONFLICT (task) DO UPDATE
SET block = EXCLUDED.block, processed = NOW();
`

if _, err := db.execer(dbTx).ExecContext(ctx, storeLastProcessedBlockSQL, task, block); err != nil {
if _, err := db.pg.ExecContext(ctx, storeLastProcessedBlockSQL, task, block); err != nil {
return err
}

Expand All @@ -91,25 +77,33 @@ func (db *pgDB) GetLastProcessedBlock(ctx context.Context, task string) (uint64,
}

// StoreUnresolvedBatchKeys stores unresolved batch keys in the database
func (db *pgDB) StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey, dbTx sqlx.ExecerContext) error {
func (db *pgDB) StoreUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error {
const storeUnresolvedBatchesSQL = `
INSERT INTO data_node.unresolved_batches (num, hash)
VALUES ($1, $2)
ON CONFLICT (num, hash) DO NOTHING;
`

execer := db.execer(dbTx)
tx, err := db.pg.BeginTxx(ctx, nil)
if err != nil {
return err
}

for _, bk := range bks {
if _, err := execer.ExecContext(
if _, err = tx.ExecContext(
ctx, storeUnresolvedBatchesSQL,
bk.Number,
bk.Hash.Hex(),
); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
}

return err
}
}

return nil
return tx.Commit()
}

// GetUnresolvedBatchKeys returns the unresolved batch keys from the database
Expand Down Expand Up @@ -143,23 +137,32 @@ func (db *pgDB) GetUnresolvedBatchKeys(ctx context.Context, limit uint) ([]types
}

// DeleteUnresolvedBatchKeys deletes the unresolved batch keys from the database
func (db *pgDB) DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey, dbTx sqlx.ExecerContext) error {
func (db *pgDB) DeleteUnresolvedBatchKeys(ctx context.Context, bks []types.BatchKey) error {
const deleteUnresolvedBatchKeysSQL = `
DELETE FROM data_node.unresolved_batches
WHERE num = $1 AND hash = $2;
`

tx, err := db.pg.BeginTxx(ctx, nil)
if err != nil {
return err
}

for _, bk := range bks {
if _, err := db.execer(dbTx).ExecContext(
if _, err = tx.ExecContext(
ctx, deleteUnresolvedBatchKeysSQL,
bk.Number,
bk.Hash.Hex(),
); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
}

return err
}
}

return nil
return tx.Commit()
}

// Exists checks if a key exists in offchain data table
Expand All @@ -178,29 +181,37 @@ func (db *pgDB) Exists(ctx context.Context, key common.Hash) bool {
}

// StoreOffChainData stores and array of key values in the Db
func (db *pgDB) StoreOffChainData(ctx context.Context, od []types.OffChainData, dbTx sqlx.ExecerContext) error {
func (db *pgDB) StoreOffChainData(ctx context.Context, od []types.OffChainData) error {
const storeOffChainDataSQL = `
INSERT INTO data_node.offchain_data (key, value)
VALUES ($1, $2)
ON CONFLICT (key) DO NOTHING;
`

execer := db.execer(dbTx)
tx, err := db.pg.BeginTxx(ctx, nil)
if err != nil {
return err
}

for _, d := range od {
if _, err := execer.ExecContext(
if _, err = tx.ExecContext(
ctx, storeOffChainDataSQL,
d.Key.Hex(),
common.Bytes2Hex(d.Value),
); err != nil {
if txErr := tx.Rollback(); txErr != nil {
return fmt.Errorf("%v: rollback caused by %v", txErr, err)
}

return err
}
}

return nil
return tx.Commit()
}

// GetOffChainData returns the value identified by the key
func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash, dbTx sqlx.QueryerContext) (types.ArgBytes, error) {
func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash) (types.ArgBytes, error) {
const getOffchainDataSQL = `
SELECT value
FROM data_node.offchain_data
Expand All @@ -211,18 +222,19 @@ func (db *pgDB) GetOffChainData(ctx context.Context, key common.Hash, dbTx sqlx.
hexValue string
)

if err := db.querier(dbTx).QueryRowxContext(ctx, getOffchainDataSQL, key.Hex()).Scan(&hexValue); err != nil {
if err := db.pg.QueryRowxContext(ctx, getOffchainDataSQL, key.Hex()).Scan(&hexValue); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrStateNotSynchronized
}

return nil, err
}

return common.FromHex(hexValue), nil
}

// ListOffChainData returns values identified by the given keys
func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash, dbTx sqlx.QueryerContext) (map[common.Hash]types.ArgBytes, error) {
func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash) (map[common.Hash]types.ArgBytes, error) {
if len(keys) == 0 {
return nil, nil
}
Expand All @@ -246,7 +258,7 @@ func (db *pgDB) ListOffChainData(ctx context.Context, keys []common.Hash, dbTx s
// sqlx.In returns queries with the `?` bindvar, we can rebind it for our backend
query = db.pg.Rebind(query)

rows, err := db.querier(dbTx).QueryxContext(ctx, query, args...)
rows, err := db.pg.QueryxContext(ctx, query, args...)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -280,19 +292,3 @@ func (db *pgDB) CountOffchainData(ctx context.Context) (uint64, error) {

return count, nil
}

func (db *pgDB) execer(dbTx sqlx.ExecerContext) sqlx.ExecerContext {
if dbTx != nil {
return dbTx
}

return db.pg
}

func (db *pgDB) querier(dbTx sqlx.QueryerContext) sqlx.QueryerContext {
if dbTx != nil {
return dbTx
}

return db.pg
}
48 changes: 27 additions & 21 deletions db/db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ func Test_DB_StoreLastProcessedBlock(t *testing.T) {

dbPG := New(wdb)

err = dbPG.StoreLastProcessedBlock(context.Background(), tt.task, tt.block, wdb)
err = dbPG.StoreLastProcessedBlock(context.Background(), tt.task, tt.block)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -116,7 +116,7 @@ func Test_DB_GetLastProcessedBlock(t *testing.T) {

dbPG := New(wdb)

err = dbPG.StoreLastProcessedBlock(context.Background(), tt.task, tt.block, wdb)
err = dbPG.StoreLastProcessedBlock(context.Background(), tt.task, tt.block)
require.NoError(t, err)

actual, err := dbPG.GetLastProcessedBlock(context.Background(), tt.task)
Expand Down Expand Up @@ -179,6 +179,7 @@ func Test_DB_StoreUnresolvedBatchKeys(t *testing.T) {

defer db.Close()

mock.ExpectBegin()
for _, o := range tt.bk {
expected := mock.ExpectExec(`INSERT INTO data_node\.unresolved_batches \(num, hash\) VALUES \(\$1, \$2\) ON CONFLICT \(num, hash\) DO NOTHING`).
WithArgs(o.Number, o.Hash.Hex())
Expand All @@ -188,12 +189,17 @@ func Test_DB_StoreUnresolvedBatchKeys(t *testing.T) {
expected.WillReturnResult(sqlmock.NewResult(int64(len(tt.bk)), int64(len(tt.bk))))
}
}
if tt.returnErr == nil {
mock.ExpectCommit()
} else {
mock.ExpectRollback()
}

wdb := sqlx.NewDb(db, "postgres")

dbPG := New(wdb)

err = dbPG.StoreUnresolvedBatchKeys(context.Background(), tt.bk, wdb)
err = dbPG.StoreUnresolvedBatchKeys(context.Background(), tt.bk)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -304,6 +310,7 @@ func Test_DB_DeleteUnresolvedBatchKeys(t *testing.T) {

defer db.Close()

mock.ExpectBegin()
for _, bk := range tt.bks {
expected := mock.ExpectExec(`DELETE FROM data_node\.unresolved_batches WHERE num = \$1 AND hash = \$2`).
WithArgs(bk.Number, bk.Hash.Hex())
Expand All @@ -313,12 +320,17 @@ func Test_DB_DeleteUnresolvedBatchKeys(t *testing.T) {
expected.WillReturnResult(sqlmock.NewResult(int64(len(tt.bks)), int64(len(tt.bks))))
}
}
if tt.returnErr != nil {
mock.ExpectRollback()
} else {
mock.ExpectCommit()
}

wdb := sqlx.NewDb(db, "postgres")

dbPG := New(wdb)

err = dbPG.DeleteUnresolvedBatchKeys(context.Background(), tt.bks, wdb)
err = dbPG.DeleteUnresolvedBatchKeys(context.Background(), tt.bks)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -377,6 +389,7 @@ func Test_DB_StoreOffChainData(t *testing.T) {

defer db.Close()

mock.ExpectBegin()
for _, o := range tt.od {
expected := mock.ExpectExec(`INSERT INTO data_node\.offchain_data \(key, value\) VALUES \(\$1, \$2\) ON CONFLICT \(key\) DO NOTHING`).
WithArgs(o.Key.Hex(), common.Bytes2Hex(o.Value))
Expand All @@ -386,12 +399,17 @@ func Test_DB_StoreOffChainData(t *testing.T) {
expected.WillReturnResult(sqlmock.NewResult(int64(len(tt.od)), int64(len(tt.od))))
}
}
if tt.returnErr == nil {
mock.ExpectCommit()
} else {
mock.ExpectRollback()
}

wdb := sqlx.NewDb(db, "postgres")

dbPG := New(wdb)

err = dbPG.StoreOffChainData(context.Background(), tt.od, wdb)
err = dbPG.StoreOffChainData(context.Background(), tt.od)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -467,7 +485,7 @@ func Test_DB_GetOffChainData(t *testing.T) {

dbPG := New(wdb)

data, err := dbPG.GetOffChainData(context.Background(), tt.key, wdb)
data, err := dbPG.GetOffChainData(context.Background(), tt.key)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -586,7 +604,7 @@ func Test_DB_ListOffChainData(t *testing.T) {

dbPG := New(wdb)

data, err := dbPG.ListOffChainData(context.Background(), tt.keys, wdb)
data, err := dbPG.ListOffChainData(context.Background(), tt.keys)
if tt.returnErr != nil {
require.ErrorIs(t, err, tt.returnErr)
} else {
Expand Down Expand Up @@ -760,13 +778,7 @@ func seedOffchainData(t *testing.T, db *sqlx.DB, mock sqlmock.Sqlmock, od []type
}
mock.ExpectCommit()

tx, err := db.BeginTxx(context.Background(), nil)
require.NoError(t, err)

err = New(db).StoreOffChainData(context.Background(), od, tx)
require.NoError(t, err)

err = tx.Commit()
err := New(db).StoreOffChainData(context.Background(), od)
require.NoError(t, err)
}

Expand All @@ -781,12 +793,6 @@ func seedUnresolvedBatchKeys(t *testing.T, db *sqlx.DB, mock sqlmock.Sqlmock, bk
}
mock.ExpectCommit()

tx, err := db.BeginTxx(context.Background(), nil)
require.NoError(t, err)

err = New(db).StoreUnresolvedBatchKeys(context.Background(), bk, tx)
require.NoError(t, err)

err = tx.Commit()
err := New(db).StoreUnresolvedBatchKeys(context.Background(), bk)
require.NoError(t, err)
}
Loading

0 comments on commit 97f3683

Please sign in to comment.