Skip to content

Commit

Permalink
Merge pull request #2136 from ninedraft/optimize-sanitize
Browse files Browse the repository at this point in the history
Reduce SQL sanitizer allocations
  • Loading branch information
jackc authored Dec 31, 2024
2 parents 4ff0a45 + e452f80 commit ca04098
Show file tree
Hide file tree
Showing 6 changed files with 387 additions and 28 deletions.
1 change: 0 additions & 1 deletion conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1417,5 +1417,4 @@ func TestErrNoRows(t *testing.T) {
require.Equal(t, "no rows in result set", pgx.ErrNoRows.Error())

require.ErrorIs(t, pgx.ErrNoRows, sql.ErrNoRows, "pgx.ErrNowRows must match sql.ErrNoRows")
require.ErrorIs(t, pgx.ErrNoRows, pgx.ErrNoRows, "sql.ErrNowRows must match pgx.ErrNoRows")
}
60 changes: 60 additions & 0 deletions internal/sanitize/benchmmark.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
#!/usr/bin/env bash

current_branch=$(git rev-parse --abbrev-ref HEAD)
if [ "$current_branch" == "HEAD" ]; then
current_branch=$(git rev-parse HEAD)
fi

restore_branch() {
echo "Restoring original branch/commit: $current_branch"
git checkout "$current_branch"
}
trap restore_branch EXIT

# Check if there are uncommitted changes
if ! git diff --quiet || ! git diff --cached --quiet; then
echo "There are uncommitted changes. Please commit or stash them before running this script."
exit 1
fi

# Ensure that at least one commit argument is passed
if [ "$#" -lt 1 ]; then
echo "Usage: $0 <commit1> <commit2> ... <commitN>"
exit 1
fi

commits=("$@")
benchmarks_dir=benchmarks

if ! mkdir -p "${benchmarks_dir}"; then
echo "Unable to create dir for benchmarks data"
exit 1
fi

# Benchmark results
bench_files=()

# Run benchmark for each listed commit
for i in "${!commits[@]}"; do
commit="${commits[i]}"
git checkout "$commit" || {
echo "Failed to checkout $commit"
exit 1
}

# Sanitized commmit message
commit_message=$(git log -1 --pretty=format:"%s" | tr -c '[:alnum:]-_' '_')

# Benchmark data will go there
bench_file="${benchmarks_dir}/${i}_${commit_message}.bench"

if ! go test -bench=. -count=10 >"$bench_file"; then
echo "Benchmarking failed for commit $commit"
exit 1
fi

bench_files+=("$bench_file")
done

