diff --git a/go/vt/vtgate/evalengine/compiler_asm.go b/go/vt/vtgate/evalengine/compiler_asm.go index 7dda215353f..815b80a6cfc 100644 --- a/go/vt/vtgate/evalengine/compiler_asm.go +++ b/go/vt/vtgate/evalengine/compiler_asm.go @@ -2670,6 +2670,30 @@ func (asm *assembler) Fn_MULTICMP_u(args int, lessThan bool) { }, "FN MULTICMP UINT64(SP-%d)...UINT64(SP-1)", args) } +func (asm *assembler) Fn_MULTICMP_temporal(args int, lessThan bool) { + asm.adjustStack(-(args - 1)) + + asm.emit(func(env *ExpressionEnv) int { + var x *evalTemporal + for sp := env.vm.sp - args; sp < env.vm.sp; sp++ { + if env.vm.stack[sp] == nil { + continue + } + if x == nil { + x = env.vm.stack[sp].(*evalTemporal) + continue + } + y := env.vm.stack[sp].(*evalTemporal) + if lessThan == (y.compare(x) < 0) { + x = y + } + } + env.vm.stack[env.vm.sp-args] = x + env.vm.sp -= args - 1 + return 1 + }, "FN MULTICMP TEMPORAL(SP-%d)...TEMPORAL(SP-1)", args) +} + func (asm *assembler) Fn_REPEAT(base sqltypes.Type, fallback sqltypes.Type) { asm.adjustStack(-1) diff --git a/go/vt/vtgate/evalengine/eval_temporal.go b/go/vt/vtgate/evalengine/eval_temporal.go index d73485441c3..e1331ce3148 100644 --- a/go/vt/vtgate/evalengine/eval_temporal.go +++ b/go/vt/vtgate/evalengine/eval_temporal.go @@ -29,7 +29,7 @@ func (e *evalTemporal) ToRawBytes() []byte { switch e.t { case sqltypes.Date: return e.dt.Date.Format() - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.Format(e.prec) case sqltypes.Time: return e.dt.Time.Format(e.prec) @@ -54,7 +54,7 @@ func (e *evalTemporal) toInt64() int64 { switch e.SQLType() { case sqltypes.Date: return e.dt.Date.FormatInt64() - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatInt64() case sqltypes.Time: return e.dt.Time.FormatInt64() @@ -67,7 +67,7 @@ func (e *evalTemporal) toFloat() float64 { switch e.SQLType() { case sqltypes.Date: return float64(e.dt.Date.FormatInt64()) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatFloat64() case sqltypes.Time: return e.dt.Time.FormatFloat64() @@ -80,7 +80,7 @@ func (e *evalTemporal) toDecimal() decimal.Decimal { switch e.SQLType() { case sqltypes.Date: return decimal.NewFromInt(e.dt.Date.FormatInt64()) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return e.dt.FormatDecimal() case sqltypes.Time: return e.dt.Time.FormatDecimal() @@ -93,7 +93,7 @@ func (e *evalTemporal) toJSON() *evalJSON { switch e.SQLType() { case sqltypes.Date: return json.NewDate(hack.String(e.dt.Date.Format())) - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: return json.NewDateTime(hack.String(e.dt.Format(datetime.DefaultPrecision))) case sqltypes.Time: return json.NewTime(hack.String(e.dt.Time.Format(datetime.DefaultPrecision))) @@ -103,8 +103,11 @@ func (e *evalTemporal) toJSON() *evalJSON { } func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal { + if l == -1 { + l = int(e.prec) + } switch e.SQLType() { - case sqltypes.Datetime, sqltypes.Date: + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Round(l), prec: uint8(l)} case sqltypes.Time: return &evalTemporal{t: sqltypes.Datetime, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)} @@ -113,9 +116,26 @@ func (e *evalTemporal) toDateTime(l int, now time.Time) *evalTemporal { } } +func (e *evalTemporal) toTimestamp(l int, now time.Time) *evalTemporal { + if l == -1 { + l = int(e.prec) + } + switch e.SQLType() { + case sqltypes.Datetime, sqltypes.Date, sqltypes.Timestamp: + return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Round(l), prec: uint8(l)} + case sqltypes.Time: + return &evalTemporal{t: sqltypes.Timestamp, dt: e.dt.Time.Round(l).ToDateTime(now), prec: uint8(l)} + default: + panic("unreachable") + } +} + func (e *evalTemporal) toTime(l int) *evalTemporal { + if l == -1 { + l = int(e.prec) + } switch e.SQLType() { - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: dt := datetime.DateTime{Time: e.dt.Time.Round(l)} return &evalTemporal{t: sqltypes.Time, dt: dt, prec: uint8(l)} case sqltypes.Date: @@ -130,7 +150,7 @@ func (e *evalTemporal) toTime(l int) *evalTemporal { func (e *evalTemporal) toDate(now time.Time) *evalTemporal { switch e.SQLType() { - case sqltypes.Datetime: + case sqltypes.Datetime, sqltypes.Timestamp: dt := datetime.DateTime{Date: e.dt.Date} return &evalTemporal{t: sqltypes.Date, dt: dt} case sqltypes.Date: @@ -148,6 +168,10 @@ func (e *evalTemporal) isZero() bool { return e.dt.IsZero() } +func (e *evalTemporal) compare(other *evalTemporal) int { + return e.dt.Compare(other.dt) +} + func (e *evalTemporal) addInterval(interval *datetime.Interval, coll collations.ID, now time.Time) eval { var tmp *evalTemporal var ok bool @@ -179,6 +203,13 @@ func newEvalDateTime(dt datetime.DateTime, l int, allowZero bool) *evalTemporal return &evalTemporal{t: sqltypes.Datetime, dt: dt.Round(l), prec: uint8(l)} } +func newEvalTimestamp(dt datetime.DateTime, l int, allowZero bool) *evalTemporal { + if !allowZero && dt.IsZero() { + return nil + } + return &evalTemporal{t: sqltypes.Timestamp, dt: dt.Round(l), prec: uint8(l)} +} + func newEvalDate(d datetime.Date, allowZero bool) *evalTemporal { if !allowZero && d.IsZero() { return nil @@ -387,6 +418,53 @@ func evalToDateTime(e eval, l int, now time.Time, allowZero bool) *evalTemporal return nil } +func evalToTimestamp(e eval, l int, now time.Time, allowZero bool) *evalTemporal { + switch e := e.(type) { + case *evalTemporal: + return e.toTimestamp(precision(l, int(e.prec)), now) + case *evalBytes: + if t, l, _ := datetime.ParseDateTime(e.string(), l); !t.IsZero() { + return newEvalTimestamp(t, l, allowZero) + } + if d, _ := datetime.ParseDate(e.string()); !d.IsZero() { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalInt64: + if t, ok := datetime.ParseDateTimeInt64(e.i); ok { + return newEvalTimestamp(t, precision(l, 0), allowZero) + } + if d, ok := datetime.ParseDateInt64(e.i); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalUint64: + if t, ok := datetime.ParseDateTimeInt64(int64(e.u)); ok { + return newEvalTimestamp(t, precision(l, 0), allowZero) + } + if d, ok := datetime.ParseDateInt64(int64(e.u)); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalFloat: + if t, l, ok := datetime.ParseDateTimeFloat(e.f, l); ok { + return newEvalTimestamp(t, l, allowZero) + } + if d, ok := datetime.ParseDateFloat(e.f); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalDecimal: + if t, l, ok := datetime.ParseDateTimeDecimal(e.dec, e.length, l); ok { + return newEvalTimestamp(t, l, allowZero) + } + if d, ok := datetime.ParseDateDecimal(e.dec); ok { + return newEvalTimestamp(datetime.DateTime{Date: d}, precision(l, 0), allowZero) + } + case *evalJSON: + if dt, ok := e.DateTime(); ok { + return newEvalTimestamp(dt, precision(l, datetime.DefaultPrecision), allowZero) + } + } + return nil +} + func evalToDate(e eval, now time.Time, allowZero bool) *evalTemporal { switch e := e.(type) { case *evalTemporal: diff --git a/go/vt/vtgate/evalengine/fn_compare.go b/go/vt/vtgate/evalengine/fn_compare.go index 1deec6752ef..b9b6857f0a8 100644 --- a/go/vt/vtgate/evalengine/fn_compare.go +++ b/go/vt/vtgate/evalengine/fn_compare.go @@ -19,7 +19,6 @@ package evalengine import ( "bytes" - "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/mysql/collations/charset" "vitess.io/vitess/go/mysql/collations/colldata" "vitess.io/vitess/go/sqltypes" @@ -32,7 +31,7 @@ type ( CallExpr } - multiComparisonFunc func(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) + multiComparisonFunc func(env *ExpressionEnv, args []eval, cmp int) (eval, error) builtinMultiComparison struct { CallExpr @@ -101,6 +100,10 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { decimals int text int binary int + datetime int + timestamp int + date int + time int ) /* @@ -114,7 +117,7 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { for _, arg := range args { if arg == nil { - return func(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) { + return func(_ *ExpressionEnv, _ []eval, _ int) (eval, error) { return nil, nil } } @@ -135,9 +138,33 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ } + case *evalTemporal: + switch arg.SQLType() { + case sqltypes.Datetime: + datetime++ + case sqltypes.Timestamp: + timestamp++ + case sqltypes.Date: + date++ + case sqltypes.Time: + time++ + } } } + if datetime > 0 { + return compareAllDatetime + } + if timestamp > 0 { + return compareAllTimestamp + } + if date > 0 { + return compareAllDate + } + if time > 0 { + return compareAllTime + } + if integersI+integersU == len(args) { if integersI == len(args) { return compareAllInteger_i @@ -165,7 +192,53 @@ func getMultiComparisonFunc(args []eval) multiComparisonFunc { panic("unexpected argument type") } -func compareAllInteger_u(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllTime(env *ExpressionEnv, args []eval, cmp int) (eval, error) { + for i := range args { + args[i] = evalToTime(args[i], -1) + } + return compareAllTemporal(args, cmp), nil +} + +func compareAllDate(env *ExpressionEnv, args []eval, cmp int) (eval, error) { + for i := range args { + args[i] = evalToDate(args[i], env.now, true) + } + return compareAllTemporal(args, cmp), nil +} + +func compareAllTimestamp(env *ExpressionEnv, args []eval, cmp int) (eval, error) { + for i := range args { + args[i] = evalToTimestamp(args[i], -1, env.now, true) + } + return compareAllTemporal(args, cmp), nil +} + +func compareAllDatetime(env *ExpressionEnv, args []eval, cmp int) (eval, error) { + for i := range args { + args[i] = evalToDateTime(args[i], -1, env.now, true) + } + return compareAllTemporal(args, cmp), nil +} + +func compareAllTemporal(args []eval, cmp int) *evalTemporal { + var x *evalTemporal + for _, arg := range args { + if arg == nil { + continue + } + if x == nil { + x = arg.(*evalTemporal) + continue + } + y := arg.(*evalTemporal) + if (cmp < 0) == (y.compare(x) < 0) { + x = y + } + } + return x +} + +func compareAllInteger_u(_ *ExpressionEnv, args []eval, cmp int) (eval, error) { x := args[0].(*evalUint64) for _, arg := range args[1:] { y := arg.(*evalUint64) @@ -176,7 +249,7 @@ func compareAllInteger_u(_ *collations.Environment, args []eval, cmp int) (eval, return x, nil } -func compareAllInteger_i(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllInteger_i(_ *ExpressionEnv, args []eval, cmp int) (eval, error) { x := args[0].(*evalInt64) for _, arg := range args[1:] { y := arg.(*evalInt64) @@ -187,7 +260,7 @@ func compareAllInteger_i(_ *collations.Environment, args []eval, cmp int) (eval, return x, nil } -func compareAllFloat(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllFloat(_ *ExpressionEnv, args []eval, cmp int) (eval, error) { candidateF, ok := evalToFloat(args[0]) if !ok { return nil, errDecimalOutOfRange @@ -212,7 +285,7 @@ func evalDecimalPrecision(e eval) int32 { return 0 } -func compareAllDecimal(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllDecimal(_ *ExpressionEnv, args []eval, cmp int) (eval, error) { decExtreme := evalToDecimal(args[0], 0, 0).dec precExtreme := evalDecimalPrecision(args[0]) @@ -229,12 +302,12 @@ func compareAllDecimal(_ *collations.Environment, args []eval, cmp int) (eval, e return newEvalDecimalWithPrec(decExtreme, precExtreme), nil } -func compareAllText(collationEnv *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllText(env *ExpressionEnv, args []eval, cmp int) (eval, error) { var charsets = make([]charset.Charset, 0, len(args)) var ca collationAggregation for _, arg := range args { col := evalCollation(arg) - if err := ca.add(col, collationEnv); err != nil { + if err := ca.add(col, env.collationEnv); err != nil { return nil, err } charsets = append(charsets, colldata.Lookup(col.Collation).Charset()) @@ -262,7 +335,7 @@ func compareAllText(collationEnv *collations.Environment, args []eval, cmp int) return newEvalText(b1, tc), nil } -func compareAllBinary(_ *collations.Environment, args []eval, cmp int) (eval, error) { +func compareAllBinary(_ *ExpressionEnv, args []eval, cmp int) (eval, error) { candidateB := args[0].ToRawBytes() for _, arg := range args[1:] { @@ -280,7 +353,7 @@ func (call *builtinMultiComparison) eval(env *ExpressionEnv) (eval, error) { if err != nil { return nil, err } - return getMultiComparisonFunc(args)(env.collationEnv, args, call.cmp) + return getMultiComparisonFunc(args)(env, args, call.cmp) } func (call *builtinMultiComparison) compile_c(c *compiler, args []ctype) (ctype, error) { @@ -314,14 +387,18 @@ func (call *builtinMultiComparison) compile_d(c *compiler, args []ctype) (ctype, func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { var ( - signed int - unsigned int - floats int - decimals int - text int - binary int - args []ctype - nullable bool + signed int + unsigned int + floats int + decimals int + date int + datetime int + timestamp int + time int + text int + binary int + args []ctype + nullable bool ) /* @@ -355,6 +432,14 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { text++ case sqltypes.Blob, sqltypes.Binary, sqltypes.VarBinary: binary++ + case sqltypes.Date: + date++ + case sqltypes.Datetime: + datetime++ + case sqltypes.Timestamp: + timestamp++ + case sqltypes.Time: + time++ case sqltypes.Null: nullable = true default: @@ -366,6 +451,42 @@ func (call *builtinMultiComparison) compile(c *compiler) (ctype, error) { if nullable { f |= flagNullable } + if datetime > 0 { + for i, tt := range args { + if tt.Type != sqltypes.Datetime { + c.compileToDateTime(tt, len(args)-i, -1) + } + } + c.asm.Fn_MULTICMP_temporal(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Datetime, Flag: f, Col: collationBinary}, nil + } + if timestamp > 0 { + for i, tt := range args { + if tt.Type != sqltypes.Timestamp { + c.compileToDateTime(tt, len(args)-i, -1) + } + } + c.asm.Fn_MULTICMP_temporal(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Timestamp, Flag: f, Col: collationBinary}, nil + } + if date > 0 { + for i, tt := range args { + if tt.Type != sqltypes.Date { + c.compileToDateTime(tt, len(args)-i, -1) + } + } + c.asm.Fn_MULTICMP_temporal(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Date, Flag: f, Col: collationBinary}, nil + } + if time > 0 { + for i, tt := range args { + if tt.Type != sqltypes.Time { + c.compileToDateTime(tt, len(args)-i, -1) + } + } + c.asm.Fn_MULTICMP_temporal(len(args), call.cmp < 0) + return ctype{Type: sqltypes.Time, Flag: f, Col: collationBinary}, nil + } if signed+unsigned == len(args) { if signed == len(args) { c.asm.Fn_MULTICMP_i(len(args), call.cmp < 0) diff --git a/go/vt/vtgate/evalengine/integration/fuzz_test.go b/go/vt/vtgate/evalengine/integration/fuzz_test.go index a16f6164b35..32340a849fb 100644 --- a/go/vt/vtgate/evalengine/integration/fuzz_test.go +++ b/go/vt/vtgate/evalengine/integration/fuzz_test.go @@ -139,11 +139,6 @@ func evaluateLocalEvalengine(env *evalengine.ExpressionEnv, query string, fields astExpr := stmt.(*sqlparser.Select).SelectExprs.Exprs[0].(*sqlparser.AliasedExpr).Expr local, err := func() (expr evalengine.Expr, err error) { - defer func() { - if r := recover(); r != nil { - err = fmt.Errorf("PANIC during translate: %v", r) - } - }() cfg := &evalengine.Config{ ResolveColumn: evalengine.FieldResolver(fields).Column, Collation: collations.CollationUtf8mb4ID, diff --git a/go/vt/vtgate/evalengine/testcases/cases.go b/go/vt/vtgate/evalengine/testcases/cases.go index ff6c0c0f311..56455bcf3cb 100644 --- a/go/vt/vtgate/evalengine/testcases/cases.go +++ b/go/vt/vtgate/evalengine/testcases/cases.go @@ -1183,6 +1183,11 @@ func MultiComparisons(yield Query) { "_utf8mb4 'ノ東京の' COLLATE utf8mb4_ja_0900_as_cs", "_utf8mb4 'の東京ノ' COLLATE utf8mb4_ja_0900_as_cs_ks", "_utf8mb4 'ノ東京の' COLLATE utf8mb4_ja_0900_as_cs_ks", + `date'2024-02-18'`, + `date'2023-02-01'`, + `date'2100-02-01'`, + `timestamp'2020-12-31 23:59:59'`, + `timestamp'2025-01-01 00:00:00'`, } for _, method := range []string{"LEAST", "GREATEST"} {