diff --git a/pkg/backfill/backfill.go b/pkg/backfill/backfill.go index 41e86d2d8..ea9923d05 100644 --- a/pkg/backfill/backfill.go +++ b/pkg/backfill/backfill.go @@ -7,11 +7,9 @@ import ( "database/sql" "errors" "fmt" - "strings" "time" - "github.com/lib/pq" - + "github.com/xataio/pgroll/pkg/backfill/templates" "github.com/xataio/pgroll/pkg/db" "github.com/xataio/pgroll/pkg/schema" ) @@ -59,7 +57,13 @@ func (bf *Backfill) Start(ctx context.Context, table *schema.Table) error { } // Create a batcher for the table. - b := newBatcher(table, bf.batchSize) + b := batcher{ + BatchConfig: templates.BatchConfig{ + TableName: table.Name, + PrimaryKey: identityColumns, + BatchSize: bf.batchSize, + }, + } // Update each batch of rows, invoking callbacks for each one. for batch := 0; ; batch++ { @@ -158,30 +162,30 @@ func getIdentityColumns(table *schema.Table) []string { return nil } +// A batcher is responsible for updating a batch of rows in a table. +// It holds the state necessary to update the next batch of rows. type batcher struct { - statementBuilder *batchStatementBuilder - lastValues []string -} - -func newBatcher(table *schema.Table, batchSize int) *batcher { - return &batcher{ - statementBuilder: newBatchStatementBuilder(table.Name, getIdentityColumns(table), batchSize), - lastValues: make([]string, len(getIdentityColumns(table))), - } + templates.BatchConfig } func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error { return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error { // Build the query to update the next batch of rows - query := b.statementBuilder.buildQuery(b.lastValues) + sql, err := templates.BuildSQL(b.BatchConfig) + if err != nil { + return err + } // Execute the query to update the next batch of rows and update the last PK // value for the next batch - wrapper := make([]any, len(b.lastValues)) - for i := range b.lastValues { - wrapper[i] = &b.lastValues[i] + if b.LastValue == nil { + b.LastValue = make([]string, len(b.PrimaryKey)) + } + wrapper := make([]any, len(b.LastValue)) + for i := range b.LastValue { + wrapper[i] = &b.LastValue[i] } - err := tx.QueryRowContext(ctx, query).Scan(wrapper...) + err = tx.QueryRowContext(ctx, sql).Scan(wrapper...) if err != nil { return err } @@ -189,78 +193,3 @@ func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error { return nil }) } - -type batchStatementBuilder struct { - tableName string - identityColumns []string - batchSize int -} - -func newBatchStatementBuilder(tableName string, identityColumnNames []string, batchSize int) *batchStatementBuilder { - quotedCols := make([]string, len(identityColumnNames)) - for i, col := range identityColumnNames { - quotedCols[i] = pq.QuoteIdentifier(col) - } - return &batchStatementBuilder{ - tableName: pq.QuoteIdentifier(tableName), - identityColumns: quotedCols, - batchSize: batchSize, - } -} - -// buildQuery builds the query used to update the next batch of rows. -func (sb *batchStatementBuilder) buildQuery(lastValues []string) string { - return fmt.Sprintf("WITH batch AS (%[1]s), update AS (%[2]s) %[3]s", - sb.buildBatchSubQuery(lastValues), - sb.buildUpdateBatchSubQuery(), - sb.buildLastValueQuery()) -} - -// fetch the next batch of PK of rows to update -func (sb *batchStatementBuilder) buildBatchSubQuery(lastValues []string) string { - whereClause := "" - if len(lastValues) != 0 && lastValues[0] != "" { - whereClause = fmt.Sprintf("WHERE (%s) > (%s)", - strings.Join(sb.identityColumns, ", "), strings.Join(quoteLiteralList(lastValues), ", ")) - } - - return fmt.Sprintf("SELECT %[1]s FROM %[2]s %[3]s ORDER BY %[1]s LIMIT %[4]d FOR NO KEY UPDATE", - strings.Join(sb.identityColumns, ", "), sb.tableName, whereClause, sb.batchSize) -} - -func quoteLiteralList(l []string) []string { - quoted := make([]string, len(l)) - for i, v := range l { - quoted[i] = pq.QuoteLiteral(v) - } - return quoted -} - -// update the rows in the batch -func (sb *batchStatementBuilder) buildUpdateBatchSubQuery() string { - conditions := make([]string, len(sb.identityColumns)) - for i, col := range sb.identityColumns { - conditions[i] = fmt.Sprintf("%[1]s.%[2]s = batch.%[2]s", sb.tableName, col) - } - updateWhereClause := "WHERE " + strings.Join(conditions, " AND ") - - setStmt := fmt.Sprintf("%[1]s = %[2]s.%[1]s", sb.identityColumns[0], sb.tableName) - for i := 1; i < len(sb.identityColumns); i++ { - setStmt += fmt.Sprintf(", %[1]s = %[2]s.%[1]s", sb.identityColumns[i], sb.tableName) - } - updateReturning := sb.tableName + "." + sb.identityColumns[0] - for i := 1; i < len(sb.identityColumns); i++ { - updateReturning += ", " + sb.tableName + "." + sb.identityColumns[i] - } - return fmt.Sprintf("UPDATE %[1]s SET %[2]s FROM batch %[3]s RETURNING %[4]s", - sb.tableName, setStmt, updateWhereClause, updateReturning) -} - -// fetch the last values of the PK column -func (sb *batchStatementBuilder) buildLastValueQuery() string { - lastValues := make([]string, len(sb.identityColumns)) - for i, col := range sb.identityColumns { - lastValues[i] = "LAST_VALUE(" + col + ") OVER()" - } - return fmt.Sprintf("SELECT %[1]s FROM update", strings.Join(lastValues, ", ")) -} diff --git a/pkg/backfill/backfill_test.go b/pkg/backfill/backfill_test.go deleted file mode 100644 index 2fd304575..000000000 --- a/pkg/backfill/backfill_test.go +++ /dev/null @@ -1,54 +0,0 @@ -// SPDX-License-Identifier: Apache-2.0 - -package backfill - -import ( - "testing" - - "github.com/stretchr/testify/assert" -) - -func TestBatchStatementBuilder(t *testing.T) { - tests := map[string]struct { - tableName string - identityColumns []string - batchSize int - lasValues []string - expected string - }{ - "single identity column no last value": { - tableName: "table_name", - identityColumns: []string{"id"}, - batchSize: 10, - expected: `WITH batch AS (SELECT "id" FROM "table_name" ORDER BY "id" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id" FROM batch WHERE "table_name"."id" = batch."id" RETURNING "table_name"."id") SELECT LAST_VALUE("id") OVER() FROM update`, - }, - "multiple identity columns no last value": { - tableName: "table_name", - identityColumns: []string{"id", "zip"}, - batchSize: 10, - expected: `WITH batch AS (SELECT "id", "zip" FROM "table_name" ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`, - }, - "single identity column with last value": { - tableName: "table_name", - identityColumns: []string{"id"}, - batchSize: 10, - lasValues: []string{"1"}, - expected: `WITH batch AS (SELECT "id" FROM "table_name" WHERE ("id") > ('1') ORDER BY "id" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id" FROM batch WHERE "table_name"."id" = batch."id" RETURNING "table_name"."id") SELECT LAST_VALUE("id") OVER() FROM update`, - }, - "multiple identity columns with last value": { - tableName: "table_name", - identityColumns: []string{"id", "zip"}, - batchSize: 10, - lasValues: []string{"1", "1234"}, - expected: `WITH batch AS (SELECT "id", "zip" FROM "table_name" WHERE ("id", "zip") > ('1', '1234') ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`, - }, - } - - for name, test := range tests { - t.Run(name, func(t *testing.T) { - builder := newBatchStatementBuilder(test.tableName, test.identityColumns, test.batchSize) - actual := builder.buildQuery(test.lasValues) - assert.Equal(t, test.expected, actual) - }) - } -} diff --git a/pkg/backfill/templates/build.go b/pkg/backfill/templates/build.go new file mode 100644 index 000000000..9e6fe706a --- /dev/null +++ b/pkg/backfill/templates/build.go @@ -0,0 +1,86 @@ +// SPDX-License-Identifier: Apache-2.0 + +package templates + +import ( + "bytes" + "strings" + "text/template" + + "github.com/lib/pq" +) + +type BatchConfig struct { + TableName string + PrimaryKey []string + LastValue []string + BatchSize int +} + +func BuildSQL(cfg BatchConfig) (string, error) { + return executeTemplate("sql", SQL, cfg) +} + +func executeTemplate(name, content string, cfg BatchConfig) (string, error) { + ql := pq.QuoteLiteral + qi := pq.QuoteIdentifier + + tmpl := template.Must(template.New(name). + Funcs(template.FuncMap{ + "ql": ql, + "qi": qi, + "commaSeparate": func(slice []string) string { + return strings.Join(slice, ", ") + }, + "quoteIdentifiers": func(slice []string) []string { + quoted := make([]string, len(slice)) + for i, s := range slice { + quoted[i] = qi(s) + } + return quoted + }, + "quoteLiterals": func(slice []string) []string { + quoted := make([]string, len(slice)) + for i, s := range slice { + quoted[i] = ql(s) + } + return quoted + }, + "updateSetClause": func(tableName string, columns []string) string { + quoted := make([]string, len(columns)) + for i, c := range columns { + quoted[i] = qi(c) + " = " + qi(tableName) + "." + qi(c) + } + return strings.Join(quoted, ", ") + }, + "updateWhereClause": func(tableName string, columns []string) string { + quoted := make([]string, len(columns)) + for i, c := range columns { + quoted[i] = qi(tableName) + "." + qi(c) + " = batch." + qi(c) + } + return strings.Join(quoted, " AND ") + }, + "updateReturnClause": func(tableName string, columns []string) string { + quoted := make([]string, len(columns)) + for i, c := range columns { + quoted[i] = qi(tableName) + "." + qi(c) + } + return strings.Join(quoted, ", ") + }, + "selectLastValue": func(columns []string) string { + quoted := make([]string, len(columns)) + for i, c := range columns { + quoted[i] = "LAST_VALUE(" + qi(c) + ") OVER()" + } + return strings.Join(quoted, ", ") + }, + }). + Parse(content)) + + buf := bytes.Buffer{} + if err := tmpl.Execute(&buf, cfg); err != nil { + return "", err + } + + return buf.String(), nil +} diff --git a/pkg/backfill/templates/build_test.go b/pkg/backfill/templates/build_test.go new file mode 100644 index 000000000..4e9128c1e --- /dev/null +++ b/pkg/backfill/templates/build_test.go @@ -0,0 +1,142 @@ +// SPDX-License-Identifier: Apache-2.0 + +package templates + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestBatchStatementBuilder(t *testing.T) { + tests := map[string]struct { + config BatchConfig + expected string + }{ + "single identity column no last value": { + config: BatchConfig{ + TableName: "table_name", + PrimaryKey: []string{"id"}, + BatchSize: 10, + }, + expected: expectSingleIDColumnNoLastValue, + }, + "multiple identity columns no last value": { + config: BatchConfig{ + TableName: "table_name", + PrimaryKey: []string{"id", "zip"}, + BatchSize: 10, + }, + expected: multipleIDColumnsNoLastValue, + }, + "single identity column with last value": { + config: BatchConfig{ + TableName: "table_name", + PrimaryKey: []string{"id"}, + LastValue: []string{"1"}, + BatchSize: 10, + }, + expected: singleIDColumnWithLastValue, + }, + "multiple identity columns with last value": { + config: BatchConfig{ + TableName: "table_name", + PrimaryKey: []string{"id", "zip"}, + LastValue: []string{"1", "1234"}, + BatchSize: 10, + }, + expected: multipleIDColumnsWithLastValue, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + actual, err := BuildSQL(test.config) + assert.NoError(t, err) + + assert.Equal(t, test.expected, actual) + }) + } +} + +const expectSingleIDColumnNoLastValue = `WITH batch AS +( + SELECT "id" + FROM "table_name" + ORDER BY "id" + LIMIT 10 + FOR NO KEY UPDATE +), +update AS +( + UPDATE "table_name" + SET "id" = "table_name"."id" + FROM batch + WHERE "table_name"."id" = batch."id" + RETURNING "table_name"."id" +) +SELECT LAST_VALUE("id") OVER() +FROM update +` + +const multipleIDColumnsNoLastValue = `WITH batch AS +( + SELECT "id", "zip" + FROM "table_name" + ORDER BY "id", "zip" + LIMIT 10 + FOR NO KEY UPDATE +), +update AS +( + UPDATE "table_name" + SET "id" = "table_name"."id", "zip" = "table_name"."zip" + FROM batch + WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" + RETURNING "table_name"."id", "table_name"."zip" +) +SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() +FROM update +` + +const singleIDColumnWithLastValue = `WITH batch AS +( + SELECT "id" + FROM "table_name" + WHERE ("id") > ('1') + ORDER BY "id" + LIMIT 10 + FOR NO KEY UPDATE +), +update AS +( + UPDATE "table_name" + SET "id" = "table_name"."id" + FROM batch + WHERE "table_name"."id" = batch."id" + RETURNING "table_name"."id" +) +SELECT LAST_VALUE("id") OVER() +FROM update +` + +const multipleIDColumnsWithLastValue = `WITH batch AS +( + SELECT "id", "zip" + FROM "table_name" + WHERE ("id", "zip") > ('1', '1234') + ORDER BY "id", "zip" + LIMIT 10 + FOR NO KEY UPDATE +), +update AS +( + UPDATE "table_name" + SET "id" = "table_name"."id", "zip" = "table_name"."zip" + FROM batch + WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" + RETURNING "table_name"."id", "table_name"."zip" +) +SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() +FROM update +` diff --git a/pkg/backfill/templates/sql.go b/pkg/backfill/templates/sql.go new file mode 100644 index 000000000..9a11f9dc9 --- /dev/null +++ b/pkg/backfill/templates/sql.go @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: Apache-2.0 + +package templates + +const SQL = `WITH batch AS +( + SELECT {{ commaSeparate (quoteIdentifiers .PrimaryKey) }} + FROM {{ .TableName | qi}} + {{ if .LastValue -}} + WHERE ({{ commaSeparate (quoteIdentifiers .PrimaryKey) }}) > ({{ commaSeparate (quoteLiterals .LastValue) }}) + {{ end -}} + ORDER BY {{ commaSeparate (quoteIdentifiers .PrimaryKey) }} + LIMIT {{ .BatchSize }} + FOR NO KEY UPDATE +), +update AS +( + UPDATE {{ .TableName | qi }} + SET {{ updateSetClause .TableName .PrimaryKey }} + FROM batch + WHERE {{ updateWhereClause .TableName .PrimaryKey }} + RETURNING {{ updateReturnClause .TableName .PrimaryKey }} +) +SELECT {{ selectLastValue .PrimaryKey }} +FROM update +`