Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use text/template for backfill query generation #632

Merged
merged 2 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
115 changes: 22 additions & 93 deletions pkg/backfill/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,9 @@ import (
"database/sql"
"errors"
"fmt"
"strings"
"time"

"github.com/lib/pq"

"github.com/xataio/pgroll/pkg/backfill/templates"
"github.com/xataio/pgroll/pkg/db"
"github.com/xataio/pgroll/pkg/schema"
)
Expand Down Expand Up @@ -59,7 +57,13 @@ func (bf *Backfill) Start(ctx context.Context, table *schema.Table) error {
}

// Create a batcher for the table.
b := newBatcher(table, bf.batchSize)
b := batcher{
BatchConfig: templates.BatchConfig{
TableName: table.Name,
PrimaryKey: identityColumns,
BatchSize: bf.batchSize,
},
}

// Update each batch of rows, invoking callbacks for each one.
for batch := 0; ; batch++ {
Expand Down Expand Up @@ -158,109 +162,34 @@ func getIdentityColumns(table *schema.Table) []string {
return nil
}

// A batcher is responsible for updating a batch of rows in a table.
// It holds the state necessary to update the next batch of rows.
type batcher struct {
statementBuilder *batchStatementBuilder
lastValues []string
}

func newBatcher(table *schema.Table, batchSize int) *batcher {
return &batcher{
statementBuilder: newBatchStatementBuilder(table.Name, getIdentityColumns(table), batchSize),
lastValues: make([]string, len(getIdentityColumns(table))),
}
templates.BatchConfig
}

func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error {
return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
// Build the query to update the next batch of rows
query := b.statementBuilder.buildQuery(b.lastValues)
sql, err := templates.BuildSQL(b.BatchConfig)
if err != nil {
return err
}

// Execute the query to update the next batch of rows and update the last PK
// value for the next batch
wrapper := make([]any, len(b.lastValues))
for i := range b.lastValues {
wrapper[i] = &b.lastValues[i]
if b.LastValue == nil {
b.LastValue = make([]string, len(b.PrimaryKey))
}
wrapper := make([]any, len(b.LastValue))
for i := range b.LastValue {
wrapper[i] = &b.LastValue[i]
}
err := tx.QueryRowContext(ctx, query).Scan(wrapper...)
err = tx.QueryRowContext(ctx, sql).Scan(wrapper...)
if err != nil {
return err
}

return nil
})
}

type batchStatementBuilder struct {
tableName string
identityColumns []string
batchSize int
}

func newBatchStatementBuilder(tableName string, identityColumnNames []string, batchSize int) *batchStatementBuilder {
quotedCols := make([]string, len(identityColumnNames))
for i, col := range identityColumnNames {
quotedCols[i] = pq.QuoteIdentifier(col)
}
return &batchStatementBuilder{
tableName: pq.QuoteIdentifier(tableName),
identityColumns: quotedCols,
batchSize: batchSize,
}
}

// buildQuery builds the query used to update the next batch of rows.
func (sb *batchStatementBuilder) buildQuery(lastValues []string) string {
return fmt.Sprintf("WITH batch AS (%[1]s), update AS (%[2]s) %[3]s",
sb.buildBatchSubQuery(lastValues),
sb.buildUpdateBatchSubQuery(),
sb.buildLastValueQuery())
}

// fetch the next batch of PK of rows to update
func (sb *batchStatementBuilder) buildBatchSubQuery(lastValues []string) string {
whereClause := ""
if len(lastValues) != 0 && lastValues[0] != "" {
whereClause = fmt.Sprintf("WHERE (%s) > (%s)",
strings.Join(sb.identityColumns, ", "), strings.Join(quoteLiteralList(lastValues), ", "))
}

return fmt.Sprintf("SELECT %[1]s FROM %[2]s %[3]s ORDER BY %[1]s LIMIT %[4]d FOR NO KEY UPDATE",
strings.Join(sb.identityColumns, ", "), sb.tableName, whereClause, sb.batchSize)
}