# go install golang.org/x/perf/cmd/benchstat[@latest]
benchstat "${bench_files[@]}"
183 changes: 156 additions & 27 deletions internal/sanitize/sanitize.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@ import (
"bytes"
"encoding/hex"
"fmt"
"slices"
"strconv"
"strings"
"sync"
"time"
"unicode/utf8"
)
Expand All @@ -24,53 +26,75 @@ type Query struct {
// https://github.com/jackc/pgx/issues/1380
const replacementcharacterwidth = 3

const maxBufSize = 16384 // 16 Ki

var bufPool = &pool[*bytes.Buffer]{
new: func() *bytes.Buffer {
return &bytes.Buffer{}
},
reset: func(b *bytes.Buffer) bool {
n := b.Len()
b.Reset()
return n < maxBufSize
},
}

var null = []byte("null")

func (q *Query) Sanitize(args ...any) (string, error) {
argUse := make([]bool, len(args))
buf := &bytes.Buffer{}
buf := bufPool.get()
defer bufPool.put(buf)

for _, part := range q.Parts {
var str string
switch part := part.(type) {
case string:
str = part
buf.WriteString(part)
case int:
argIdx := part - 1

var p []byte
if argIdx < 0 {
return "", fmt.Errorf("first sql argument must be > 0")
}

if argIdx >= len(args) {
return "", fmt.Errorf("insufficient arguments")
}

// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
buf.WriteByte(' ')

arg := args[argIdx]
switch arg := arg.(type) {
case nil:
str = "null"
p = null
case int64:
str = strconv.FormatInt(arg, 10)
p = strconv.AppendInt(buf.AvailableBuffer(), arg, 10)
case float64:
str = strconv.FormatFloat(arg, 'f', -1, 64)
p = strconv.AppendFloat(buf.AvailableBuffer(), arg, 'f', -1, 64)
case bool:
str = strconv.FormatBool(arg)
p = strconv.AppendBool(buf.AvailableBuffer(), arg)
case []byte:
str = QuoteBytes(arg)
p = QuoteBytes(buf.AvailableBuffer(), arg)
case string:
str = QuoteString(arg)
p = QuoteString(buf.AvailableBuffer(), arg)
case time.Time:
str = arg.Truncate(time.Microsecond).Format("'2006-01-02 15:04:05.999999999Z07:00:00'")
p = arg.Truncate(time.Microsecond).
AppendFormat(buf.AvailableBuffer(), "'2006-01-02 15:04:05.999999999Z07:00:00'")
default:
return "", fmt.Errorf("invalid arg type: %T", arg)
}
argUse[argIdx] = true

buf.Write(p)

// Prevent SQL injection via Line Comment Creation
// https://github.com/jackc/pgx/security/advisories/GHSA-m7wr-2xf7-cm9p
str = " " + str + " "
buf.WriteByte(' ')
default:
return "", fmt.Errorf("invalid Part type: %T", part)
}
buf.WriteString(str)
}

for i, used := range argUse {
Expand All @@ -82,26 +106,99 @@ func (q *Query) Sanitize(args ...any) (string, error) {
}

func NewQuery(sql string) (*Query, error) {
l := &sqlLexer{
src: sql,
stateFn: rawState,
query := &Query{}
query.init(sql)

return query, nil
}

var sqlLexerPool = &pool[*sqlLexer]{
new: func() *sqlLexer {
return &sqlLexer{}
},
reset: func(sl *sqlLexer) bool {
*sl = sqlLexer{}
return true
},
}

func (q *Query) init(sql string) {
parts := q.Parts[:0]
if parts == nil {
// dirty, but fast heuristic to preallocate for ~90% usecases
n := strings.Count(sql, "$") + strings.Count(sql, "--") + 1
parts = make([]Part, 0, n)
}

l := sqlLexerPool.get()
defer sqlLexerPool.put(l)

l.src = sql
l.stateFn = rawState
l.parts = parts

for l.stateFn != nil {
l.stateFn = l.stateFn(l)
}

query := &Query{Parts: l.parts}

return query, nil
q.Parts = l.parts
}

func QuoteString(str string) string {
return "'" + strings.ReplaceAll(str, "'", "''") + "'"
func QuoteString(dst []byte, str string) []byte {
const quote = '\''

// Preallocate space for the worst case scenario
dst = slices.Grow(dst, len(str)*2+2)

// Add opening quote
dst = append(dst, quote)

// Iterate through the string without allocating
for i := 0; i < len(str); i++ {
if str[i] == quote {
dst = append(dst, quote, quote)
} else {
dst = append(dst, str[i])
}
}

// Add closing quote
dst = append(dst, quote)

return dst
}

func QuoteBytes(buf []byte) string {
return `'\x` + hex.EncodeToString(buf) + "'"
func QuoteBytes(dst, buf []byte) []byte {
if len(buf) == 0 {
return append(dst, `'\x'`...)
}

// Calculate required length
requiredLen := 3 + hex.EncodedLen(len(buf)) + 1

// Ensure dst has enough capacity
if cap(dst)-len(dst) < requiredLen {
newDst := make([]byte, len(dst), len(dst)+requiredLen)
copy(newDst, dst)
dst = newDst
}

// Record original length and extend slice
origLen := len(dst)
dst = dst[:origLen+requiredLen]

// Add prefix
dst[origLen] = '\''
dst[origLen+1] = '\\'
dst[origLen+2] = 'x'

// Encode bytes directly into dst
hex.Encode(dst[origLen+3:len(dst)-1], buf)

// Add suffix
dst[len(dst)-1] = '\''

return dst
}

type sqlLexer struct {
Expand Down Expand Up @@ -319,13 +416,45 @@ func multilineCommentState(l *sqlLexer) stateFn {
}
}

var queryPool = &pool[*Query]{
new: func() *Query {
return &Query{}
},
reset: func(q *Query) bool {
n := len(q.Parts)
q.Parts = q.Parts[:0]
return n < 64 // drop too large queries
},
}

// SanitizeSQL replaces placeholder values with args. It quotes and escapes args
// as necessary. This function is only safe when standard_conforming_strings is
// on.
func SanitizeSQL(sql string, args ...any) (string, error) {
query, err := NewQuery(sql)
if err != nil {
return "", err
}
query := queryPool.get()
query.init(sql)
defer queryPool.put(query)

return query.Sanitize(args...)
}

type pool[E any] struct {
p sync.Pool
new func() E
reset func(E) bool
}

func (pool *pool[E]) get() E {
v, ok := pool.p.Get().(E)
if !ok {
v = pool.new()
}

return v
}

func (p *pool[E]) put(v E) {
if p.reset(v) {
p.p.Put(v)
}
}
62 changes: 62 additions & 0 deletions internal/sanitize/sanitize_bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
// sanitize_benchmark_test.go
package sanitize_test

import (
"testing"
"time"

"github.com/jackc/pgx/v5/internal/sanitize"
)

var benchmarkSanitizeResult string

const benchmarkQuery = "" +
`SELECT *
FROM "water_containers"
WHERE NOT "id" = $1 -- int64
AND "tags" NOT IN $2 -- nil
AND "volume" > $3 -- float64
AND "transportable" = $4 -- bool
AND position($5 IN "sign") -- bytes
AND "label" LIKE $6 -- string
AND "created_at" > $7; -- time.Time`

var benchmarkArgs = []any{
int64(12345),
nil,
float64(500),
true,
[]byte("8BADF00D"),
"kombucha's han'dy awokowa",
time.Date(2015, 10, 1, 0, 0, 0, 0, time.UTC),
}

func BenchmarkSanitize(b *testing.B) {
query, err := sanitize.NewQuery(benchmarkQuery)
if err != nil {
b.Fatalf("failed to create query: %v", err)
}

b.ResetTimer()
b.ReportAllocs()

for i := 0; i < b.N; i++ {
benchmarkSanitizeResult, err = query.Sanitize(benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize query: %v", err)
}
}
}

var benchmarkNewSQLResult string

func BenchmarkSanitizeSQL(b *testing.B) {
b.ReportAllocs()
var err error
for i := 0; i < b.N; i++ {
benchmarkNewSQLResult, err = sanitize.SanitizeSQL(benchmarkQuery, benchmarkArgs...)
if err != nil {
b.Fatalf("failed to sanitize SQL: %v", err)
}
}
}
Loading

0 comments on commit ca04098

Please sign in to comment.