From 80707f7f60ad33ca4127afe9439c23eee55ce9d5 Mon Sep 17 00:00:00 2001 From: Dirkjan Bussink Date: Mon, 24 Feb 2025 12:14:59 +0100 Subject: [PATCH] Implement temporal comparisons (#17826) Signed-off-by: Dirkjan Bussink --- go/vt/vtgate/evalengine/cached_size.go | 2 +- go/vt/vtgate/evalengine/compiler.go | 35 +- go/vt/vtgate/evalengine/compiler_asm.go | 53 +- go/vt/vtgate/evalengine/compiler_asm_push.go | 29 + go/vt/vtgate/evalengine/compiler_test.go | 7 +- go/vt/vtgate/evalengine/eval_temporal.go | 99 ++- go/vt/vtgate/evalengine/expr_bvar.go | 4 +- go/vt/vtgate/evalengine/expr_collate.go | 2 +- go/vt/vtgate/evalengine/expr_column.go | 4 +- go/vt/vtgate/evalengine/fn_compare.go | 284 +++++++- go/vt/vtgate/evalengine/fn_compare_test.go | 80 +++ go/vt/vtgate/evalengine/fn_time.go | 48 +- .../evalengine/integration/comparison_test.go | 12 +- go/vt/vtgate/evalengine/testcases/cases.go | 656 +++++++++--------- go/vt/vtgate/evalengine/testcases/helpers.go | 2 +- 15 files changed, 930 insertions(+), 387 deletions(-) create mode 100644 go/vt/vtgate/evalengine/fn_compare_test.go diff --git a/go/vt/vtgate/evalengine/cached_size.go b/go/vt/vtgate/evalengine/cached_size.go index fc7a02e84c1..985891860c5 100644 --- a/go/vt/vtgate/evalengine/cached_size.go +++ b/go/vt/vtgate/evalengine/cached_size.go @@ -1411,7 +1411,7 @@ func (cached *builtinMultiComparison) CachedSize(alloc bool) int64 { } size := int64(0) if alloc { - size += int64(48) + size += int64(64) } // field CallExpr vitess.io/vitess/go/vt/vtgate/evalengine.CallExpr size += cached.CallExpr.CachedSize(false) diff --git a/go/vt/vtgate/evalengine/compiler.go b/go/vt/vtgate/evalengine/compiler.go index b0a7edd285d..c69df3a300f 100644 --- a/go/vt/vtgate/evalengine/compiler.go +++ b/go/vt/vtgate/evalengine/compiler.go @@ -335,7 +335,7 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { c.asm.Convert_id(offset) case sqltypes.Uint64: c.asm.Convert_ud(offset) - case sqltypes.Datetime, sqltypes.Time: + case sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: scale = ct.Size size = ct.Size + decimalSizeBase fallthrough @@ -345,6 +345,28 @@ func (c *compiler) compileToDecimal(ct ctype, offset int) ctype { return ctype{Type: sqltypes.Decimal, Flag: ct.Flag, Col: collationNumeric, Scale: scale, Size: size} } +func (c *compiler) compileToTemporal(doct ctype, typ sqltypes.Type, offset, prec int) ctype { + switch doct.Type { + case typ: + if int(doct.Size) == prec { + return doct + } + fallthrough + default: + switch typ { + case sqltypes.Date: + c.asm.Convert_xD(offset, c.sqlmode.AllowZeroDate()) + case sqltypes.Datetime: + c.asm.Convert_xDT(offset, prec, c.sqlmode.AllowZeroDate()) + case sqltypes.Timestamp: + c.asm.Convert_xDTs(offset, prec, c.sqlmode.AllowZeroDate()) + case sqltypes.Time: + c.asm.Convert_xT(offset, prec) + } + } + return ctype{Type: typ, Col: collationBinary, Flag: flagNullable} +} + func (c *compiler) compileToDate(doct ctype, offset int) ctype { switch doct.Type { case sqltypes.Date: @@ -366,6 +388,17 @@ func (c *compiler) compileToDateTime(doct ctype, offset, prec int) ctype { return ctype{Type: sqltypes.Datetime, Size: int32(prec), Col: collationBinary, Flag: flagNullable} } +func (c *compiler) compileToTimestamp(doct ctype, offset, prec int) ctype { + switch doct.Type { + case sqltypes.Timestamp: + c.asm.Convert_tp(offset, prec) + return doct + default: + c.asm.Convert_xDTs(offset, prec, c.sqlmode.AllowZeroDate()) + } + return ctype{Type: sqltypes.Timestamp, Size: int32(prec), Col: collationBinary, Flag: flagNullable} +} + func (c *compiler) compileToTime(doct ctype, offset, prec int) ctype { switch doct.Type { case sqltypes.Time: diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 7dda215353f..d13d22e76cc 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -767,11 +767,11 @@ func (asm *assembler) CmpDates() { }, "CMP DATE(SP-2), DATE(SP-1)") } -func (asm *assembler) Collate(col collations.ID) { +func (asm *assembler) Collate(col collations.TypedCollation) { asm.emit(func(env *ExpressionEnv) int { a := env.vm.stack[env.vm.sp-1].(*evalBytes) a.tt = int16(sqltypes.VarChar) - a.col.Collation = col + a.col = col return 1 }, "COLLATE VARCHAR(SP-1), %d", col) } @@ -1170,6 +1170,21 @@ func (asm *assembler) Convert_xDT(offset, prec int, allowZero bool) { }, "CONV (SP-%d), DATETIME", offset) } +func (asm *assembler) Convert_xDTs(offset, prec int, allowZero bool) { + asm.emit(func(env *ExpressionEnv) int { + // Need to explicitly check here or we otherwise + // store a nil wrapper in an interface vs. a direct + // nil. + dt := evalToTimestamp(env.vm.stack[env.vm.sp-offset], prec, env.now, allowZero) + if dt == nil { + env.vm.stack[env.vm.sp-offset] = nil + } else { + env.vm.stack[env.vm.sp-offset] = dt + } + return 1 + }, "CONV (SP-%d), TIMESTAMP", offset) +} + func (asm *assembler) Convert_xT(offset, prec int) { asm.emit(func(env *ExpressionEnv) int { t := evalToTime(env.vm.stack[env.vm.sp-offset], prec) @@ -2670,6 +2685,40 @@ func (asm *assembler) Fn_MULTICMP_u(args int, lessThan bool) { }, "FN MULTICMP UINT64(SP-%d)...UINT64(SP-1)", args) } +func (asm *assembler) Fn_MULTICMP_temporal(args int, lessThan bool) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(env *ExpressionEnv) int { + var x *evalTemporal + x, _ = env.vm.stack[env.vm.sp-args].(*evalTemporal) + for sp := env.vm.sp - args + 1; sp < env.vm.sp; sp++ { + if env.vm.stack[sp] == nil { + if lessThan { + x = nil + } + continue + } + y := env.vm.stack[sp].(*evalTemporal) + if lessThan == (y.compare(x) < 0) { + x = y + } + } + env.vm.stack[env.vm.sp-args] = x + env.vm.sp -= args - 1 + return 1 + }, "FN MULTICMP TEMPORAL(SP-%d)...TEMPORAL(SP-1)", args) +} + +func (asm *assembler) Fn_MULTICMP_temporal_fallback(f multiComparisonFunc, args int, cmp, prec int) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(env *ExpressionEnv) int { + env.vm.stack[env.vm.sp-args], env.vm.err = f(env, env.vm.stack[env.vm.sp-args:env.vm.sp], cmp, prec) + env.vm.sp -= args - 1 + return 1 + }, "FN MULTICMP_FALLBACK TEMPORAL(SP-%d)...TEMPORAL(SP-1)", args) +} + func (asm *assembler) Fn_REPEAT(base sqltypes.Type, fallback sqltypes.Type) { asm.adjustStack(-1) diff --git a/go/vt/vtgate/evalengine/compiler_asm_push.go b/go/vt/vtgate/evalengine/compiler_asm_push.go index 8f2b5d9f28b..404c8870f87 100644 --- a/go/vt/vtgate/evalengine/compiler_asm_push.go +++ b/go/vt/vtgate/evalengine/compiler_asm_push.go @@ -362,6 +362,23 @@ func (asm *assembler) PushColumn_datetime(offset int) { }, "PUSH DATETIME(:%d)", offset) } +func push_timestamp(env *ExpressionEnv, raw []byte) int { + env.vm.stack[env.vm.sp], env.vm.err = parseTimestamp(raw) + env.vm.sp++ + return 1 +} + +func (asm *assembler) PushColumn_timestamp(offset int) { + asm.adjustStack(1) + asm.emit(func(env *ExpressionEnv) int { + col := env.Row[offset] + if col.IsNull() { + return push_null(env) + } + return push_timestamp(env, col.Raw()) + }, "PUSH TIMESTAMP(:%d)", offset) +} + func (asm *assembler) PushBVar_datetime(key string) { asm.adjustStack(1) asm.emit(func(env *ExpressionEnv) int { @@ -374,6 +391,18 @@ func (asm *assembler) PushBVar_datetime(key string) { }, "PUSH DATETIME(:%q)", key) } +func (asm *assembler) PushBVar_timestamp(key string) { + asm.adjustStack(1) + asm.emit(func(env *ExpressionEnv) int { + var bvar *querypb.BindVariable + bvar, env.vm.err = env.lookupBindVar(key) + if env.vm.err != nil { + return 0 + } + return push_timestamp(env, bvar.Value) + }, "PUSH TIMESTAMP(:%q)", key) +} + func push_date(env *ExpressionEnv, raw []byte) int { env.vm.stack[env.vm.sp], env.vm.err = parseDate(raw) env.vm.sp++ diff --git a/go/vt/vtgate/evalengine/compiler_test.go b/go/vt/vtgate/evalengine/compiler_test.go index 343bb0cd043..3a53fbfd4c8 100644 --- a/go/vt/vtgate/evalengine/compiler_test.go +++ b/go/vt/vtgate/evalengine/compiler_test.go @@ -24,12 +24,12 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" - "github.com/olekukonko/tablewriter" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/mysql/collations" + "vitess.io/vitess/go/mysql/collations/colldata" "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" "vitess.io/vitess/go/vt/sqlparser" @@ -119,7 +119,7 @@ func TestCompilerReference(t *testing.T) { var supported, total int env := evalengine.EmptyExpressionEnv(venv) - tc.Run(func(query string, row []sqltypes.Value) { + tc.Run(func(query string, row []sqltypes.Value, _ bool) { env.Row = row total++ testCompilerCase(t, query, venv, tc.Schema, env) @@ -171,6 +171,7 @@ func testCompilerCase(t *testing.T, query string, venv *vtenv.Environment, schem eval := expected.String() comp := res.String() assert.Equalf(t, eval, comp, "bad evaluation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, eval, comp) + assert.Equalf(t, expected.Collation(), res.Collation(), "bad collation from compiler:\nSQL: %s\nEval: %s\nComp: %s", query, colldata.Lookup(expected.Collation()).Name(), colldata.Lookup(res.Collation()).Name()) case vmErr == nil: t.Errorf("failed evaluation from evalengine:\nSQL: %s\nError: %s", query, evalErr) case evalErr == nil: diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index d73485441c3..2adab98afc4 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -29,7 +29,7 @@ func (e *evalTemporal) ToRawBytes() []byte { switch e.t { case sqltypes.Date: return e.dt.Date.Format() - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.Format(e.prec) case sqltypes.Time: return e.dt.Time.Format(e.prec) @@ -54,7 +54,7 @@ func (e *evalTemporal) toInt64() int64 { switch e.SQLType() { case sqltypes.Date: return e.dt.Date.FormatInt64() - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatInt64() case sqltypes.Time: return e.dt.Time.FormatInt64() @@ -67,7 +67,7 @@ func (e *evalTemporal) toFloat() float64 { switch e.SQLType() { case sqltypes.Date: return float64(e.dt.Date.FormatInt64()) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatFloat64() case sqltypes.Time: return e.dt.Time.FormatFloat64() @@ -80,7 +80,7 @@ func (e *evalTemporal) toDecimal() decimal.Decimal { switch e.SQLType() { case sqltypes.Date: return decimal.NewFromInt(e.dt.Date.FormatInt64()) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatDecimal() case sqltypes.Time: return e.dt.Time.FormatDecimal() @@ -93,7 +93,7 @@ func (e *evalTemporal) toJSON() *evalJSON { switch e.SQLType() { case sqltypes.Date: return json.NewDate(hack.String(e.dt.Date.Format())) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return json.NewDateTime(hack.String(e.dt.Format(datetime.DefaultPrecision))) case sqltypes.Time: return json.NewTime(hack.String(e.dt.Time.Format(datetime.DefaultPrecision))) @@ -104,7 +104,7 @@ func (e *evalTemporal) toJSON() *evalJSON { func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal { switch e.SQLType() { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Round(l), prec: uint8(l)} case sqltypes.Time: return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)} @@ -113,9 +113,23 @@ func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal { } } +func (e *evalTemporal) toTimestamp(l int, now time.Time) *evalTemporal { + switch e.SQLType() { + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: + return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Round(l), prec: uint8(l)} + case sqltypes.Time: + return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)} + default: + panic("unreachable") + } +} + func (e *evalTemporal) toTime(l int) *evalTemporal { + if l == -1 { + l = int(e.prec) + } switch e.SQLType() { - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: dt := datetime.DateTime{Time: e.dt.Time.Round(l)} return &evalTemporal{t: sqltypes.Time, dt: dt, prec: uint8(l)} case sqltypes.Date: @@ -130,7 +144,7 @@ func (e *evalTemporal) toTime(l int) *evalTemporal { func (e *evalTemporal) toDate(now time.Time) *evalTemporal { switch e.SQLType() { - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: dt := datetime.DateTime{Date: e.dt.Date} return &evalTemporal{t: sqltypes.Date, dt: dt} case sqltypes.Date: @@ -148,6 +162,13 @@ func (e *evalTemporal) isZero() bool { return e.dt.IsZero() } +func (e *evalTemporal) compare(other *evalTemporal) int { + if other == nil { + return 1 + } + return e.dt.Compare(other.dt) +} + func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations.ID, now time.Time) eval { var tmp *evalTemporal var ok bool @@ -179,6 +200,13 @@ func newEvalDateTime(dt datetime.DateTime, l int, allowZero bool) *evalTemporal return &evalTemporal{t: sqltypes.Datetime, dt: dt.Round(l), prec: uint8(l)} } +func newEvalTimestamp(dt datetime.DateTime, l int, allowZero bool) *evalTemporal { + if !allowZero && dt.IsZero() { + return nil + } + return &evalTemporal{t: sqltypes.Timestamp, dt: dt.Round(l), prec: uint8(l)} +} + func newEvalDate(d datetime.Date, allowZero bool) *evalTemporal { if !allowZero && d.IsZero() { return nil @@ -210,6 +238,14 @@ func parseDateTime(s []byte) (*evalTemporal, error) { return newEvalDateTime(t, l, true), nil } +func parseTimestamp(s []byte) (*evalTemporal, error) { + t, l, ok := datetime.ParseDateTime(hack.String(s), -1) + if !ok { + return nil, errIncorrectTemporal("TIMESTAMP", s) + } + return newEvalTimestamp(t, l, true), nil +} + func parseTime(s []byte) (*evalTemporal, error) { t, l, state := datetime.ParseTime(hack.String(s), -1) if state != datetime.TimeOK { @@ -387,6 +423,53 @@ func evalToDateTime(e eval, l int, now time.Time, allowZero bool) *evalTemporal return nil } +func evalToTimestamp(e eval, l int, now time.Time, allowZero bool) *evalTemporal { + switch e := e.(type) { + case *evalTemporal: + return e.toTimestamp(precision(l, int(e.prec)), now) + case *evalBytes: + if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() { + return newEvalTimestamp(t, l, allowZero) + } + if d, _ := datetime.ParseDate(e.string()); !d.IsZero() { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalInt64: + if t, ok := datetime.ParseDateTimeInt64(e.i); ok { + return newEvalTimestamp(t, precision(l, 0), allowZero) + } + if d, ok := datetime.ParseDateInt64(e.i); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalUint64: + if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok { + return newEvalTimestamp(t, precision(l, 0), allowZero) + } + if d, ok := datetime.ParseDateInt64(int64(e.u)); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalFloat: + if t, l, ok := datetime.ParseDateTimeFloat(e.f, l); ok { + return newEvalTimestamp(t, l, allowZero) + } + if d, ok := datetime.ParseDateFloat(e.f); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalDecimal: + if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, l); ok { + return newEvalTimestamp(t, l, allowZero) + } + if d, ok := datetime.ParseDateDecimal(e.dec); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalJSON: + if dt, ok := e.DateTime(); ok { + return newEvalTimestamp(dt, precision(l, datetime.DefaultPrecision), allowZero) + } + } + return nil +} + func evalToDate(e eval, now time.Time, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: diff --git a/go/vt/vtgate/evalengine/expr_bvar.go b/go/vt/vtgate/evalengine/expr_bvar.go index 50f231dbe9c..23b40949e83 100644 --- a/go/vt/vtgate/evalengine/expr_bvar.go +++ b/go/vt/vtgate/evalengine/expr_bvar.go @@ -154,8 +154,10 @@ func (bvar *BindVariable) compile(c *compiler) (ctype, error) { c.asm.PushNull() case tt == sqltypes.TypeJSON: c.asm.PushBVar_json(bvar.Key) - case tt == sqltypes.Datetime || tt == sqltypes.Timestamp: + case tt == sqltypes.Datetime: c.asm.PushBVar_datetime(bvar.Key) + case tt == sqltypes.Timestamp: + c.asm.PushBVar_timestamp(bvar.Key) case tt == sqltypes.Date: c.asm.PushBVar_date(bvar.Key) case tt == sqltypes.Time: diff --git a/go/vt/vtgate/evalengine/expr_collate.go b/go/vt/vtgate/evalengine/expr_collate.go index be0eb78882b..b381acf6356 100644 --- a/go/vt/vtgate/evalengine/expr_collate.go +++ b/go/vt/vtgate/evalengine/expr_collate.go @@ -118,7 +118,7 @@ func (expr *CollateExpr) compile(c *compiler) (ctype, error) { } fallthrough case sqltypes.VarBinary: - c.asm.Collate(expr.TypedCollation.Collation) + c.asm.Collate(expr.TypedCollation) default: c.asm.Convert_xc(1, sqltypes.VarChar, expr.TypedCollation.Collation, nil) } diff --git a/go/vt/vtgate/evalengine/expr_column.go b/go/vt/vtgate/evalengine/expr_column.go index e52c522d973..7df113ee5d2 100644 --- a/go/vt/vtgate/evalengine/expr_column.go +++ b/go/vt/vtgate/evalengine/expr_column.go @@ -145,8 +145,10 @@ func (column *Column) compile(c *compiler) (ctype, error) { c.asm.PushNull() case tt == sqltypes.TypeJSON: c.asm.PushColumn_json(column.Offset) - case tt == sqltypes.Datetime || tt == sqltypes.Timestamp: + case tt == sqltypes.Datetime: c.asm.PushColumn_datetime(column.Offset) + case tt == sqltypes.Timestamp: + c.asm.PushColumn_timestamp(column.Offset) case tt == sqltypes.Date: c.asm.PushColumn_date(column.Offset) case tt == sqltypes.Time: diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index 1deec6752ef..1084a240bd8 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -22,6 +22,7 @@ import ( "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/mysql/collations/colldata" + datetime2 "vitess.io/vitess/go/mysql/datetime" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/proto/vtrpc" "vitess.io/vitess/go/vt/vterrors" @@ -32,11 +33,12 @@ type ( CallExpr } - multiComparisonFunc func(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) + multiComparisonFunc func(env *ExpressionEnv, args []eval, cmp, prec int) (eval, error) builtinMultiComparison struct { CallExpr - cmp int + cmp int + prec int } ) @@ -93,7 +95,7 @@ func (b *builtinCoalesce) compile(c *compiler) (ctype, error) { return ctype{Type: ta.result(), Flag: f, Col: ca.result()}, nil } -func getMultiComparisonFunc(args []eval) multiComparisonFunc { +func (call *builtinMultiComparison) getMultiComparisonFunc(args []eval) multiComparisonFunc { var ( integersI int integersU int @@ -101,6 +103,11 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { decimals int text int binary int + temporal int + datetime int + timestamp int + date int + time int ) /* @@ -114,7 +121,7 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { for _, arg := range args { if arg == nil { - return func(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) { + return func(_ *ExpressionEnv, _ []eval, _, _ int) (eval, error) { return nil, nil } } @@ -126,18 +133,86 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { integersU++ case *evalFloat: floats++ + call.prec = datetime2.DefaultPrecision case *evalDecimal: decimals++ + call.prec = max(call.prec, int(arg.length)) case *evalBytes: switch arg.SQLType() { case sqltypes.Text, sqltypes.VarChar: text++ + call.prec = max(call.prec, datetime2.DefaultPrecision) case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ + if !arg.isHexOrBitLiteral() { + call.prec = max(call.prec, datetime2.DefaultPrecision) + } + } + case *evalTemporal: + temporal++ + call.prec = max(call.prec, int(arg.prec)) + switch arg.SQLType() { + case sqltypes.Datetime: + datetime++ + case sqltypes.Timestamp: + timestamp++ + case sqltypes.Date: + date++ + case sqltypes.Time: + time++ } } } + if temporal == len(args) { + switch { + case datetime > 0: + return compareAllTemporal(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case timestamp > 0: + return compareAllTemporal(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTimestamp(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case date > 0 && time > 0: + // When all types are temporal, we convert the case + // of having a date and time all to datetime. + // This is contrary to the case where we have a non-temporal + // type in the list, since MySQL doesn't do that. + return compareAllTemporal(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case date > 0: + return compareAllTemporal(func(env *ExpressionEnv, arg eval, _ int) *evalTemporal { + return evalToDate(arg, env.now, env.sqlmode.AllowZeroDate()) + }) + case time > 0: + return compareAllTemporal(func(_ *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTime(arg, prec) + }) + } + } + + switch { + case datetime > 0: + return compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case timestamp > 0: + return compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTimestamp(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }) + case date > 0: + return compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, _ int) *evalTemporal { + return evalToDate(arg, env.now, env.sqlmode.AllowZeroDate()) + }) + case time > 0: + // So for time, there's actually no conversion and + // internal comparisons as time. So we don't pass it + // a conversion function. + return compareAllTemporalAsString(nil) + } + if integersI+integersU == len(args) { if integersI == len(args) { return compareAllInteger_i @@ -165,7 +240,93 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { panic("unexpected argument type") } -func compareAllInteger_u(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllTemporal(f func(env *ExpressionEnv, arg eval, prec int) *evalTemporal) multiComparisonFunc { + return func(env *ExpressionEnv, args []eval, cmp, prec int) (eval, error) { + var x *evalTemporal + for _, arg := range args { + conv := f(env, arg, prec) + if x == nil { + x = conv + continue + } + if (cmp < 0) == (conv.compare(x) < 0) { + x = conv + } + } + return x, nil + } +} + +func compareAllTemporalAsString(f func(env *ExpressionEnv, arg eval, prec int) *evalTemporal) multiComparisonFunc { + return func(env *ExpressionEnv, args []eval, cmp, prec int) (eval, error) { + validArgs := make([]*evalTemporal, 0, len(args)) + var ca collationAggregation + for _, arg := range args { + if err := ca.add(evalCollation(arg), env.collationEnv); err != nil { + return nil, err + } + if f != nil { + conv := f(env, arg, prec) + validArgs = append(validArgs, conv) + } + } + tc := ca.result() + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(sqltypes.VarChar, env.collationEnv.DefaultConnectionCharset()) + } + if f != nil { + idx := compareTemporalInternal(validArgs, cmp) + if idx >= 0 { + arg := args[idx] + if _, ok := arg.(*evalTemporal); ok { + arg = validArgs[idx] + } + return evalToVarchar(arg, tc.Collation, false) + } + } + txt, err := compareAllText(env, args, cmp, prec) + if err != nil { + return nil, err + } + return evalToVarchar(txt, tc.Collation, false) + } +} + +func compareTemporalInternal(args []*evalTemporal, cmp int) int { + if cmp < 0 { + // If we have any failed conversions and want to have the smallest value, + // we can't find that so we return -1 to indicate that. + // This will result in a fallback to do a string comparison. + for _, arg := range args { + if arg == nil { + return -1 + } + } + } + + x := 0 + for i, arg := range args[1:] { + if arg == nil { + continue + } + if (cmp < 0) == (compareTemporal(args, i+1, x) < 0) { + x = i + 1 + } + } + return x +} + +func compareTemporal(args []*evalTemporal, idx1, idx2 int) int { + if idx1 < 0 { + return 1 + } + if idx2 < 0 { + return -1 + } + return args[idx1].compare(args[idx2]) +} + +func compareAllInteger_u(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { x := args[0].(*evalUint64) for _, arg := range args[1:] { y := arg.(*evalUint64) @@ -176,7 +337,7 @@ func compareAllInteger_u(_ *collations.Environment, args []eval, cmp int) (eval, return x, nil } -func compareAllInteger_i(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllInteger_i(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { x := args[0].(*evalInt64) for _, arg := range args[1:] { y := arg.(*evalInt64) @@ -187,7 +348,7 @@ func compareAllInteger_i(_ *collations.Environment, args []eval, cmp int) (eval, return x, nil } -func compareAllFloat(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllFloat(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { candidateF, ok := evalToFloat(args[0]) if !ok { return nil, errDecimalOutOfRange @@ -212,7 +373,7 @@ func evalDecimalPrecision(e eval) int32 { return 0 } -func compareAllDecimal(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllDecimal(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { decExtreme := evalToDecimal(args[0], 0, 0).dec precExtreme := evalDecimalPrecision(args[0]) @@ -229,12 +390,12 @@ func compareAllDecimal(_ *collations.Environment, args []eval, cmp int) (eval, e return newEvalDecimalWithPrec(decExtreme, precExtreme), nil } -func compareAllText(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllText(env *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { var charsets = make([]charset.Charset, 0, len(args)) var ca collationAggregation for _, arg := range args { col := evalCollation(arg) - if err := ca.add(col, collationEnv); err != nil { + if err := ca.add(col, env.collationEnv); err != nil { return nil, err } charsets = append(charsets, colldata.Lookup(col.Collation).Charset()) @@ -262,7 +423,7 @@ func compareAllText(collationEnv *collations.Environment, args []eval, cmp int) return newEvalText(b1, tc), nil } -func compareAllBinary(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllBinary(_ *ExpressionEnv, args []eval, cmp, _ int) (eval, error) { candidateB := args[0].ToRawBytes() for _, arg := range args[1:] { @@ -280,7 +441,7 @@ func (call *builtinMultiComparison) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - return getMultiComparisonFunc(args)(env.collationEnv, args, call.cmp) + return call.getMultiComparisonFunc(args)(env, args, call.cmp, call.prec) } func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, error) { @@ -314,14 +475,20 @@ func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { var ( - signed int - unsigned int - floats int - decimals int - text int - binary int - args []ctype - nullable bool + signed int + unsigned int + floats int + decimals int + temporal int + date int + datetime int + timestamp int + time int + text int + binary int + args []ctype + nullable bool + prec int ) /* @@ -349,12 +516,34 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { unsigned++ case sqltypes.Float64: floats++ + prec = max(prec, datetime2.DefaultPrecision) case sqltypes.Decimal: decimals++ + prec = max(prec, int(tt.Scale)) case sqltypes.Text, sqltypes.VarChar: text++ + prec = max(prec, datetime2.DefaultPrecision) case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ + if !tt.isHexOrBitLiteral() { + prec = max(prec, datetime2.DefaultPrecision) + } + case sqltypes.Date: + temporal++ + date++ + prec = max(prec, int(tt.Size)) + case sqltypes.Datetime: + temporal++ + datetime++ + prec = max(prec, int(tt.Size)) + case sqltypes.Timestamp: + temporal++ + timestamp++ + prec = max(prec, int(tt.Size)) + case sqltypes.Time: + temporal++ + time++ + prec = max(prec, int(tt.Size)) case sqltypes.Null: nullable = true default: @@ -366,6 +555,61 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { if nullable { f |= flagNullable } + if temporal == len(args) { + var typ sqltypes.Type + switch { + case datetime > 0: + typ = sqltypes.Datetime + case timestamp > 0: + typ = sqltypes.Timestamp + case date > 0 && time > 0: + // When all types are temporal, we convert the case + // of having a date and time all to datetime. + // This is contrary to the case where we have a non-temporal + // type in the list, since MySQL doesn't do that. + typ = sqltypes.Datetime + case date > 0: + typ = sqltypes.Date + case time > 0: + typ = sqltypes.Time + } + for i, tt := range args { + if tt.Type != typ || int(tt.Size) != prec { + c.compileToTemporal(tt, typ, len(args)-i, prec) + } + } + c.asm.Fn_MULTICMP_temporal(len(args), call.cmp < 0) + return ctype{Type: typ, Flag: f, Col: collationBinary}, nil + } else if temporal > 0 { + var ca collationAggregation + for _, arg := range args { + if err := ca.add(arg.Col, c.env.CollationEnv()); err != nil { + return ctype{}, err + } + } + + tc := ca.result() + if tc.Coercibility == collations.CoerceNumeric { + tc = typedCoercionCollation(sqltypes.VarChar, c.collation) + } + switch { + case datetime > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDateTime(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }), len(args), call.cmp, prec) + case timestamp > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToTimestamp(arg, prec, env.now, env.sqlmode.AllowZeroDate()) + }), len(args), call.cmp, prec) + case date > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(func(env *ExpressionEnv, arg eval, prec int) *evalTemporal { + return evalToDate(arg, env.now, env.sqlmode.AllowZeroDate()) + }), len(args), call.cmp, prec) + case time > 0: + c.asm.Fn_MULTICMP_temporal_fallback(compareAllTemporalAsString(nil), len(args), call.cmp, prec) + } + return ctype{Type: sqltypes.VarChar, Flag: f, Col: tc}, nil + } if signed+unsigned == len(args) { if signed == len(args) { c.asm.Fn_MULTICMP_i(len(args), call.cmp < 0) diff --git a/go/vt/vtgate/evalengine/fn_compare_test.go b/go/vt/vtgate/evalengine/fn_compare_test.go new file mode 100644 index 00000000000..def40d8365c --- /dev/null +++ b/go/vt/vtgate/evalengine/fn_compare_test.go @@ -0,0 +1,80 @@ +/* +Copyright 2025 The Vitess Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package evalengine + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + + "vitess.io/vitess/go/mysql/datetime" +) + +func TestCompareTemporal(t *testing.T) { + tests := []struct { + name string + val1 *evalTemporal + val2 *evalTemporal + result int + }{ + { + name: "equal values", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: 0, + }, + { + name: "larger value", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: -1, + }, + { + name: "smaller value", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2022, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: 1, + }, + { + name: "first nil value", + val1: nil, + val2: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + result: 1, + }, + + { + name: "second nil value", + val1: newEvalDateTime(datetime.NewDateTimeFromStd(time.Date(2021, 1, 1, 0, 0, 0, 0, time.UTC)), 6, false), + val2: nil, + result: -1, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + idx1 := 0 + idx2 := 1 + if tt.val1 == nil { + idx1 = -1 + } + if tt.val2 == nil { + idx2 = -1 + } + assert.Equal(t, tt.result, compareTemporal([]*evalTemporal{tt.val1, tt.val2}, idx1, idx2)) + }) + } +} diff --git a/go/vt/vtgate/evalengine/fn_time.go b/go/vt/vtgate/evalengine/fn_time.go index 322b89faafb..c50a7a265f5 100644 --- a/go/vt/vtgate/evalengine/fn_time.go +++ b/go/vt/vtgate/evalengine/fn_time.go @@ -335,7 +335,7 @@ func (call *builtinDateFormat) compile(c *compiler) (ctype, error) { skip1 := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: default: c.asm.Convert_xDT(1, datetime.DefaultPrecision, false) } @@ -451,7 +451,7 @@ func (call *builtinConvertTz) compile(c *compiler) (ctype, error) { var prec int32 switch n.Type { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: prec = n.Size case sqltypes.Decimal: prec = n.Scale @@ -533,7 +533,7 @@ func (call *builtinDayOfMonth) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -566,7 +566,7 @@ func (call *builtinDayOfWeek) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -599,7 +599,7 @@ func (call *builtinDayOfYear) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -742,7 +742,7 @@ func (call *builtinFromUnixtime) compile(c *compiler) (ctype, error) { case sqltypes.Decimal: prec = arg.Size c.asm.Fn_FROM_UNIXTIME_d() - case sqltypes.Datetime, sqltypes.Date, sqltypes.Time: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Time, sqltypes.Timestamp: prec = arg.Size if prec == 0 { c.asm.Convert_Ti(1) @@ -814,7 +814,7 @@ func (call *builtinHour) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1160,7 +1160,7 @@ func (call *builtinMicrosecond) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1193,7 +1193,7 @@ func (call *builtinMinute) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1226,7 +1226,7 @@ func (call *builtinMonth) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -1264,7 +1264,7 @@ func (call *builtinMonthName) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1309,7 +1309,7 @@ func (call *builtinLastDay) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, c.sqlmode.AllowZeroDate()) } @@ -1344,7 +1344,7 @@ func (call *builtinToDays) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1481,7 +1481,7 @@ func (call *builtinTimeToSec) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1516,7 +1516,7 @@ func (call *builtinToSeconds) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xDT(1, -1, false) } @@ -1549,7 +1549,7 @@ func (call *builtinQuarter) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -1582,7 +1582,7 @@ func (call *builtinSecond) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime, sqltypes.Time: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Time, sqltypes.Timestamp: default: c.asm.Convert_xT(1, -1) } @@ -1617,7 +1617,7 @@ func (call *builtinTime) compile(c *compiler) (ctype, error) { var prec int32 switch arg.Type { case sqltypes.Time: - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: prec = arg.Size c.asm.Convert_xT(1, -1) case sqltypes.Decimal: @@ -1717,7 +1717,7 @@ func (call *builtinUnixTimestamp) compile(c *compiler) (ctype, error) { c.asm.Fn_UNIX_TIMESTAMP1() c.asm.jumpDestination(skip) switch arg.Type { - case sqltypes.Datetime, sqltypes.Time, sqltypes.Decimal: + case sqltypes.Datetime, sqltypes.Time, sqltypes.Decimal, sqltypes.Timestamp: if arg.Size == 0 { return ctype{Type: sqltypes.Int64, Col: collationNumeric, Flag: arg.Flag}, nil } @@ -1782,7 +1782,7 @@ func (call *builtinWeek) compile(c *compiler) (ctype, error) { var skip2 *jump switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1827,7 +1827,7 @@ func (call *builtinWeekDay) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1863,7 +1863,7 @@ func (call *builtinWeekOfYear) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } @@ -1906,7 +1906,7 @@ func (call *builtinYear) compile(c *compiler) (ctype, error) { skip := c.compileNullCheck1(arg) switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, true) } @@ -1956,7 +1956,7 @@ func (call *builtinYearWeek) compile(c *compiler) (ctype, error) { var skip2 *jump switch arg.Type { - case sqltypes.Date, sqltypes.Datetime: + case sqltypes.Date, sqltypes.Datetime, sqltypes.Timestamp: default: c.asm.Convert_xD(1, false) } diff --git a/go/vt/vtgate/evalengine/integration/comparison_test.go b/go/vt/vtgate/evalengine/integration/comparison_test.go index d559cb8ab1d..0e15869a125 100644 --- a/go/vt/vtgate/evalengine/integration/comparison_test.go +++ b/go/vt/vtgate/evalengine/integration/comparison_test.go @@ -82,12 +82,12 @@ func normalizeValue(v sqltypes.Value, coll collations.ID) sqltypes.Value { return v } -func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, env *evalengine.ExpressionEnv, conn *mysql.Conn, expr string, fields []*querypb.Field, cmp *testcases.Comparison) { +func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, env *evalengine.ExpressionEnv, conn *mysql.Conn, expr string, fields []*querypb.Field, cmp *testcases.Comparison, skipCollationCheck bool) { t.Helper() localQuery := "SELECT " + expr remoteQuery := "SELECT " + expr - if debugCheckCollations { + if debugCheckCollations && !skipCollationCheck { remoteQuery = fmt.Sprintf("SELECT %s, COLLATION(%s)", expr, expr) } if len(fields) > 0 { @@ -146,7 +146,7 @@ func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, en var localCollation, remoteCollation collations.ID if localErr == nil { v := local.Value(collations.MySQL8().DefaultConnectionCharset()) - if debugCheckCollations { + if debugCheckCollations && !skipCollationCheck { if v.IsNull() { localCollation = collations.CollationBinaryID } else { @@ -166,7 +166,7 @@ func compareRemoteExprEnv(t *testing.T, collationEnv *collations.Environment, en } else { remoteVal = remote.Rows[0][0] } - if debugCheckCollations { + if debugCheckCollations && !skipCollationCheck { if remote.Rows[0][0].IsNull() { // TODO: passthrough proper collations for nullable fields remoteCollation = collations.CollationBinaryID @@ -275,9 +275,9 @@ func TestMySQL(t *testing.T) { Username: "vt_dba", }) env := evalengine.NewExpressionEnv(ctx, nil, &vcursor{env: venv}) - tc.Run(func(query string, row []sqltypes.Value) { + tc.Run(func(query string, row []sqltypes.Value, skipCollationCheck bool) { env.Row = row - compareRemoteExprEnv(t, collationEnv, env, conn, query, tc.Schema, tc.Compare) + compareRemoteExprEnv(t, collationEnv, env, conn, query, tc.Schema, tc.Compare, skipCollationCheck) }) }) } diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index ff6c0c0f311..5469873b10e 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -178,18 +178,18 @@ var Cases = []TestCase{ func JSONPathOperations(yield Query) { for _, obj := range inputJSONObjects { - yield(fmt.Sprintf("JSON_KEYS('%s')", obj), nil) + yield(fmt.Sprintf("JSON_KEYS('%s')", obj), nil, false) for _, path1 := range inputJSONPaths { - yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s')", obj, path1), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s')", obj, path1), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s')", obj, path1), nil) - yield(fmt.Sprintf("JSON_KEYS('%s', '%s')", obj, path1), nil) + yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s')", obj, path1), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s')", obj, path1), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s')", obj, path1), nil, false) + yield(fmt.Sprintf("JSON_KEYS('%s', '%s')", obj, path1), nil, false) for _, path2 := range inputJSONPaths { - yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s', '%s')", obj, path1, path2), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s', '%s')", obj, path1, path2), nil) - yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s', '%s')", obj, path1, path2), nil) + yield(fmt.Sprintf("JSON_EXTRACT('%s', '%s', '%s')", obj, path1, path2), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'one', '%s', '%s')", obj, path1, path2), nil, false) + yield(fmt.Sprintf("JSON_CONTAINS_PATH('%s', 'all', '%s', '%s')", obj, path1, path2), nil, false) } } } @@ -197,21 +197,21 @@ func JSONPathOperations(yield Query) { func JSONArray(yield Query) { for _, a := range inputJSONPrimitives { - yield(fmt.Sprintf("JSON_ARRAY(%s)", a), nil) + yield(fmt.Sprintf("JSON_ARRAY(%s)", a), nil, false) for _, b := range inputJSONPrimitives { - yield(fmt.Sprintf("JSON_ARRAY(%s, %s)", a, b), nil) + yield(fmt.Sprintf("JSON_ARRAY(%s, %s)", a, b), nil, false) } } - yield("JSON_ARRAY()", nil) + yield("JSON_ARRAY()", nil, false) } func JSONObject(yield Query) { for _, a := range inputJSONPrimitives { for _, b := range inputJSONPrimitives { - yield(fmt.Sprintf("JSON_OBJECT(%s, %s)", a, b), nil) + yield(fmt.Sprintf("JSON_OBJECT(%s, %s)", a, b), nil, false) } } - yield("JSON_OBJECT()", nil) + yield("JSON_OBJECT()", nil, false) } func CharsetConversionOperators(yield Query) { @@ -228,7 +228,7 @@ func CharsetConversionOperators(yield Query) { for _, pfx := range introducers { for _, lhs := range contents { for _, rhs := range charsets { - yield(fmt.Sprintf("HEX(CONVERT(%s %s USING %s))", pfx, lhs, rhs), nil) + yield(fmt.Sprintf("HEX(CONVERT(%s %s USING %s))", pfx, lhs, rhs), nil, false) } } } @@ -250,7 +250,7 @@ func CaseExprWithPredicate(yield Query) { for _, pred1 := range predicates { for _, val1 := range elements { for _, elseVal := range elements { - yield(fmt.Sprintf("case when %s then %s else %s end", pred1, val1, elseVal), nil) + yield(fmt.Sprintf("case when %s then %s else %s end", pred1, val1, elseVal), nil, false) } } } @@ -259,7 +259,7 @@ func CaseExprWithPredicate(yield Query) { genSubsets(elements, 3, func(values []string) { yield(fmt.Sprintf("case when %s then %s when %s then %s when %s then %s end", predicates[0], values[0], predicates[1], values[1], predicates[2], values[2], - ), nil) + ), nil, false) }) }) } @@ -279,13 +279,13 @@ func FnCeil(yield Query) { } for _, num := range ceilInputs { - yield(fmt.Sprintf("CEIL(%s)", num), nil) - yield(fmt.Sprintf("CEILING(%s)", num), nil) + yield(fmt.Sprintf("CEIL(%s)", num), nil, false) + yield(fmt.Sprintf("CEILING(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("CEIL(%s)", num), nil) - yield(fmt.Sprintf("CEILING(%s)", num), nil) + yield(fmt.Sprintf("CEIL(%s)", num), nil, false) + yield(fmt.Sprintf("CEILING(%s)", num), nil, false) } } @@ -304,11 +304,11 @@ func FnFloor(yield Query) { } for _, num := range floorInputs { - yield(fmt.Sprintf("FLOOR(%s)", num), nil) + yield(fmt.Sprintf("FLOOR(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("FLOOR(%s)", num), nil) + yield(fmt.Sprintf("FLOOR(%s)", num), nil, false) } } @@ -327,280 +327,280 @@ func FnAbs(yield Query) { } for _, num := range absInputs { - yield(fmt.Sprintf("ABS(%s)", num), nil) + yield(fmt.Sprintf("ABS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ABS(%s)", num), nil) + yield(fmt.Sprintf("ABS(%s)", num), nil, false) } } func FnPi(yield Query) { - yield("PI()+0.000000000000000000", nil) + yield("PI()+0.000000000000000000", nil, false) } func FnAcos(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ACOS(%s)", num), nil) + yield(fmt.Sprintf("ACOS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ACOS(%s)", num), nil) + yield(fmt.Sprintf("ACOS(%s)", num), nil, false) } } func FnAsin(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ASIN(%s)", num), nil) + yield(fmt.Sprintf("ASIN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ASIN(%s)", num), nil) + yield(fmt.Sprintf("ASIN(%s)", num), nil, false) } } func FnAtan(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ATAN(%s)", num), nil) + yield(fmt.Sprintf("ATAN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ATAN(%s)", num), nil) + yield(fmt.Sprintf("ATAN(%s)", num), nil, false) } } func FnAtan2(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range inputBitwise { - yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ATAN(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("ATAN2(%s, %s)", num1, num2), nil, false) } } } func FnCos(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("COS(%s)", num), nil) + yield(fmt.Sprintf("COS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("COS(%s)", num), nil) + yield(fmt.Sprintf("COS(%s)", num), nil, false) } } func FnCot(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("COT(%s)", num), nil) + yield(fmt.Sprintf("COT(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("COT(%s)", num), nil) + yield(fmt.Sprintf("COT(%s)", num), nil, false) } } func FnSin(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SIN(%s)", num), nil) + yield(fmt.Sprintf("SIN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SIN(%s)", num), nil) + yield(fmt.Sprintf("SIN(%s)", num), nil, false) } } func FnTan(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("TAN(%s)", num), nil) + yield(fmt.Sprintf("TAN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("TAN(%s)", num), nil) + yield(fmt.Sprintf("TAN(%s)", num), nil, false) } } func FnDegrees(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("DEGREES(%s)", num), nil) + yield(fmt.Sprintf("DEGREES(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("DEGREES(%s)", num), nil) + yield(fmt.Sprintf("DEGREES(%s)", num), nil, false) } } func FnRadians(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("RADIANS(%s)", num), nil) + yield(fmt.Sprintf("RADIANS(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("RADIANS(%s)", num), nil) + yield(fmt.Sprintf("RADIANS(%s)", num), nil, false) } } func FnExp(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("EXP(%s)", num), nil) + yield(fmt.Sprintf("EXP(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("EXP(%s)", num), nil) + yield(fmt.Sprintf("EXP(%s)", num), nil, false) } } func FnLn(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LN(%s)", num), nil) + yield(fmt.Sprintf("LN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LN(%s)", num), nil) + yield(fmt.Sprintf("LN(%s)", num), nil, false) } } func FnLog(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LOG(%s)", num), nil) + yield(fmt.Sprintf("LOG(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LOG(%s)", num), nil) + yield(fmt.Sprintf("LOG(%s)", num), nil, false) } for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("LOG(%s, %s)", num1, num2), nil, false) } } } func FnLog10(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LOG10(%s)", num), nil) + yield(fmt.Sprintf("LOG10(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LOG10(%s)", num), nil) + yield(fmt.Sprintf("LOG10(%s)", num), nil, false) } } func FnMod(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("MOD(%s, %s)", num1, num2), nil, false) } } } func FnLog2(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LOG2(%s)", num), nil) + yield(fmt.Sprintf("LOG2(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LOG2(%s)", num), nil) + yield(fmt.Sprintf("LOG2(%s)", num), nil, false) } } func FnPow(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil) - yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("POW(%s, %s)", num1, num2), nil, false) + yield(fmt.Sprintf("POWER(%s, %s)", num1, num2), nil, false) } } } func FnSign(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SIGN(%s)", num), nil) + yield(fmt.Sprintf("SIGN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SIGN(%s)", num), nil) + yield(fmt.Sprintf("SIGN(%s)", num), nil, false) } } func FnSqrt(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SQRT(%s)", num), nil) + yield(fmt.Sprintf("SQRT(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SQRT(%s)", num), nil) + yield(fmt.Sprintf("SQRT(%s)", num), nil, false) } } func FnRound(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("ROUND(%s)", num), nil) + yield(fmt.Sprintf("ROUND(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("ROUND(%s)", num), nil) + yield(fmt.Sprintf("ROUND(%s)", num), nil, false) } for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("ROUND(%s, %s)", num1, num2), nil, false) } } } @@ -608,34 +608,34 @@ func FnRound(yield Query) { func FnTruncate(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } } for _, num1 := range inputBitwise { for _, num2 := range radianInputs { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } for _, num2 := range inputBitwise { - yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil) + yield(fmt.Sprintf("TRUNCATE(%s, %s)", num1, num2), nil, false) } } } func FnCrc32(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("CRC32(%s)", num), nil) + yield(fmt.Sprintf("CRC32(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("CRC32(%s)", num), nil) + yield(fmt.Sprintf("CRC32(%s)", num), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("CRC32(%s)", num), nil) + yield(fmt.Sprintf("CRC32(%s)", num), nil, false) } } @@ -643,10 +643,10 @@ func FnConv(yield Query) { for _, num1 := range radianInputs { for _, num2 := range radianInputs { for _, num3 := range radianInputs { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } for _, num3 := range inputBitwise { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } } } @@ -654,10 +654,10 @@ func FnConv(yield Query) { for _, num1 := range radianInputs { for _, num2 := range inputBitwise { for _, num3 := range radianInputs { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } for _, num3 := range inputBitwise { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } } } @@ -665,10 +665,10 @@ func FnConv(yield Query) { for _, num1 := range inputBitwise { for _, num2 := range inputBitwise { for _, num3 := range radianInputs { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } for _, num3 := range inputBitwise { - yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil) + yield(fmt.Sprintf("CONV(%s, %s, %s)", num1, num2, num3), nil, false) } } } @@ -676,50 +676,50 @@ func FnConv(yield Query) { func FnBin(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("BIN(%s)", num), nil) + yield(fmt.Sprintf("BIN(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("BIN(%s)", num), nil) + yield(fmt.Sprintf("BIN(%s)", num), nil, false) } } func FnOct(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("OCT(%s)", num), nil) + yield(fmt.Sprintf("OCT(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("OCT(%s)", num), nil) + yield(fmt.Sprintf("OCT(%s)", num), nil, false) } } func FnMD5(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("MD5(%s)", num), nil) + yield(fmt.Sprintf("MD5(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("MD5(%s)", num), nil) + yield(fmt.Sprintf("MD5(%s)", num), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("MD5(%s)", num), nil) + yield(fmt.Sprintf("MD5(%s)", num), nil, false) } } func FnSHA1(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("SHA1(%s)", num), nil) - yield(fmt.Sprintf("SHA(%s)", num), nil) + yield(fmt.Sprintf("SHA1(%s)", num), nil, false) + yield(fmt.Sprintf("SHA(%s)", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SHA1(%s)", num), nil) - yield(fmt.Sprintf("SHA(%s)", num), nil) + yield(fmt.Sprintf("SHA1(%s)", num), nil, false) + yield(fmt.Sprintf("SHA(%s)", num), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("SHA1(%s)", num), nil) - yield(fmt.Sprintf("SHA(%s)", num), nil) + yield(fmt.Sprintf("SHA1(%s)", num), nil, false) + yield(fmt.Sprintf("SHA(%s)", num), nil, false) } } @@ -727,28 +727,28 @@ func FnSHA2(yield Query) { bitLengths := []string{"0", "224", "256", "384", "512", "1", "0.1", "256.1e0", "1-1", "128+128"} for _, bits := range bitLengths { for _, num := range radianInputs { - yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil) + yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil) + yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil, false) } for _, num := range inputConversions { - yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil) + yield(fmt.Sprintf("SHA2(%s, %s)", num, bits), nil, false) } } } func FnRandomBytes(yield Query) { for _, num := range radianInputs { - yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil) - yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil) + yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil, false) + yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil, false) } for _, num := range inputBitwise { - yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil) - yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil) + yield(fmt.Sprintf("LENGTH(RANDOM_BYTES(%s))", num), nil, false) + yield(fmt.Sprintf("COLLATION(RANDOM_BYTES(%s))", num), nil, false) } } @@ -762,7 +762,7 @@ func CaseExprWithValue(yield Query) { if !(bugs{}).CanCompare(cmpbase, val1) { continue } - yield(fmt.Sprintf("case %s when %s then 1 else 0 end", cmpbase, val1), nil) + yield(fmt.Sprintf("case %s when %s then 1 else 0 end", cmpbase, val1), nil, false) } } } @@ -775,7 +775,7 @@ func If(yield Query) { for _, cmpbase := range elements { for _, val1 := range elements { for _, val2 := range elements { - yield(fmt.Sprintf("if(%s, %s, %s)", cmpbase, val1, val2), nil) + yield(fmt.Sprintf("if(%s, %s, %s)", cmpbase, val1, val2), nil, false) } } } @@ -796,17 +796,17 @@ func Base64(yield Query) { } for _, lhs := range inputs { - yield(fmt.Sprintf("FROM_BASE64(%s)", lhs), nil) - yield(fmt.Sprintf("TO_BASE64(%s)", lhs), nil) + yield(fmt.Sprintf("FROM_BASE64(%s)", lhs), nil, false) + yield(fmt.Sprintf("TO_BASE64(%s)", lhs), nil, false) } } func Conversion(yield Query) { for _, lhs := range inputConversions { for _, rhs := range inputConversionTypes { - yield(fmt.Sprintf("CAST(%s AS %s)", lhs, rhs), nil) - yield(fmt.Sprintf("CONVERT(%s, %s)", lhs, rhs), nil) - yield(fmt.Sprintf("CAST(CAST(%s AS JSON) AS %s)", lhs, rhs), nil) + yield(fmt.Sprintf("CAST(%s AS %s)", lhs, rhs), nil, false) + yield(fmt.Sprintf("CONVERT(%s, %s)", lhs, rhs), nil, false) + yield(fmt.Sprintf("CAST(CAST(%s AS JSON) AS %s)", lhs, rhs), nil, false) } } } @@ -815,8 +815,8 @@ func LargeDecimals(yield Query) { var largepi = inputPi + inputPi for pos := 0; pos < len(largepi); pos++ { - yield(fmt.Sprintf("%s.%s", largepi[:pos], largepi[pos:]), nil) - yield(fmt.Sprintf("-%s.%s", largepi[:pos], largepi[pos:]), nil) + yield(fmt.Sprintf("%s.%s", largepi[:pos], largepi[pos:]), nil, false) + yield(fmt.Sprintf("-%s.%s", largepi[:pos], largepi[pos:]), nil, false) } } @@ -824,8 +824,8 @@ func LargeIntegers(yield Query) { var largepi = inputPi + inputPi for pos := 1; pos < len(largepi); pos++ { - yield(largepi[:pos], nil) - yield(fmt.Sprintf("-%s", largepi[:pos]), nil) + yield(largepi[:pos], nil, false) + yield(fmt.Sprintf("-%s", largepi[:pos]), nil, false) } } @@ -833,7 +833,7 @@ func DecimalClamping(yield Query) { for pos := 0; pos < len(inputPi); pos++ { for m := 0; m < min(len(inputPi), 67); m += 2 { for d := 0; d <= min(m, 33); d += 2 { - yield(fmt.Sprintf("CAST(%s.%s AS DECIMAL(%d, %d))", inputPi[:pos], inputPi[pos:], m, d), nil) + yield(fmt.Sprintf("CAST(%s.%s AS DECIMAL(%d, %d))", inputPi[:pos], inputPi[pos:], m, d), nil, false) } } } @@ -842,7 +842,7 @@ func DecimalClamping(yield Query) { func BitwiseOperatorsUnary(yield Query) { for _, op := range []string{"~", "BIT_COUNT"} { for _, rhs := range inputBitwise { - yield(fmt.Sprintf("%s(%s)", op, rhs), nil) + yield(fmt.Sprintf("%s(%s)", op, rhs), nil, false) } } } @@ -851,13 +851,13 @@ func BitwiseOperators(yield Query) { for _, op := range []string{"&", "|", "^", "<<", ">>"} { for _, lhs := range inputBitwise { for _, rhs := range inputBitwise { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } for _, lhs := range inputConversions { for _, rhs := range inputConversions { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } } @@ -910,7 +910,7 @@ func WeightString(yield Query) { } for _, i := range inputs { - yield(fmt.Sprintf("WEIGHT_STRING(%s)", i), nil) + yield(fmt.Sprintf("WEIGHT_STRING(%s)", i), nil, false) } } @@ -927,18 +927,18 @@ func FloatFormatting(yield Query) { } for _, f := range floats { - yield(fmt.Sprintf("%s + 0.0e0", f), nil) - yield(fmt.Sprintf("-%s", f), nil) + yield(fmt.Sprintf("%s + 0.0e0", f), nil, false) + yield(fmt.Sprintf("-%s", f), nil, false) } for i := 0; i < 64; i++ { v := uint64(1) << i - yield(fmt.Sprintf("%d + 0.0e0", v), nil) - yield(fmt.Sprintf("%d + 0.0e0", v+1), nil) - yield(fmt.Sprintf("%d + 0.0e0", ^v), nil) - yield(fmt.Sprintf("-%de0", v), nil) - yield(fmt.Sprintf("-%de0", v+1), nil) - yield(fmt.Sprintf("-%de0", ^v), nil) + yield(fmt.Sprintf("%d + 0.0e0", v), nil, false) + yield(fmt.Sprintf("%d + 0.0e0", v+1), nil, false) + yield(fmt.Sprintf("%d + 0.0e0", ^v), nil, false) + yield(fmt.Sprintf("-%de0", v), nil, false) + yield(fmt.Sprintf("-%de0", v+1), nil, false) + yield(fmt.Sprintf("-%de0", ^v), nil, false) } } @@ -962,7 +962,7 @@ func UnderscoreAndPercentage(yield Query) { `'poke\_mon' = 'poke\_mon'`, } for _, query := range queries { - yield(query, nil) + yield(query, nil, false) } } @@ -993,7 +993,7 @@ func Types(yield Query) { } for _, query := range queries { - yield(query, nil) + yield(query, nil, false) } } @@ -1003,13 +1003,13 @@ func Arithmetic(yield Query) { for _, op := range operators { for _, lhs := range inputConversions { for _, rhs := range inputConversions { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } for _, lhs := range inputBitwise { for _, rhs := range inputBitwise { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } } @@ -1025,9 +1025,9 @@ func HexArithmetic(yield Query) { for _, lhs := range cases { for _, rhs := range cases { - yield(fmt.Sprintf("%s + %s", lhs, rhs), nil) + yield(fmt.Sprintf("%s + %s", lhs, rhs), nil, false) // compare with negative values too - yield(fmt.Sprintf("-%s + -%s", lhs, rhs), nil) + yield(fmt.Sprintf("-%s + -%s", lhs, rhs), nil, false) } } } @@ -1055,7 +1055,7 @@ func NumericTypes(yield Query) { } for _, rhs := range numbers { - yield(rhs, nil) + yield(rhs, nil, false) } } @@ -1072,13 +1072,13 @@ func NegateArithmetic(yield Query) { } for _, rhs := range cases { - yield(fmt.Sprintf("- %s", rhs), nil) - yield(fmt.Sprintf("-%s", rhs), nil) + yield(fmt.Sprintf("- %s", rhs), nil, false) + yield(fmt.Sprintf("-%s", rhs), nil, false) } for _, rhs := range inputConversions { - yield(fmt.Sprintf("- %s", rhs), nil) - yield(fmt.Sprintf("-%s", rhs), nil) + yield(fmt.Sprintf("- %s", rhs), nil, false) + yield(fmt.Sprintf("-%s", rhs), nil, false) } } @@ -1092,7 +1092,7 @@ func CollationOperations(yield Query) { } for _, expr := range cases { - yield(expr, nil) + yield(expr, nil, false) } } @@ -1115,7 +1115,7 @@ func LikeComparison(yield Query) { for _, lhs := range left { for _, rhs := range right { for _, op := range []string{"LIKE", "NOT LIKE"} { - yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil) + yield(fmt.Sprintf("%s %s %s", lhs, op, rhs), nil, false) } } } @@ -1149,7 +1149,7 @@ func StrcmpComparison(yield Query) { for _, lhs := range inputs { for _, rhs := range inputs { - yield(fmt.Sprintf("STRCMP(%s, %s)", lhs, rhs), nil) + yield(fmt.Sprintf("STRCMP(%s, %s)", lhs, rhs), nil, false) } } } @@ -1168,7 +1168,7 @@ func MultiComparisons(yield Query) { `"0"`, `"-1"`, `"1"`, `_utf8mb4 'foobar'`, `_utf8mb4 'FOOBAR'`, `_binary '0'`, `_binary '-1'`, `_binary '1'`, - `0x0`, `0x1`, `-0x0`, `-0x1`, + `0x0`, `0x1`, "_utf8mb4 'Abc' COLLATE utf8mb4_0900_as_ci", "_utf8mb4 'aBC' COLLATE utf8mb4_0900_as_ci", "_utf8mb4 'ǍḄÇ' COLLATE utf8mb4_0900_as_ci", @@ -1183,17 +1183,37 @@ func MultiComparisons(yield Query) { "_utf8mb4 'ノ東京の' COLLATE utf8mb4_ja_0900_as_cs", "_utf8mb4 'の東京ノ' COLLATE utf8mb4_ja_0900_as_cs_ks", "_utf8mb4 'ノ東京の' COLLATE utf8mb4_ja_0900_as_cs_ks", + `date'2024-02-18'`, + `date'2023-02-01'`, + `date'2100-02-01'`, + `timestamp'2020-12-31 23:59:59'`, + `timestamp'2025-01-01 00:00:00.123456'`, + `time'23:59:59.5432'`, + `time'120:59:59'`, } for _, method := range []string{"LEAST", "GREATEST"} { + skip := func(arg []string) bool { + skipCollations := false + for _, a := range arg { + if strings.Contains(a, "date'") || strings.Contains(a, "time'") || strings.Contains(a, "timestamp'") { + skipCollations = true + break + } + } + return skipCollations + } + genSubsets(numbers, 2, func(arg []string) { - yield(fmt.Sprintf("%s(%s, %s)", method, arg[0], arg[1]), nil) - yield(fmt.Sprintf("%s(%s, %s)", method, arg[1], arg[0]), nil) + skipCollations := skip(arg) + yield(fmt.Sprintf("%s(%s, %s)", method, arg[0], arg[1]), nil, skipCollations) + yield(fmt.Sprintf("%s(%s, %s)", method, arg[1], arg[0]), nil, skipCollations) }) genSubsets(numbers, 3, func(arg []string) { - yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[0], arg[1], arg[2]), nil) - yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[2], arg[1], arg[0]), nil) + skipCollations := skip(arg) + yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[0], arg[1], arg[2]), nil, skipCollations) + yield(fmt.Sprintf("%s(%s, %s, %s)", method, arg[2], arg[1], arg[0]), nil, skipCollations) }) } } @@ -1213,7 +1233,7 @@ func IntervalStatement(yield Query) { for _, arg1 := range inputs { for _, arg2 := range inputs { for _, arg3 := range inputs { - yield(fmt.Sprintf("INTERVAL(%s, %s, %s, %s)", base, arg1, arg2, arg3), nil) + yield(fmt.Sprintf("INTERVAL(%s, %s, %s, %s)", base, arg1, arg2, arg3), nil, false) } } } @@ -1238,7 +1258,7 @@ func IsStatement(yield Query) { for _, l := range left { for _, r := range right { - yield(fmt.Sprintf("%s IS %s", l, r), nil) + yield(fmt.Sprintf("%s IS %s", l, r), nil, false) } } } @@ -1247,7 +1267,7 @@ func NotStatement(yield Query) { var ops = []string{"NOT", "!"} for _, op := range ops { for _, i := range inputConversions { - yield(fmt.Sprintf("%s %s", op, i), nil) + yield(fmt.Sprintf("%s %s", op, i), nil, false) } } } @@ -1257,7 +1277,7 @@ func LogicalStatement(yield Query) { for _, op := range ops { for _, l := range inputConversions { for _, r := range inputConversions { - yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + yield(fmt.Sprintf("%s %s %s", l, op, r), nil, false) } } } @@ -1275,7 +1295,7 @@ func TupleComparisons(yield Query) { for _, op := range operators { for i := 0; i < len(tuples); i++ { for j := 0; j < len(tuples); j++ { - yield(fmt.Sprintf("%s %s %s", tuples[i], op, tuples[j]), nil) + yield(fmt.Sprintf("%s %s %s", tuples[i], op, tuples[j]), nil, false) } } } @@ -1286,13 +1306,13 @@ func Comparisons(yield Query) { for _, op := range operators { for _, l := range inputComparisonElement { for _, r := range inputComparisonElement { - yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + yield(fmt.Sprintf("%s %s %s", l, op, r), nil, false) } } for _, l := range inputConversions { for _, r := range inputConversions { - yield(fmt.Sprintf("%s %s %s", l, op, r), nil) + yield(fmt.Sprintf("%s %s %s", l, op, r), nil, false) } } } @@ -1331,9 +1351,9 @@ func JSONExtract(yield Query) { expr2 := fmt.Sprintf("cast(%s as char) <=> %s", expr0, expr1) for _, row := range rows { - yield(expr0, []sqltypes.Value{row}) - yield(expr1, []sqltypes.Value{row}) - yield(expr2, []sqltypes.Value{row}) + yield(expr0, []sqltypes.Value{row}, false) + yield(expr1, []sqltypes.Value{row}, false) + yield(expr2, []sqltypes.Value{row}, false) } } } @@ -1350,7 +1370,7 @@ func FnField(yield Query) { for _, s1 := range inputStrings { for _, s2 := range inputStrings { for _, s3 := range inputStrings { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1358,7 +1378,7 @@ func FnField(yield Query) { for _, s1 := range radianInputs { for _, s2 := range radianInputs { for _, s3 := range radianInputs { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1367,7 +1387,7 @@ func FnField(yield Query) { for _, s1 := range inputStrings { for _, s2 := range radianInputs { for _, s3 := range inputStrings { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1376,7 +1396,7 @@ func FnField(yield Query) { for _, s1 := range inputBitwise { for _, s2 := range inputBitwise { for _, s3 := range inputBitwise { - yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil) + yield(fmt.Sprintf("FIELD(%s, %s, %s)", s1, s2, s3), nil, false) } } } @@ -1386,21 +1406,21 @@ func FnField(yield Query) { "FIELD('Gg', 'Aa', 'Bb', 'Cc', 'Dd', 'Ff')", } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnElt(yield Query) { for _, s1 := range inputStrings { for _, n := range inputBitwise { - yield(fmt.Sprintf("ELT(%s, %s)", n, s1), nil) + yield(fmt.Sprintf("ELT(%s, %s)", n, s1), nil, false) } } for _, s1 := range inputStrings { for _, s2 := range inputStrings { for _, n := range inputBitwise { - yield(fmt.Sprintf("ELT(%s, %s, %s)", n, s1, s2), nil) + yield(fmt.Sprintf("ELT(%s, %s, %s)", n, s1, s2), nil, false) } } } @@ -1414,7 +1434,7 @@ func FnElt(yield Query) { for _, s2 := range inputStrings { for _, s3 := range inputStrings { for _, n := range validIndex { - yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil) + yield(fmt.Sprintf("ELT(%s, %s, %s, %s)", n, s1, s2, s3), nil, false) } } } @@ -1426,7 +1446,7 @@ func FnElt(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } @@ -1435,7 +1455,7 @@ func FnInsert(yield Query) { for _, ns := range insertStrings { for _, l := range inputBitwise { for _, p := range inputBitwise { - yield(fmt.Sprintf("INSERT(%s, %s, %s, %s)", s, p, l, ns), nil) + yield(fmt.Sprintf("INSERT(%s, %s, %s, %s)", s, p, l, ns), nil, false) } } } @@ -1448,53 +1468,53 @@ func FnInsert(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnLower(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("LOWER(%s)", str), nil) - yield(fmt.Sprintf("LCASE(%s)", str), nil) + yield(fmt.Sprintf("LOWER(%s)", str), nil, false) + yield(fmt.Sprintf("LCASE(%s)", str), nil, false) } } func FnUpper(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("UPPER(%s)", str), nil) - yield(fmt.Sprintf("UCASE(%s)", str), nil) + yield(fmt.Sprintf("UPPER(%s)", str), nil, false) + yield(fmt.Sprintf("UCASE(%s)", str), nil, false) } } func FnCharLength(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("CHAR_LENGTH(%s)", str), nil) - yield(fmt.Sprintf("CHARACTER_LENGTH(%s)", str), nil) + yield(fmt.Sprintf("CHAR_LENGTH(%s)", str), nil, false) + yield(fmt.Sprintf("CHARACTER_LENGTH(%s)", str), nil, false) } } func FnLength(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("LENGTH(%s)", str), nil) - yield(fmt.Sprintf("OCTET_LENGTH(%s)", str), nil) + yield(fmt.Sprintf("LENGTH(%s)", str), nil, false) + yield(fmt.Sprintf("OCTET_LENGTH(%s)", str), nil, false) } } func FnBitLength(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("BIT_LENGTH(%s)", str), nil) + yield(fmt.Sprintf("BIT_LENGTH(%s)", str), nil, false) } } func FnAscii(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("ASCII(%s)", str), nil) + yield(fmt.Sprintf("ASCII(%s)", str), nil, false) } } func FnReverse(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("REVERSE(%s)", str), nil) + yield(fmt.Sprintf("REVERSE(%s)", str), nil, false) } } @@ -1516,13 +1536,13 @@ func FnSpace(yield Query) { } for _, c := range counts { - yield(fmt.Sprintf("SPACE(%s)", c), nil) + yield(fmt.Sprintf("SPACE(%s)", c), nil, false) } } func FnOrd(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("ORD(%s)", str), nil) + yield(fmt.Sprintf("ORD(%s)", str), nil, false) } } @@ -1530,7 +1550,7 @@ func FnRepeat(yield Query) { counts := []string{"-1", "1.9", "3", "1073741825", "'1.9'"} for _, str := range inputStrings { for _, cnt := range counts { - yield(fmt.Sprintf("REPEAT(%s, %s)", str, cnt), nil) + yield(fmt.Sprintf("REPEAT(%s, %s)", str, cnt), nil, false) } } } @@ -1539,7 +1559,7 @@ func FnLeft(yield Query) { counts := []string{"-1", "1.9", "3", "10", "'1.9'"} for _, str := range inputStrings { for _, cnt := range counts { - yield(fmt.Sprintf("LEFT(%s, %s)", str, cnt), nil) + yield(fmt.Sprintf("LEFT(%s, %s)", str, cnt), nil, false) } } } @@ -1549,7 +1569,7 @@ func FnLpad(yield Query) { for _, str := range inputStrings { for _, cnt := range counts { for _, pad := range inputStrings { - yield(fmt.Sprintf("LPAD(%s, %s, %s)", str, cnt, pad), nil) + yield(fmt.Sprintf("LPAD(%s, %s, %s)", str, cnt, pad), nil, false) } } } @@ -1559,7 +1579,7 @@ func FnRight(yield Query) { counts := []string{"-1", "1.9", "3", "10", "'1.9'"} for _, str := range inputStrings { for _, cnt := range counts { - yield(fmt.Sprintf("RIGHT(%s, %s)", str, cnt), nil) + yield(fmt.Sprintf("RIGHT(%s, %s)", str, cnt), nil, false) } } } @@ -1569,7 +1589,7 @@ func FnRpad(yield Query) { for _, str := range inputStrings { for _, cnt := range counts { for _, pad := range inputStrings { - yield(fmt.Sprintf("RPAD(%s, %s, %s)", str, cnt, pad), nil) + yield(fmt.Sprintf("RPAD(%s, %s, %s)", str, cnt, pad), nil, false) } } } @@ -1577,33 +1597,33 @@ func FnRpad(yield Query) { func FnLTrim(yield Query) { for _, str := range inputTrimStrings { - yield(fmt.Sprintf("LTRIM(%s)", str), nil) + yield(fmt.Sprintf("LTRIM(%s)", str), nil, false) } } func FnRTrim(yield Query) { for _, str := range inputTrimStrings { - yield(fmt.Sprintf("RTRIM(%s)", str), nil) + yield(fmt.Sprintf("RTRIM(%s)", str), nil, false) } } func FnTrim(yield Query) { for _, str := range inputTrimStrings { - yield(fmt.Sprintf("TRIM(%s)", str), nil) + yield(fmt.Sprintf("TRIM(%s)", str), nil, false) } modes := []string{"LEADING", "TRAILING", "BOTH"} for _, str := range inputTrimStrings { for _, mode := range modes { - yield(fmt.Sprintf("TRIM(%s FROM %s)", mode, str), nil) + yield(fmt.Sprintf("TRIM(%s FROM %s)", mode, str), nil, false) } } for _, str := range inputTrimStrings { for _, pat := range inputTrimStrings { - yield(fmt.Sprintf("TRIM(%s FROM %s)", pat, str), nil) + yield(fmt.Sprintf("TRIM(%s FROM %s)", pat, str), nil, false) for _, mode := range modes { - yield(fmt.Sprintf("TRIM(%s %s FROM %s)", mode, pat, str), nil) + yield(fmt.Sprintf("TRIM(%s %s FROM %s)", mode, pat, str), nil, false) } } } @@ -1628,15 +1648,15 @@ func FnSubstr(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, str := range inputStrings { for _, i := range radianInputs { - yield(fmt.Sprintf("SUBSTRING(%s, %s)", str, i), nil) + yield(fmt.Sprintf("SUBSTRING(%s, %s)", str, i), nil, false) for _, j := range radianInputs { - yield(fmt.Sprintf("SUBSTRING(%s, %s, %s)", str, i, j), nil) + yield(fmt.Sprintf("SUBSTRING(%s, %s, %s)", str, i, j), nil, false) } } } @@ -1654,17 +1674,17 @@ func FnLocate(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, substr := range locateStrings { for _, str := range locateStrings { - yield(fmt.Sprintf("LOCATE(%s, %s)", substr, str), nil) - yield(fmt.Sprintf("INSTR(%s, %s)", str, substr), nil) - yield(fmt.Sprintf("POSITION(%s IN %s)", str, substr), nil) + yield(fmt.Sprintf("LOCATE(%s, %s)", substr, str), nil, false) + yield(fmt.Sprintf("INSTR(%s, %s)", str, substr), nil, false) + yield(fmt.Sprintf("POSITION(%s IN %s)", str, substr), nil, false) for _, i := range radianInputs { - yield(fmt.Sprintf("LOCATE(%s, %s, %s)", substr, str, i), nil) + yield(fmt.Sprintf("LOCATE(%s, %s, %s)", substr, str, i), nil, false) } } } @@ -1685,13 +1705,13 @@ func FnReplace(yield Query) { } for _, q := range cases { - yield(q, nil) + yield(q, nil, false) } for _, substr := range inputStrings { for _, str := range inputStrings { for _, i := range inputStrings { - yield(fmt.Sprintf("REPLACE(%s, %s, %s)", substr, str, i), nil) + yield(fmt.Sprintf("REPLACE(%s, %s, %s)", substr, str, i), nil, false) } } } @@ -1699,19 +1719,19 @@ func FnReplace(yield Query) { func FnConcat(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("CONCAT(%s)", str), nil) + yield(fmt.Sprintf("CONCAT(%s)", str), nil, false) } for _, str1 := range inputConversions { for _, str2 := range inputConversions { - yield(fmt.Sprintf("CONCAT(%s, %s)", str1, str2), nil) + yield(fmt.Sprintf("CONCAT(%s, %s)", str1, str2), nil, false) } } for _, str1 := range inputStrings { for _, str2 := range inputStrings { for _, str3 := range inputStrings { - yield(fmt.Sprintf("CONCAT(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1719,13 +1739,13 @@ func FnConcat(yield Query) { func FnConcatWs(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("CONCAT_WS(%s, NULL)", str), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, NULL)", str), nil, false) } for _, str1 := range inputConversions { for _, str2 := range inputStrings { for _, str3 := range inputStrings { - yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1733,7 +1753,7 @@ func FnConcatWs(yield Query) { for _, str1 := range inputStrings { for _, str2 := range inputConversions { for _, str3 := range inputStrings { - yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1741,7 +1761,7 @@ func FnConcatWs(yield Query) { for _, str1 := range inputStrings { for _, str2 := range inputStrings { for _, str3 := range inputConversions { - yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil) + yield(fmt.Sprintf("CONCAT_WS(%s, %s, %s)", str1, str2, str3), nil, false) } } } @@ -1760,13 +1780,13 @@ func FnChar(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, i1 := range radianInputs { for _, i2 := range inputBitwise { for _, i3 := range inputConversions { - yield(fmt.Sprintf("CHAR(%s, %s, %s)", i1, i2, i3), nil) + yield(fmt.Sprintf("CHAR(%s, %s, %s)", i1, i2, i3), nil, false) } } } @@ -1774,15 +1794,15 @@ func FnChar(yield Query) { func FnHex(yield Query) { for _, str := range inputStrings { - yield(fmt.Sprintf("hex(%s)", str), nil) + yield(fmt.Sprintf("hex(%s)", str), nil, false) } for _, str := range inputConversions { - yield(fmt.Sprintf("hex(%s)", str), nil) + yield(fmt.Sprintf("hex(%s)", str), nil, false) } for _, str := range inputBitwise { - yield(fmt.Sprintf("hex(%s)", str), nil) + yield(fmt.Sprintf("hex(%s)", str), nil, false) } } @@ -1802,7 +1822,7 @@ func FnUnhex(yield Query) { } for _, lhs := range inputs { - yield(fmt.Sprintf("UNHEX(%s)", lhs), nil) + yield(fmt.Sprintf("UNHEX(%s)", lhs), nil, false) } } @@ -1814,15 +1834,15 @@ func InStatement(yield Query) { if !(bugs{}).CanCompare(inputs...) { return } - yield(fmt.Sprintf("%s IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil) - yield(fmt.Sprintf("%s IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) - yield(fmt.Sprintf("%s IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) - yield(fmt.Sprintf("%s IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil, false) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil, false) + yield(fmt.Sprintf("%s IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil, false) + yield(fmt.Sprintf("%s IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil, false) - yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil) - yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil) - yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil) - yield(fmt.Sprintf("%s NOT IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[0], inputs[1], inputs[2]), nil, false) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[2], inputs[1], inputs[0]), nil, false) + yield(fmt.Sprintf("%s NOT IN (%s, %s)", inputs[1], inputs[0], inputs[2]), nil, false) + yield(fmt.Sprintf("%s NOT IN (%s, %s, %s)", inputs[0], inputs[1], inputs[2], inputs[0]), nil, false) }) } @@ -1845,7 +1865,7 @@ func FnNow(yield Query) { "SYSDATE(1)", "SYSDATE(2)", "SYSDATE(3)", "SYSDATE(4)", "SYSDATE(5)", } for _, fn := range fns { - yield(fn, nil) + yield(fn, nil, false) } } @@ -1857,7 +1877,7 @@ func FnInfo(yield Query) { "VERSION()", } for _, fn := range fns { - yield(fn, nil) + yield(fn, nil, false) } } @@ -1871,7 +1891,7 @@ func FnDateFormat(yield Query) { format := buf.String() for _, d := range inputConversions { - yield(fmt.Sprintf("DATE_FORMAT(%s, %q)", d, format), nil) + yield(fmt.Sprintf("DATE_FORMAT(%s, %q)", d, format), nil, false) } } @@ -1897,7 +1917,7 @@ func FnConvertTz(yield Query) { for _, tzFrom := range timezoneInputs { for _, tzTo := range timezoneInputs { q := fmt.Sprintf("CONVERT_TZ(%s, '%s', '%s')", num1, tzFrom, tzTo) - yield(q, nil) + yield(q, nil, false) } } } @@ -1905,26 +1925,26 @@ func FnConvertTz(yield Query) { func FnDate(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DATE(%s)", d), nil) + yield(fmt.Sprintf("DATE(%s)", d), nil, false) } } func FnDayOfMonth(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DAYOFMONTH(%s)", d), nil) - yield(fmt.Sprintf("DAY(%s)", d), nil) + yield(fmt.Sprintf("DAYOFMONTH(%s)", d), nil, false) + yield(fmt.Sprintf("DAY(%s)", d), nil, false) } } func FnDayOfWeek(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DAYOFWEEK(%s)", d), nil) + yield(fmt.Sprintf("DAYOFWEEK(%s)", d), nil, false) } } func FnDayOfYear(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("DAYOFYEAR(%s)", d), nil) + yield(fmt.Sprintf("DAYOFYEAR(%s)", d), nil, false) } } @@ -1938,21 +1958,21 @@ func FnFromUnixtime(yield Query) { format := buf.String() for _, d := range inputConversions { - yield(fmt.Sprintf("FROM_UNIXTIME(%s)", d), nil) - yield(fmt.Sprintf("FROM_UNIXTIME(%s, %q)", d, format), nil) + yield(fmt.Sprintf("FROM_UNIXTIME(%s)", d), nil, false) + yield(fmt.Sprintf("FROM_UNIXTIME(%s, %q)", d, format), nil, false) } } func FnHour(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("HOUR(%s)", d), nil) + yield(fmt.Sprintf("HOUR(%s)", d), nil, false) } } func FnMakedate(yield Query) { for _, y := range inputConversions { for _, d := range inputConversions { - yield(fmt.Sprintf("MAKEDATE(%s, %s)", y, d), nil) + yield(fmt.Sprintf("MAKEDATE(%s, %s)", y, d), nil, false) } } } @@ -1969,7 +1989,7 @@ func FnMaketime(yield Query) { } for _, m := range minutes { for _, s := range inputConversions { - yield(fmt.Sprintf("MAKETIME(%s, %s, %s)", h, m, s), nil) + yield(fmt.Sprintf("MAKETIME(%s, %s, %s)", h, m, s), nil, false) } } } @@ -1977,31 +1997,31 @@ func FnMaketime(yield Query) { func FnMicroSecond(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MICROSECOND(%s)", d), nil) + yield(fmt.Sprintf("MICROSECOND(%s)", d), nil, false) } } func FnMinute(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MINUTE(%s)", d), nil) + yield(fmt.Sprintf("MINUTE(%s)", d), nil, false) } } func FnMonth(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MONTH(%s)", d), nil) + yield(fmt.Sprintf("MONTH(%s)", d), nil, false) } } func FnMonthName(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("MONTHNAME(%s)", d), nil) + yield(fmt.Sprintf("MONTHNAME(%s)", d), nil, false) } } func FnLastDay(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("LAST_DAY(%s)", d), nil) + yield(fmt.Sprintf("LAST_DAY(%s)", d), nil, false) } dates := []string{ @@ -2018,13 +2038,13 @@ func FnLastDay(yield Query) { } for _, d := range dates { - yield(fmt.Sprintf("LAST_DAY(%s)", d), nil) + yield(fmt.Sprintf("LAST_DAY(%s)", d), nil, false) } } func FnToDays(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("TO_DAYS(%s)", d), nil) + yield(fmt.Sprintf("TO_DAYS(%s)", d), nil, false) } dates := []string{ @@ -2042,13 +2062,13 @@ func FnToDays(yield Query) { } for _, d := range dates { - yield(fmt.Sprintf("TO_DAYS(%s)", d), nil) + yield(fmt.Sprintf("TO_DAYS(%s)", d), nil, false) } } func FnFromDays(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil) + yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil, false) } days := []string{ @@ -2064,13 +2084,13 @@ func FnFromDays(yield Query) { } for _, d := range days { - yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil) + yield(fmt.Sprintf("FROM_DAYS(%s)", d), nil, false) } } func FnSecToTime(yield Query) { for _, s := range inputConversions { - yield(fmt.Sprintf("SEC_TO_TIME(%s)", s), nil) + yield(fmt.Sprintf("SEC_TO_TIME(%s)", s), nil, false) } mysqlDocSamples := []string{ @@ -2079,13 +2099,13 @@ func FnSecToTime(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnTimeToSec(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("TIME_TO_SEC(%s)", d), nil) + yield(fmt.Sprintf("TIME_TO_SEC(%s)", d), nil, false) } time := []string{ @@ -2103,13 +2123,13 @@ func FnTimeToSec(yield Query) { } for _, t := range time { - yield(fmt.Sprintf("TIME_TO_SEC(%s)", t), nil) + yield(fmt.Sprintf("TIME_TO_SEC(%s)", t), nil, false) } } func FnToSeconds(yield Query) { for _, t := range inputConversions { - yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil) + yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil, false) } timeInputs := []string{ @@ -2127,7 +2147,7 @@ func FnToSeconds(yield Query) { } for _, t := range timeInputs { - yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil) + yield(fmt.Sprintf("TO_SECONDS(%s)", t), nil, false) } mysqlDocSamples := []string{ @@ -2137,25 +2157,25 @@ func FnToSeconds(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnQuarter(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("QUARTER(%s)", d), nil) + yield(fmt.Sprintf("QUARTER(%s)", d), nil, false) } } func FnSecond(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("SECOND(%s)", d), nil) + yield(fmt.Sprintf("SECOND(%s)", d), nil, false) } } func FnTime(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("TIME(%s)", d), nil) + yield(fmt.Sprintf("TIME(%s)", d), nil, false) } times := []string{ "'00:00:00'", @@ -2174,68 +2194,68 @@ func FnTime(yield Query) { } for _, d := range times { - yield(fmt.Sprintf("TIME(%s)", d), nil) + yield(fmt.Sprintf("TIME(%s)", d), nil, false) } } func FnUnixTimestamp(yield Query) { - yield("UNIX_TIMESTAMP()", nil) + yield("UNIX_TIMESTAMP()", nil, false) for _, d := range inputConversions { - yield(fmt.Sprintf("UNIX_TIMESTAMP(%s)", d), nil) - yield(fmt.Sprintf("UNIX_TIMESTAMP(%s) + 1", d), nil) + yield(fmt.Sprintf("UNIX_TIMESTAMP(%s)", d), nil, false) + yield(fmt.Sprintf("UNIX_TIMESTAMP(%s) + 1", d), nil, false) } } func FnWeek(yield Query) { for i := 0; i < 16; i++ { for _, d := range inputConversions { - yield(fmt.Sprintf("WEEK(%s, %d)", d, i), nil) + yield(fmt.Sprintf("WEEK(%s, %d)", d, i), nil, false) } } for _, d := range inputConversions { - yield(fmt.Sprintf("WEEK(%s)", d), nil) + yield(fmt.Sprintf("WEEK(%s)", d), nil, false) } } func FnWeekDay(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("WEEKDAY(%s)", d), nil) + yield(fmt.Sprintf("WEEKDAY(%s)", d), nil, false) } } func FnWeekOfYear(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("WEEKOFYEAR(%s)", d), nil) + yield(fmt.Sprintf("WEEKOFYEAR(%s)", d), nil, false) } } func FnYear(yield Query) { for _, d := range inputConversions { - yield(fmt.Sprintf("YEAR(%s)", d), nil) + yield(fmt.Sprintf("YEAR(%s)", d), nil, false) } } func FnYearWeek(yield Query) { for i := 0; i < 8; i++ { for _, d := range inputConversions { - yield(fmt.Sprintf("YEARWEEK(%s, %d)", d, i), nil) + yield(fmt.Sprintf("YEARWEEK(%s, %d)", d, i), nil, false) } } for _, d := range inputConversions { - yield(fmt.Sprintf("YEARWEEK(%s)", d), nil) + yield(fmt.Sprintf("YEARWEEK(%s)", d), nil, false) } } func FnPeriodAdd(yield Query) { for _, p := range inputBitwise { for _, m := range inputBitwise { - yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil) + yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil, false) } } for _, p := range inputPeriods { for _, m := range inputBitwise { - yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil) + yield(fmt.Sprintf("PERIOD_ADD(%s, %s)", p, m), nil, false) } } @@ -2244,19 +2264,19 @@ func FnPeriodAdd(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnPeriodDiff(yield Query) { for _, p1 := range inputBitwise { for _, p2 := range inputBitwise { - yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil) + yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil, false) } } for _, p1 := range inputPeriods { for _, p2 := range inputPeriods { - yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil) + yield(fmt.Sprintf("PERIOD_DIFF(%s, %s)", p1, p2), nil, false) } } @@ -2265,59 +2285,59 @@ func FnPeriodDiff(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } func FnInetAton(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET_ATON(%s)", d), nil) + yield(fmt.Sprintf("INET_ATON(%s)", d), nil, false) } } func FnInetNtoa(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET_NTOA(%s)", d), nil) - yield(fmt.Sprintf("INET_NTOA(INET_ATON(%s))", d), nil) + yield(fmt.Sprintf("INET_NTOA(%s)", d), nil, false) + yield(fmt.Sprintf("INET_NTOA(INET_ATON(%s))", d), nil, false) } } func FnInet6Aton(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET6_ATON(%s)", d), nil) + yield(fmt.Sprintf("INET6_ATON(%s)", d), nil, false) } } func FnInet6Ntoa(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("INET6_NTOA(%s)", d), nil) - yield(fmt.Sprintf("INET6_NTOA(INET6_ATON(%s))", d), nil) + yield(fmt.Sprintf("INET6_NTOA(%s)", d), nil, false) + yield(fmt.Sprintf("INET6_NTOA(INET6_ATON(%s))", d), nil, false) } } func FnIsIPv4(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV4(%s)", d), nil) + yield(fmt.Sprintf("IS_IPV4(%s)", d), nil, false) } } func FnIsIPv4Compat(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV4_COMPAT(%s)", d), nil) - yield(fmt.Sprintf("IS_IPV4_COMPAT(INET6_ATON(%s))", d), nil) + yield(fmt.Sprintf("IS_IPV4_COMPAT(%s)", d), nil, false) + yield(fmt.Sprintf("IS_IPV4_COMPAT(INET6_ATON(%s))", d), nil, false) } } func FnIsIPv4Mapped(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV4_MAPPED(%s)", d), nil) - yield(fmt.Sprintf("IS_IPV4_MAPPED(INET6_ATON(%s))", d), nil) + yield(fmt.Sprintf("IS_IPV4_MAPPED(%s)", d), nil, false) + yield(fmt.Sprintf("IS_IPV4_MAPPED(INET6_ATON(%s))", d), nil, false) } } func FnIsIPv6(yield Query) { for _, d := range ipInputs { - yield(fmt.Sprintf("IS_IPV6(%s)", d), nil) + yield(fmt.Sprintf("IS_IPV6(%s)", d), nil, false) } } @@ -2335,27 +2355,27 @@ func FnBinToUUID(yield Query) { "'2'", } for _, d := range uuidInputs { - yield(fmt.Sprintf("BIN_TO_UUID(%s)", d), nil) + yield(fmt.Sprintf("BIN_TO_UUID(%s)", d), nil, false) } for _, d := range uuidInputs { for _, a := range args { - yield(fmt.Sprintf("BIN_TO_UUID(%s, %s)", d, a), nil) + yield(fmt.Sprintf("BIN_TO_UUID(%s, %s)", d, a), nil, false) } } } func FnIsUUID(yield Query) { for _, d := range uuidInputs { - yield(fmt.Sprintf("IS_UUID(%s)", d), nil) + yield(fmt.Sprintf("IS_UUID(%s)", d), nil, false) } } func FnUUID(yield Query) { - yield("LENGTH(UUID())", nil) - yield("COLLATION(UUID())", nil) - yield("IS_UUID(UUID())", nil) - yield("LENGTH(UUID_TO_BIN(UUID())", nil) + yield("LENGTH(UUID())", nil, false) + yield("COLLATION(UUID())", nil, false) + yield("IS_UUID(UUID())", nil, false) + yield("LENGTH(UUID_TO_BIN(UUID())", nil, false) } func FnUUIDToBin(yield Query) { @@ -2372,12 +2392,12 @@ func FnUUIDToBin(yield Query) { "'2'", } for _, d := range uuidInputs { - yield(fmt.Sprintf("UUID_TO_BIN(%s)", d), nil) + yield(fmt.Sprintf("UUID_TO_BIN(%s)", d), nil, false) } for _, d := range uuidInputs { for _, a := range args { - yield(fmt.Sprintf("UUID_TO_BIN(%s, %s)", d, a), nil) + yield(fmt.Sprintf("UUID_TO_BIN(%s, %s)", d, a), nil, false) } } } @@ -2418,15 +2438,15 @@ func DateMath(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, d := range dates { for _, i := range inputIntervals { for _, v := range intervalValues { - yield(fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", d, v, i), nil) - yield(fmt.Sprintf("DATE_SUB(%s, INTERVAL %s %s)", d, v, i), nil) - yield(fmt.Sprintf("TIMESTAMPADD(%v, %s, %s)", i, v, d), nil) + yield(fmt.Sprintf("DATE_ADD(%s, INTERVAL %s %s)", d, v, i), nil, false) + yield(fmt.Sprintf("DATE_SUB(%s, INTERVAL %s %s)", d, v, i), nil, false) + yield(fmt.Sprintf("TIMESTAMPADD(%v, %s, %s)", i, v, d), nil, false) } } } @@ -2481,15 +2501,15 @@ func RegexpLike(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } for _, i := range regexInputs { for _, p := range regexInputs { - yield(fmt.Sprintf("%s REGEXP %s", i, p), nil) - yield(fmt.Sprintf("%s NOT REGEXP %s", i, p), nil) + yield(fmt.Sprintf("%s REGEXP %s", i, p), nil, false) + yield(fmt.Sprintf("%s NOT REGEXP %s", i, p), nil, false) for _, m := range regexMatchStrings { - yield(fmt.Sprintf("REGEXP_LIKE(%s, %s, %s)", i, p, m), nil) + yield(fmt.Sprintf("REGEXP_LIKE(%s, %s, %s)", i, p, m), nil, false) } } } @@ -2565,7 +2585,7 @@ func RegexpInstr(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } @@ -2632,7 +2652,7 @@ func RegexpSubstr(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } @@ -2712,6 +2732,6 @@ func RegexpReplace(yield Query) { } for _, q := range mysqlDocSamples { - yield(q, nil) + yield(q, nil, false) } } diff --git a/go/vt/vtgate/evalengine/testcases/helpers.go b/go/vt/vtgate/evalengine/testcases/helpers.go index 71602e12c1c..db5ad6475b4 100644 --- a/go/vt/vtgate/evalengine/testcases/helpers.go +++ b/go/vt/vtgate/evalengine/testcases/helpers.go @@ -30,7 +30,7 @@ import ( querypb "vitess.io/vitess/go/vt/proto/query" ) -type Query func(query string, row []sqltypes.Value) +type Query func(query string, row []sqltypes.Value, skipCollationCheck bool) type Runner func(yield Query) type TestCase struct { Run Runner