diff --git a/pkg/backfill/backfill.go b/pkg/backfill/backfill.go
index 41e86d2d8..ea9923d05 100644
--- a/pkg/backfill/backfill.go
+++ b/pkg/backfill/backfill.go
@@ -7,11 +7,9 @@ import (
 	"database/sql"
 	"errors"
 	"fmt"
-	"strings"
 	"time"
 
-	"github.com/lib/pq"
-
+	"github.com/xataio/pgroll/pkg/backfill/templates"
 	"github.com/xataio/pgroll/pkg/db"
 	"github.com/xataio/pgroll/pkg/schema"
 )
@@ -59,7 +57,13 @@ func (bf *Backfill) Start(ctx context.Context, table *schema.Table) error {
 	}
 
 	// Create a batcher for the table.
-	b := newBatcher(table, bf.batchSize)
+	b := batcher{
+		BatchConfig: templates.BatchConfig{
+			TableName:  table.Name,
+			PrimaryKey: identityColumns,
+			BatchSize:  bf.batchSize,
+		},
+	}
 
 	// Update each batch of rows, invoking callbacks for each one.
 	for batch := 0; ; batch++ {
@@ -158,30 +162,30 @@ func getIdentityColumns(table *schema.Table) []string {
 	return nil
 }
 
+// A batcher is responsible for updating a batch of rows in a table.
+// It holds the state necessary to update the next batch of rows.
 type batcher struct {
-	statementBuilder *batchStatementBuilder
-	lastValues       []string
-}
-
-func newBatcher(table *schema.Table, batchSize int) *batcher {
-	return &batcher{
-		statementBuilder: newBatchStatementBuilder(table.Name, getIdentityColumns(table), batchSize),
-		lastValues:       make([]string, len(getIdentityColumns(table))),
-	}
+	templates.BatchConfig
 }
 
 func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error {
 	return conn.WithRetryableTransaction(ctx, func(ctx context.Context, tx *sql.Tx) error {
 		// Build the query to update the next batch of rows
-		query := b.statementBuilder.buildQuery(b.lastValues)
+		sql, err := templates.BuildSQL(b.BatchConfig)
+		if err != nil {
+			return err
+		}
 
 		// Execute the query to update the next batch of rows and update the last PK
 		// value for the next batch
-		wrapper := make([]any, len(b.lastValues))
-		for i := range b.lastValues {
-			wrapper[i] = &b.lastValues[i]
+		if b.LastValue == nil {
+			b.LastValue = make([]string, len(b.PrimaryKey))
+		}
+		wrapper := make([]any, len(b.LastValue))
+		for i := range b.LastValue {
+			wrapper[i] = &b.LastValue[i]
 		}
-		err := tx.QueryRowContext(ctx, query).Scan(wrapper...)
+		err = tx.QueryRowContext(ctx, sql).Scan(wrapper...)
 		if err != nil {
 			return err
 		}
@@ -189,78 +193,3 @@ func (b *batcher) updateBatch(ctx context.Context, conn db.DB) error {
 		return nil
 	})
 }
-
-type batchStatementBuilder struct {
-	tableName       string
-	identityColumns []string
-	batchSize       int
-}
-
-func newBatchStatementBuilder(tableName string, identityColumnNames []string, batchSize int) *batchStatementBuilder {
-	quotedCols := make([]string, len(identityColumnNames))
-	for i, col := range identityColumnNames {
-		quotedCols[i] = pq.QuoteIdentifier(col)
-	}
-	return &batchStatementBuilder{
-		tableName:       pq.QuoteIdentifier(tableName),
-		identityColumns: quotedCols,
-		batchSize:       batchSize,
-	}
-}
-
-// buildQuery builds the query used to update the next batch of rows.
-func (sb *batchStatementBuilder) buildQuery(lastValues []string) string {
-	return fmt.Sprintf("WITH batch AS (%[1]s), update AS (%[2]s) %[3]s",
-		sb.buildBatchSubQuery(lastValues),
-		sb.buildUpdateBatchSubQuery(),
-		sb.buildLastValueQuery())
-}
-
-// fetch the next batch of PK of rows to update
-func (sb *batchStatementBuilder) buildBatchSubQuery(lastValues []string) string {
-	whereClause := ""
-	if len(lastValues) != 0 && lastValues[0] != "" {
-		whereClause = fmt.Sprintf("WHERE (%s) > (%s)",
-			strings.Join(sb.identityColumns, ", "), strings.Join(quoteLiteralList(lastValues), ", "))
-	}
-
-	return fmt.Sprintf("SELECT %[1]s FROM %[2]s %[3]s ORDER BY %[1]s LIMIT %[4]d FOR NO KEY UPDATE",
-		strings.Join(sb.identityColumns, ", "), sb.tableName, whereClause, sb.batchSize)
-}
-
-func quoteLiteralList(l []string) []string {
-	quoted := make([]string, len(l))
-	for i, v := range l {
-		quoted[i] = pq.QuoteLiteral(v)
-	}
-	return quoted
-}
-
-// update the rows in the batch
-func (sb *batchStatementBuilder) buildUpdateBatchSubQuery() string {
-	conditions := make([]string, len(sb.identityColumns))
-	for i, col := range sb.identityColumns {
-		conditions[i] = fmt.Sprintf("%[1]s.%[2]s = batch.%[2]s", sb.tableName, col)
-	}
-	updateWhereClause := "WHERE " + strings.Join(conditions, " AND ")
-
-	setStmt := fmt.Sprintf("%[1]s = %[2]s.%[1]s", sb.identityColumns[0], sb.tableName)
-	for i := 1; i < len(sb.identityColumns); i++ {
-		setStmt += fmt.Sprintf(", %[1]s = %[2]s.%[1]s", sb.identityColumns[i], sb.tableName)
-	}
-	updateReturning := sb.tableName + "." + sb.identityColumns[0]
-	for i := 1; i < len(sb.identityColumns); i++ {
-		updateReturning += ", " + sb.tableName + "." + sb.identityColumns[i]
-	}
-	return fmt.Sprintf("UPDATE %[1]s SET %[2]s FROM batch %[3]s RETURNING %[4]s",
-		sb.tableName, setStmt, updateWhereClause, updateReturning)
-}
-
-// fetch the last values of the PK column
-func (sb *batchStatementBuilder) buildLastValueQuery() string {
-	lastValues := make([]string, len(sb.identityColumns))
-	for i, col := range sb.identityColumns {
-		lastValues[i] = "LAST_VALUE(" + col + ") OVER()"
-	}
-	return fmt.Sprintf("SELECT %[1]s FROM update", strings.Join(lastValues, ", "))
-}
diff --git a/pkg/backfill/backfill_test.go b/pkg/backfill/backfill_test.go
deleted file mode 100644
index 2fd304575..000000000
--- a/pkg/backfill/backfill_test.go
+++ /dev/null
@@ -1,54 +0,0 @@
-// SPDX-License-Identifier: Apache-2.0
-
-package backfill
-
-import (
-	"testing"
-
-	"github.com/stretchr/testify/assert"
-)
-
-func TestBatchStatementBuilder(t *testing.T) {
-	tests := map[string]struct {
-		tableName       string
-		identityColumns []string
-		batchSize       int
-		lasValues       []string
-		expected        string
-	}{
-		"single identity column no last value": {
-			tableName:       "table_name",
-			identityColumns: []string{"id"},
-			batchSize:       10,
-			expected:        `WITH batch AS (SELECT "id" FROM "table_name"  ORDER BY "id" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id" FROM batch WHERE "table_name"."id" = batch."id" RETURNING "table_name"."id") SELECT LAST_VALUE("id") OVER() FROM update`,
-		},
-		"multiple identity columns no last value": {
-			tableName:       "table_name",
-			identityColumns: []string{"id", "zip"},
-			batchSize:       10,
-			expected:        `WITH batch AS (SELECT "id", "zip" FROM "table_name"  ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`,
-		},
-		"single identity column with last value": {
-			tableName:       "table_name",
-			identityColumns: []string{"id"},
-			batchSize:       10,
-			lasValues:       []string{"1"},
-			expected:        `WITH batch AS (SELECT "id" FROM "table_name" WHERE ("id") > ('1') ORDER BY "id" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id" FROM batch WHERE "table_name"."id" = batch."id" RETURNING "table_name"."id") SELECT LAST_VALUE("id") OVER() FROM update`,
-		},
-		"multiple identity columns with last value": {
-			tableName:       "table_name",
-			identityColumns: []string{"id", "zip"},
-			batchSize:       10,
-			lasValues:       []string{"1", "1234"},
-			expected:        `WITH batch AS (SELECT "id", "zip" FROM "table_name" WHERE ("id", "zip") > ('1', '1234') ORDER BY "id", "zip" LIMIT 10 FOR NO KEY UPDATE), update AS (UPDATE "table_name" SET "id" = "table_name"."id", "zip" = "table_name"."zip" FROM batch WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip" RETURNING "table_name"."id", "table_name"."zip") SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER() FROM update`,
-		},
-	}
-
-	for name, test := range tests {
-		t.Run(name, func(t *testing.T) {
-			builder := newBatchStatementBuilder(test.tableName, test.identityColumns, test.batchSize)
-			actual := builder.buildQuery(test.lasValues)
-			assert.Equal(t, test.expected, actual)
-		})
-	}
-}
diff --git a/pkg/backfill/templates/build.go b/pkg/backfill/templates/build.go
new file mode 100644
index 000000000..9e6fe706a
--- /dev/null
+++ b/pkg/backfill/templates/build.go
@@ -0,0 +1,86 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package templates
+
+import (
+	"bytes"
+	"strings"
+	"text/template"
+
+	"github.com/lib/pq"
+)
+
+type BatchConfig struct {
+	TableName  string
+	PrimaryKey []string
+	LastValue  []string
+	BatchSize  int
+}
+
+func BuildSQL(cfg BatchConfig) (string, error) {
+	return executeTemplate("sql", SQL, cfg)
+}
+
+func executeTemplate(name, content string, cfg BatchConfig) (string, error) {
+	ql := pq.QuoteLiteral
+	qi := pq.QuoteIdentifier
+
+	tmpl := template.Must(template.New(name).
+		Funcs(template.FuncMap{
+			"ql": ql,
+			"qi": qi,
+			"commaSeparate": func(slice []string) string {
+				return strings.Join(slice, ", ")
+			},
+			"quoteIdentifiers": func(slice []string) []string {
+				quoted := make([]string, len(slice))
+				for i, s := range slice {
+					quoted[i] = qi(s)
+				}
+				return quoted
+			},
+			"quoteLiterals": func(slice []string) []string {
+				quoted := make([]string, len(slice))
+				for i, s := range slice {
+					quoted[i] = ql(s)
+				}
+				return quoted
+			},
+			"updateSetClause": func(tableName string, columns []string) string {
+				quoted := make([]string, len(columns))
+				for i, c := range columns {
+					quoted[i] = qi(c) + " = " + qi(tableName) + "." + qi(c)
+				}
+				return strings.Join(quoted, ", ")
+			},
+			"updateWhereClause": func(tableName string, columns []string) string {
+				quoted := make([]string, len(columns))
+				for i, c := range columns {
+					quoted[i] = qi(tableName) + "." + qi(c) + " = batch." + qi(c)
+				}
+				return strings.Join(quoted, " AND ")
+			},
+			"updateReturnClause": func(tableName string, columns []string) string {
+				quoted := make([]string, len(columns))
+				for i, c := range columns {
+					quoted[i] = qi(tableName) + "." + qi(c)
+				}
+				return strings.Join(quoted, ", ")
+			},
+			"selectLastValue": func(columns []string) string {
+				quoted := make([]string, len(columns))
+				for i, c := range columns {
+					quoted[i] = "LAST_VALUE(" + qi(c) + ") OVER()"
+				}
+				return strings.Join(quoted, ", ")
+			},
+		}).
+		Parse(content))
+
+	buf := bytes.Buffer{}
+	if err := tmpl.Execute(&buf, cfg); err != nil {
+		return "", err
+	}
+
+	return buf.String(), nil
+}
diff --git a/pkg/backfill/templates/build_test.go b/pkg/backfill/templates/build_test.go
new file mode 100644
index 000000000..4e9128c1e
--- /dev/null
+++ b/pkg/backfill/templates/build_test.go
@@ -0,0 +1,142 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package templates
+
+import (
+	"testing"
+
+	"github.com/stretchr/testify/assert"
+)
+
+func TestBatchStatementBuilder(t *testing.T) {
+	tests := map[string]struct {
+		config   BatchConfig
+		expected string
+	}{
+		"single identity column no last value": {
+			config: BatchConfig{
+				TableName:  "table_name",
+				PrimaryKey: []string{"id"},
+				BatchSize:  10,
+			},
+			expected: expectSingleIDColumnNoLastValue,
+		},
+		"multiple identity columns no last value": {
+			config: BatchConfig{
+				TableName:  "table_name",
+				PrimaryKey: []string{"id", "zip"},
+				BatchSize:  10,
+			},
+			expected: multipleIDColumnsNoLastValue,
+		},
+		"single identity column with last value": {
+			config: BatchConfig{
+				TableName:  "table_name",
+				PrimaryKey: []string{"id"},
+				LastValue:  []string{"1"},
+				BatchSize:  10,
+			},
+			expected: singleIDColumnWithLastValue,
+		},
+		"multiple identity columns with last value": {
+			config: BatchConfig{
+				TableName:  "table_name",
+				PrimaryKey: []string{"id", "zip"},
+				LastValue:  []string{"1", "1234"},
+				BatchSize:  10,
+			},
+			expected: multipleIDColumnsWithLastValue,
+		},
+	}
+
+	for name, test := range tests {
+		t.Run(name, func(t *testing.T) {
+			actual, err := BuildSQL(test.config)
+			assert.NoError(t, err)
+
+			assert.Equal(t, test.expected, actual)
+		})
+	}
+}
+
+const expectSingleIDColumnNoLastValue = `WITH batch AS
+(
+  SELECT "id"
+  FROM "table_name"
+  ORDER BY "id"
+  LIMIT 10
+  FOR NO KEY UPDATE
+),
+update AS
+(
+  UPDATE "table_name"
+  SET "id" = "table_name"."id"
+  FROM batch
+  WHERE "table_name"."id" = batch."id"
+  RETURNING "table_name"."id"
+)
+SELECT LAST_VALUE("id") OVER()
+FROM update
+`
+
+const multipleIDColumnsNoLastValue = `WITH batch AS
+(
+  SELECT "id", "zip"
+  FROM "table_name"
+  ORDER BY "id", "zip"
+  LIMIT 10
+  FOR NO KEY UPDATE
+),
+update AS
+(
+  UPDATE "table_name"
+  SET "id" = "table_name"."id", "zip" = "table_name"."zip"
+  FROM batch
+  WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip"
+  RETURNING "table_name"."id", "table_name"."zip"
+)
+SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER()
+FROM update
+`
+
+const singleIDColumnWithLastValue = `WITH batch AS
+(
+  SELECT "id"
+  FROM "table_name"
+  WHERE ("id") > ('1')
+  ORDER BY "id"
+  LIMIT 10
+  FOR NO KEY UPDATE
+),
+update AS
+(
+  UPDATE "table_name"
+  SET "id" = "table_name"."id"
+  FROM batch
+  WHERE "table_name"."id" = batch."id"
+  RETURNING "table_name"."id"
+)
+SELECT LAST_VALUE("id") OVER()
+FROM update
+`
+
+const multipleIDColumnsWithLastValue = `WITH batch AS
+(
+  SELECT "id", "zip"
+  FROM "table_name"
+  WHERE ("id", "zip") > ('1', '1234')
+  ORDER BY "id", "zip"
+  LIMIT 10
+  FOR NO KEY UPDATE
+),
+update AS
+(
+  UPDATE "table_name"
+  SET "id" = "table_name"."id", "zip" = "table_name"."zip"
+  FROM batch
+  WHERE "table_name"."id" = batch."id" AND "table_name"."zip" = batch."zip"
+  RETURNING "table_name"."id", "table_name"."zip"
+)
+SELECT LAST_VALUE("id") OVER(), LAST_VALUE("zip") OVER()
+FROM update
+`
diff --git a/pkg/backfill/templates/sql.go b/pkg/backfill/templates/sql.go
new file mode 100644
index 000000000..9a11f9dc9
--- /dev/null
+++ b/pkg/backfill/templates/sql.go
@@ -0,0 +1,26 @@
+// SPDX-License-Identifier: Apache-2.0
+
+package templates
+
+const SQL = `WITH batch AS
+(
+  SELECT {{ commaSeparate (quoteIdentifiers .PrimaryKey) }}
+  FROM {{ .TableName | qi}}
+  {{ if .LastValue -}}
+  WHERE ({{ commaSeparate (quoteIdentifiers .PrimaryKey) }}) > ({{ commaSeparate (quoteLiterals .LastValue) }})
+  {{ end -}}
+  ORDER BY {{ commaSeparate (quoteIdentifiers .PrimaryKey) }}
+  LIMIT {{ .BatchSize }}
+  FOR NO KEY UPDATE
+),
+update AS
+(
+  UPDATE {{ .TableName | qi }}
+  SET {{ updateSetClause .TableName .PrimaryKey }}
+  FROM batch
+  WHERE {{ updateWhereClause .TableName .PrimaryKey }}
+  RETURNING {{ updateReturnClause .TableName .PrimaryKey }}
+)
+SELECT {{ selectLastValue .PrimaryKey }}
+FROM update
+`