func quoteLiteralList(l []string) []string {
quoted := make([]string, len(l))
for i, v := range l {
quoted[i] = pq.QuoteLiteral(v)
}
return quoted
}

// update the rows in the batch
func (sb *batchStatementBuilder) buildUpdateBatchSubQuery() string {
conditions := make([]string, len(sb.identityColumns))
for i, col := range sb.identityColumns {
conditions[i] = fmt.Sprintf("%[1]s.%[2]s = batch.%[2]s", sb.tableName, col)
}
updateWhereClause := "WHERE " + strings.Join(conditions, " AND ")

setStmt := fmt.Sprintf("%[1]s = %[2]s.%[1]s", sb.identityColumns[0], sb.tableName)
for i := 1; i < len(sb.identityColumns); i++ {
setStmt += fmt.Sprintf(", %[1]s = %[2]s.%[1]s", sb.identityColumns[i], sb.tableName)
}
updateReturning := sb.tableName + "." + sb.identityColumns[0]
for i := 1; i < len(sb.identityColumns); i++ {
updateReturning += ", " + sb.tableName + "." + sb.identityColumns[i]
}
return fmt.Sprintf("UPDATE %[1]s SET %[2]s FROM batch %[3]s RETURNING %[4]s",
sb.tableName, setStmt, updateWhereClause, updateReturning)
}

// fetch the last values of the PK column
func (sb *batchStatementBuilder) buildLastValueQuery() string {
lastValues := make([]string, len(sb.identityColumns))
for i, col := range sb.identityColumns {
lastValues[i] = "LAST_VALUE(" + col + ") OVER()"
}
return fmt.Sprintf("SELECT %[1]s FROM update", strings.Join(lastValues, ", "))
}
54 changes: 0 additions & 54 deletions pkg/backfill/backfill_test.go

This file was deleted.

86 changes: 86 additions & 0 deletions pkg/backfill/templates/build.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
// SPDX-License-Identifier: Apache-2.0

package templates

import (
"bytes"
"strings"
"text/template"

"github.com/lib/pq"
)

type BatchConfig struct {
TableName string
PrimaryKey []string
LastValue []string
BatchSize int
}

func BuildSQL(cfg BatchConfig) (string, error) {
return executeTemplate("sql", SQL, cfg)
}

func executeTemplate(name, content string, cfg BatchConfig) (string, error) {
ql := pq.QuoteLiteral
qi := pq.QuoteIdentifier

tmpl := template.Must(template.New(name).
Funcs(template.FuncMap{
"ql": ql,
"qi": qi,
"commaSeparate": func(slice []string) string {
return strings.Join(slice, ", ")
},
"quoteIdentifiers": func(slice []string) []string {
quoted := make([]string, len(slice))
for i, s := range slice {
quoted[i] = qi(s)
}
return quoted
},
"quoteLiterals": func(slice []string) []string {
quoted := make([]string, len(slice))
for i, s := range slice {
quoted[i] = ql(s)
}
return quoted
},
"updateSetClause": func(tableName string, columns []string) string {
quoted := make([]string, len(columns))
for i, c := range columns {
quoted[i] = qi(c) + " = " + qi(tableName) + "." + qi(c)
}
return strings.Join(quoted, ", ")
},
"updateWhereClause": func(tableName string, columns []string) string {
quoted := make([]string, len(columns))
for i, c := range columns {
quoted[i] = qi(tableName) + "." + qi(c) + " = batch." + qi(c)
}
return strings.Join(quoted, " AND ")
},
"updateReturnClause": func(tableName string, columns []string) string {
quoted := make([]string, len(columns))
for i, c := range columns {
quoted[i] = qi(tableName) + "." + qi(c)
}
return strings.Join(quoted, ", ")
},
"selectLastValue": func(columns []string) string {
quoted := make([]string, len(columns))
for i, c := range columns {
quoted[i] = "LAST_VALUE(" + qi(c) + ") OVER()"
}
return strings.Join(quoted, ", ")
},
}).
Parse(content))

buf := bytes.Buffer{}
if err := tmpl.Execute(&buf, cfg); err != nil {
return "", err
}

return buf.String(), nil
}
Loading