From 19890cfb8ef7bd486e475b58a713d89fc2c29282 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Andr=C3=A9s=20Taylor?= Date: Tue, 28 Jan 2025 11:09:53 +0100 Subject: [PATCH] Optimise AST rewriting (#17623) Signed-off-by: Andres Taylor Signed-off-by: Harshit Gangal Co-authored-by: Harshit Gangal --- go/vt/sqlparser/ast_rewriting.go | 557 -------------- go/vt/sqlparser/ast_rewriting_test.go | 565 -------------- go/vt/sqlparser/bind_var_needs.go | 18 +- go/vt/sqlparser/normalizer.go | 786 ++++++++++++++++---- go/vt/sqlparser/normalizer_test.go | 612 ++++++++++++++- go/vt/sqlparser/redact_query.go | 4 +- go/vt/sqlparser/utils.go | 5 +- go/vt/vtgate/executor_test.go | 24 + go/vt/vtgate/planbuilder/builder.go | 6 +- go/vt/vtgate/planbuilder/simplifier_test.go | 12 +- go/vt/vtgate/semantics/typer_test.go | 17 +- 11 files changed, 1279 insertions(+), 1327 deletions(-) delete mode 100644 go/vt/sqlparser/ast_rewriting.go delete mode 100644 go/vt/sqlparser/ast_rewriting_test.go diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go deleted file mode 100644 index 05e7e290fc1..00000000000 --- a/go/vt/sqlparser/ast_rewriting.go +++ /dev/null @@ -1,557 +0,0 @@ -/* -Copyright 2020 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sqlparser - -import ( - "strconv" - "strings" - - querypb "vitess.io/vitess/go/vt/proto/query" - vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" - "vitess.io/vitess/go/vt/sysvars" - "vitess.io/vitess/go/vt/vterrors" -) - -var HasValueSubQueryBaseName = []byte("__sq_has_values") - -// SQLSelectLimitUnset default value for sql_select_limit not set. -const SQLSelectLimitUnset = -1 - -// RewriteASTResult contains the rewritten ast and meta information about it -type RewriteASTResult struct { - *BindVarNeeds - AST Statement // The rewritten AST -} - -type VSchemaViews interface { - FindView(name TableName) TableStatement -} - -// PrepareAST will normalize the query -func PrepareAST( - in Statement, - reservedVars *ReservedVars, - bindVars map[string]*querypb.BindVariable, - parameterize bool, - keyspace string, - selectLimit int, - setVarComment string, - sysVars map[string]string, - fkChecksState *bool, - views VSchemaViews, -) (*RewriteASTResult, error) { - if parameterize { - err := Normalize(in, reservedVars, bindVars) - if err != nil { - return nil, err - } - } - return RewriteAST(in, keyspace, selectLimit, setVarComment, sysVars, fkChecksState, views) -} - -// RewriteAST rewrites the whole AST, replacing function calls and adding column aliases to queries. -// SET_VAR comments are also added to the AST if required. -func RewriteAST( - in Statement, - keyspace string, - selectLimit int, - setVarComment string, - sysVars map[string]string, - fkChecksState *bool, - views VSchemaViews, -) (*RewriteASTResult, error) { - er := newASTRewriter(keyspace, selectLimit, setVarComment, sysVars, fkChecksState, views) - er.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in) - result := SafeRewrite(in, er.rewriteDown, er.rewriteUp) - if er.err != nil { - return nil, er.err - } - - out, ok := result.(Statement) - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "statement rewriting returned a non statement: %s", String(out)) - } - - r := &RewriteASTResult{ - AST: out, - BindVarNeeds: er.bindVars, - } - return r, nil -} - -func shouldRewriteDatabaseFunc(in Statement) bool { - selct, ok := in.(*Select) - if !ok { - return false - } - if len(selct.From) != 1 { - return false - } - aliasedTable, ok := selct.From[0].(*AliasedTableExpr) - if !ok { - return false - } - tableName, ok := aliasedTable.Expr.(TableName) - if !ok { - return false - } - return tableName.Name.String() == "dual" -} - -type astRewriter struct { - bindVars *BindVarNeeds - shouldRewriteDatabaseFunc bool - err error - - // we need to know this to make a decision if we can safely rewrite JOIN USING => JOIN ON - hasStarInSelect bool - - keyspace string - selectLimit int - setVarComment string - fkChecksState *bool - sysVars map[string]string - views VSchemaViews -} - -func newASTRewriter(keyspace string, selectLimit int, setVarComment string, sysVars map[string]string, fkChecksState *bool, views VSchemaViews) *astRewriter { - return &astRewriter{ - bindVars: &BindVarNeeds{}, - keyspace: keyspace, - selectLimit: selectLimit, - setVarComment: setVarComment, - fkChecksState: fkChecksState, - sysVars: sysVars, - views: views, - } -} - -const ( - // LastInsertIDName is a reserved bind var name for last_insert_id() - LastInsertIDName = "__lastInsertId" - - // DBVarName is a reserved bind var name for database() - DBVarName = "__vtdbname" - - // FoundRowsName is a reserved bind var name for found_rows() - FoundRowsName = "__vtfrows" - - // RowCountName is a reserved bind var name for row_count() - RowCountName = "__vtrcount" - - // UserDefinedVariableName is what we prepend bind var names for user defined variables - UserDefinedVariableName = "__vtudv" -) - -func (er *astRewriter) rewriteAliasedExpr(node *AliasedExpr) (*BindVarNeeds, error) { - inner := newASTRewriter(er.keyspace, er.selectLimit, er.setVarComment, er.sysVars, nil, er.views) - inner.shouldRewriteDatabaseFunc = er.shouldRewriteDatabaseFunc - tmp := SafeRewrite(node.Expr, inner.rewriteDown, inner.rewriteUp) - newExpr, ok := tmp.(Expr) - if !ok { - return nil, vterrors.Errorf(vtrpcpb.Code_INTERNAL, "failed to rewrite AST. function expected to return Expr returned a %s", String(tmp)) - } - node.Expr = newExpr - return inner.bindVars, nil -} - -func (er *astRewriter) rewriteDown(node SQLNode, _ SQLNode) bool { - switch node := node.(type) { - case *Select: - er.visitSelect(node) - case *PrepareStmt, *ExecuteStmt: - return false // nothing to rewrite here. - } - return true -} - -func (er *astRewriter) rewriteUp(cursor *Cursor) bool { - // Add SET_VAR comment to this node if it supports it and is needed - if supportOptimizerHint, supportsOptimizerHint := cursor.Node().(SupportOptimizerHint); supportsOptimizerHint { - if er.setVarComment != "" { - newComments, err := supportOptimizerHint.GetParsedComments().AddQueryHint(er.setVarComment) - if err != nil { - er.err = err - return false - } - supportOptimizerHint.SetComments(newComments) - } - if er.fkChecksState != nil { - newComments := supportOptimizerHint.GetParsedComments().SetMySQLSetVarValue(sysvars.ForeignKeyChecks, FkChecksStateString(er.fkChecksState)) - supportOptimizerHint.SetComments(newComments) - } - } - - switch node := cursor.Node().(type) { - case *Union: - er.rewriteUnion(node) - case *FuncExpr: - er.funcRewrite(cursor, node) - case *Variable: - er.rewriteVariable(cursor, node) - case *Subquery: - er.unnestSubQueries(cursor, node) - case *NotExpr: - er.rewriteNotExpr(cursor, node) - case *AliasedTableExpr: - er.rewriteAliasedTable(cursor, node) - case *ShowBasic: - er.rewriteShowBasic(node) - case *ExistsExpr: - er.existsRewrite(cursor, node) - case DistinctableAggr: - er.rewriteDistinctableAggr(cursor, node) - } - return true -} - -func (er *astRewriter) rewriteUnion(node *Union) { - // set select limit if explicitly not set when sql_select_limit is set on the connection. - if er.selectLimit > 0 && node.Limit == nil { - node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} - } -} - -func (er *astRewriter) rewriteAliasedTable(cursor *Cursor, node *AliasedTableExpr) { - aliasTableName, ok := node.Expr.(TableName) - if !ok { - return - } - - // Qualifier should not be added to dual table - tblName := aliasTableName.Name.String() - if tblName == "dual" { - return - } - - if SystemSchema(er.keyspace) { - if aliasTableName.Qualifier.IsEmpty() { - aliasTableName.Qualifier = NewIdentifierCS(er.keyspace) - node.Expr = aliasTableName - cursor.Replace(node) - } - return - } - - // Could we be dealing with a view? - if er.views == nil { - return - } - view := er.views.FindView(aliasTableName) - if view == nil { - return - } - - // Aha! It's a view. Let's replace it with a derived table - node.Expr = &DerivedTable{Select: Clone(view)} // TODO: this is a bit hacky. We want to update the schema def so it contains new types - if node.As.IsEmpty() { - node.As = NewIdentifierCS(tblName) - } -} - -func (er *astRewriter) rewriteShowBasic(node *ShowBasic) { - if node.Command == VariableGlobal || node.Command == VariableSession { - varsToAdd := sysvars.GetInterestingVariables() - for _, sysVar := range varsToAdd { - er.bindVars.AddSysVar(sysVar) - } - } -} - -func (er *astRewriter) rewriteNotExpr(cursor *Cursor, node *NotExpr) { - switch inner := node.Expr.(type) { - case *ComparisonExpr: - // not col = 42 => col != 42 - // not col > 42 => col <= 42 - // etc - canChange, inverse := inverseOp(inner.Operator) - if canChange { - inner.Operator = inverse - cursor.Replace(inner) - } - case *NotExpr: - // not not true => true - cursor.Replace(inner.Expr) - case BoolVal: - // not true => false - inner = !inner - cursor.Replace(inner) - } -} - -func (er *astRewriter) rewriteVariable(cursor *Cursor, node *Variable) { - // Iff we are in SET, we want to change the scope of variables if a modifier has been set - // and only on the lhs of the assignment: - // set session sql_mode = @someElse - // here we need to change the scope of `sql_mode` and not of `@someElse` - if v, isSet := cursor.Parent().(*SetExpr); isSet && v.Var == node { - return - } - // no rewriting for global scope variable. - // this should be returned from the underlying database. - switch node.Scope { - case VariableScope: - er.udvRewrite(cursor, node) - case SessionScope, NextTxScope: - er.sysVarRewrite(cursor, node) - } -} - -func (er *astRewriter) visitSelect(node *Select) { - for _, col := range node.SelectExprs { - if _, hasStar := col.(*StarExpr); hasStar { - er.hasStarInSelect = true - continue - } - - aliasedExpr, ok := col.(*AliasedExpr) - if !ok || aliasedExpr.As.NotEmpty() { - continue - } - buf := NewTrackedBuffer(nil) - aliasedExpr.Expr.Format(buf) - // select last_insert_id() -> select :__lastInsertId as `last_insert_id()` - innerBindVarNeeds, err := er.rewriteAliasedExpr(aliasedExpr) - if err != nil { - er.err = err - return - } - if innerBindVarNeeds.HasRewrites() { - aliasedExpr.As = NewIdentifierCI(buf.String()) - } - er.bindVars.MergeWith(innerBindVarNeeds) - - } - // set select limit if explicitly not set when sql_select_limit is set on the connection. - if er.selectLimit > 0 && node.Limit == nil { - node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(er.selectLimit))} - } -} - -func inverseOp(i ComparisonExprOperator) (bool, ComparisonExprOperator) { - switch i { - case EqualOp: - return true, NotEqualOp - case LessThanOp: - return true, GreaterEqualOp - case GreaterThanOp: - return true, LessEqualOp - case LessEqualOp: - return true, GreaterThanOp - case GreaterEqualOp: - return true, LessThanOp - case NotEqualOp: - return true, EqualOp - case InOp: - return true, NotInOp - case NotInOp: - return true, InOp - case LikeOp: - return true, NotLikeOp - case NotLikeOp: - return true, LikeOp - case RegexpOp: - return true, NotRegexpOp - case NotRegexpOp: - return true, RegexpOp - } - - return false, i -} - -func (er *astRewriter) sysVarRewrite(cursor *Cursor, node *Variable) { - lowered := node.Name.Lowered() - - var found bool - if er.sysVars != nil { - _, found = er.sysVars[lowered] - } - - switch lowered { - case sysvars.Autocommit.Name, - sysvars.Charset.Name, - sysvars.ClientFoundRows.Name, - sysvars.DDLStrategy.Name, - sysvars.MigrationContext.Name, - sysvars.Names.Name, - sysvars.TransactionMode.Name, - sysvars.ReadAfterWriteGTID.Name, - sysvars.ReadAfterWriteTimeOut.Name, - sysvars.SessionEnableSystemSettings.Name, - sysvars.SessionTrackGTIDs.Name, - sysvars.SessionUUID.Name, - sysvars.SkipQueryPlanCache.Name, - sysvars.Socket.Name, - sysvars.SQLSelectLimit.Name, - sysvars.Version.Name, - sysvars.VersionComment.Name, - sysvars.QueryTimeout.Name, - sysvars.Workload.Name: - found = true - } - - if found { - cursor.Replace(bindVarExpression("__vt" + lowered)) - er.bindVars.AddSysVar(lowered) - } -} - -func (er *astRewriter) udvRewrite(cursor *Cursor, node *Variable) { - udv := strings.ToLower(node.Name.CompliantName()) - cursor.Replace(bindVarExpression(UserDefinedVariableName + udv)) - er.bindVars.AddUserDefVar(udv) -} - -var funcRewrites = map[string]string{ - "last_insert_id": LastInsertIDName, - "database": DBVarName, - "schema": DBVarName, - "found_rows": FoundRowsName, - "row_count": RowCountName, -} - -func (er *astRewriter) funcRewrite(cursor *Cursor, node *FuncExpr) { - lowered := node.Name.Lowered() - if lowered == "last_insert_id" && len(node.Exprs) > 0 { - // if we are dealing with is LAST_INSERT_ID() with an argument, we don't need to rewrite it. - // with an argument, this is an identity function that will update the session state and - // sets the correct fields in the OK TCP packet that we send back - return - } - bindVar, found := funcRewrites[lowered] - if !found || (bindVar == DBVarName && !er.shouldRewriteDatabaseFunc) { - return - } - if len(node.Exprs) > 0 { - er.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", lowered) - return - } - cursor.Replace(bindVarExpression(bindVar)) - er.bindVars.AddFuncResult(bindVar) -} - -func (er *astRewriter) unnestSubQueries(cursor *Cursor, subquery *Subquery) { - if _, isExists := cursor.Parent().(*ExistsExpr); isExists { - return - } - sel, isSimpleSelect := subquery.Select.(*Select) - if !isSimpleSelect { - return - } - - if len(sel.SelectExprs) != 1 || - len(sel.OrderBy) != 0 || - sel.GroupBy != nil || - len(sel.From) != 1 || - sel.Where != nil || - sel.Having != nil || - sel.Limit != nil || sel.Lock != NoLock { - return - } - - aliasedTable, ok := sel.From[0].(*AliasedTableExpr) - if !ok { - return - } - table, ok := aliasedTable.Expr.(TableName) - if !ok || table.Name.String() != "dual" { - return - } - expr, ok := sel.SelectExprs[0].(*AliasedExpr) - if !ok { - return - } - _, isColName := expr.Expr.(*ColName) - if isColName { - // If we find a single col-name in a `dual` subquery, we can be pretty sure the user is returning a column - // already projected. - // `select 1 as x, (select x)` - // is perfectly valid - any aliased columns to the left are available inside subquery scopes - return - } - er.bindVars.NoteRewrite() - // we need to make sure that the inner expression also gets rewritten, - // so we fire off another rewriter traversal here - rewritten := SafeRewrite(expr.Expr, er.rewriteDown, er.rewriteUp) - - // Here we need to handle the subquery rewrite in case in occurs in an IN clause - // For example, SELECT id FROM user WHERE id IN (SELECT 1 FROM DUAL) - // Here we cannot rewrite the query to SELECT id FROM user WHERE id IN 1, since that is syntactically wrong - // We must rewrite it to SELECT id FROM user WHERE id IN (1) - // Find more cases in the test file - rewrittenExpr, isExpr := rewritten.(Expr) - _, isColTuple := rewritten.(ColTuple) - comparisonExpr, isCompExpr := cursor.Parent().(*ComparisonExpr) - // Check that the parent is a comparison operator with IN or NOT IN operation. - // Also, if rewritten is already a ColTuple (like a subquery), then we do not need this - // We also need to check that rewritten is an Expr, if it is then we can rewrite it as a ValTuple - if isCompExpr && (comparisonExpr.Operator == InOp || comparisonExpr.Operator == NotInOp) && !isColTuple && isExpr { - cursor.Replace(ValTuple{rewrittenExpr}) - return - } - - cursor.Replace(rewritten) -} - -func (er *astRewriter) existsRewrite(cursor *Cursor, node *ExistsExpr) { - sel, ok := node.Subquery.Select.(*Select) - if !ok { - return - } - - if sel.Having != nil { - // If the query has HAVING, we can't take any shortcuts - return - } - - if sel.GroupBy == nil && sel.SelectExprs.AllAggregation() { - // in these situations, we are guaranteed to always get a non-empty result, - // so we can replace the EXISTS with a literal true - cursor.Replace(BoolVal(true)) - } - - // If we are not doing HAVING, we can safely replace all select expressions with a - // single `1` and remove any grouping - sel.SelectExprs = SelectExprs{ - &AliasedExpr{Expr: NewIntLiteral("1")}, - } - sel.GroupBy = nil -} - -// rewriteDistinctableAggr removed Distinct from Max and Min Aggregations as it does not impact the result. But, makes the plan simpler. -func (er *astRewriter) rewriteDistinctableAggr(cursor *Cursor, node DistinctableAggr) { - if !node.IsDistinct() { - return - } - switch aggr := node.(type) { - case *Max, *Min: - aggr.SetDistinct(false) - er.bindVars.NoteRewrite() - } -} - -func bindVarExpression(name string) Expr { - return NewArgument(name) -} - -// SystemSchema returns true if the schema passed is system schema -func SystemSchema(schema string) bool { - return strings.EqualFold(schema, "information_schema") || - strings.EqualFold(schema, "performance_schema") || - strings.EqualFold(schema, "sys") || - strings.EqualFold(schema, "mysql") -} diff --git a/go/vt/sqlparser/ast_rewriting_test.go b/go/vt/sqlparser/ast_rewriting_test.go deleted file mode 100644 index 8b3e3d44c54..00000000000 --- a/go/vt/sqlparser/ast_rewriting_test.go +++ /dev/null @@ -1,565 +0,0 @@ -/* -Copyright 2019 The Vitess Authors. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -*/ - -package sqlparser - -import ( - "fmt" - "testing" - - "github.com/stretchr/testify/assert" - - "vitess.io/vitess/go/vt/sysvars" - - "github.com/stretchr/testify/require" -) - -type testCaseSetVar struct { - in, expected, setVarComment string -} - -type testCaseSysVar struct { - in, expected string - sysVar map[string]string -} - -type myTestCase struct { - in, expected string - liid, db, foundRows, rowCount, rawGTID, rawTimeout, sessTrackGTID bool - ddlStrategy, migrationContext, sessionUUID, sessionEnableSystemSettings bool - udv int - autocommit, foreignKeyChecks, clientFoundRows, skipQueryPlanCache, socket, queryTimeout bool - sqlSelectLimit, transactionMode, workload, version, versionComment bool -} - -func TestRewrites(in *testing.T) { - tests := []myTestCase{{ - in: "SELECT 42", - expected: "SELECT 42", - // no bindvar needs - }, { - in: "SELECT @@version", - expected: "SELECT :__vtversion as `@@version`", - version: true, - }, { - in: "SELECT @@query_timeout", - expected: "SELECT :__vtquery_timeout as `@@query_timeout`", - queryTimeout: true, - }, { - in: "SELECT @@version_comment", - expected: "SELECT :__vtversion_comment as `@@version_comment`", - versionComment: true, - }, { - in: "SELECT @@enable_system_settings", - expected: "SELECT :__vtenable_system_settings as `@@enable_system_settings`", - sessionEnableSystemSettings: true, - }, { - in: "SELECT last_insert_id()", - expected: "SELECT :__lastInsertId as `last_insert_id()`", - liid: true, - }, { - in: "SELECT database()", - expected: "SELECT :__vtdbname as `database()`", - db: true, - }, { - in: "SELECT database() from test", - expected: "SELECT database() from test", - // no bindvar needs - }, { - in: "SELECT last_insert_id() as test", - expected: "SELECT :__lastInsertId as test", - liid: true, - }, { - in: "SELECT last_insert_id() + database()", - expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`", - db: true, liid: true, - }, { - // unnest database() call - in: "select (select database()) from test", - expected: "select database() as `(select database() from dual)` from test", - // no bindvar needs - }, { - // unnest database() call - in: "select (select database() from dual) from test", - expected: "select database() as `(select database() from dual)` from test", - // no bindvar needs - }, { - in: "select (select database() from dual) from dual", - expected: "select :__vtdbname as `(select database() from dual)` from dual", - db: true, - }, { - // don't unnest solo columns - in: "select 1 as foobar, (select foobar)", - expected: "select 1 as foobar, (select foobar from dual) from dual", - }, { - in: "select id from user where database()", - expected: "select id from user where database()", - // no bindvar needs - }, { - in: "select table_name from information_schema.tables where table_schema = database()", - expected: "select table_name from information_schema.tables where table_schema = database()", - // no bindvar needs - }, { - in: "select schema()", - expected: "select :__vtdbname as `schema()`", - db: true, - }, { - in: "select found_rows()", - expected: "select :__vtfrows as `found_rows()`", - foundRows: true, - }, { - in: "select @`x y`", - expected: "select :__vtudvx_y as `@``x y``` from dual", - udv: 1, - }, { - in: "select id from t where id = @x and val = @y", - expected: "select id from t where id = :__vtudvx and val = :__vtudvy", - db: false, udv: 2, - }, { - in: "insert into t(id) values(@xyx)", - expected: "insert into t(id) values(:__vtudvxyx)", - db: false, udv: 1, - }, { - in: "select row_count()", - expected: "select :__vtrcount as `row_count()`", - rowCount: true, - }, { - in: "SELECT lower(database())", - expected: "SELECT lower(:__vtdbname) as `lower(database())`", - db: true, - }, { - in: "SELECT @@autocommit", - expected: "SELECT :__vtautocommit as `@@autocommit`", - autocommit: true, - }, { - in: "SELECT @@client_found_rows", - expected: "SELECT :__vtclient_found_rows as `@@client_found_rows`", - clientFoundRows: true, - }, { - in: "SELECT @@skip_query_plan_cache", - expected: "SELECT :__vtskip_query_plan_cache as `@@skip_query_plan_cache`", - skipQueryPlanCache: true, - }, { - in: "SELECT @@sql_select_limit", - expected: "SELECT :__vtsql_select_limit as `@@sql_select_limit`", - sqlSelectLimit: true, - }, { - in: "SELECT @@transaction_mode", - expected: "SELECT :__vttransaction_mode as `@@transaction_mode`", - transactionMode: true, - }, { - in: "SELECT @@workload", - expected: "SELECT :__vtworkload as `@@workload`", - workload: true, - }, { - in: "SELECT @@socket", - expected: "SELECT :__vtsocket as `@@socket`", - socket: true, - }, { - in: "select (select 42) from dual", - expected: "select 42 as `(select 42 from dual)` from dual", - }, { - in: "select * from user where col = (select 42)", - expected: "select * from user where col = 42", - }, { - in: "select * from (select 42) as t", // this is not an expression, and should not be rewritten - expected: "select * from (select 42) as t", - }, { - in: `select (select (select (select (select (select last_insert_id()))))) as x`, - expected: "select :__lastInsertId as x from dual", - liid: true, - }, { - in: `select * from user where col = @@ddl_strategy`, - expected: "select * from user where col = :__vtddl_strategy", - ddlStrategy: true, - }, { - in: `select * from user where col = @@migration_context`, - expected: "select * from user where col = :__vtmigration_context", - migrationContext: true, - }, { - in: `select * from user where col = @@read_after_write_gtid OR col = @@read_after_write_timeout OR col = @@session_track_gtids`, - expected: "select * from user where col = :__vtread_after_write_gtid or col = :__vtread_after_write_timeout or col = :__vtsession_track_gtids", - rawGTID: true, rawTimeout: true, sessTrackGTID: true, - }, { - in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual)", - expected: "SELECT * FROM tbl WHERE id IN (1)", - }, { - in: "SELECT * FROM tbl WHERE id IN (SELECT last_insert_id() FROM dual)", - expected: "SELECT * FROM tbl WHERE id IN (:__lastInsertId)", - liid: true, - }, { - in: "SELECT * FROM tbl WHERE id IN (SELECT (SELECT 1 FROM dual WHERE 1 = 0) FROM dual)", - expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual WHERE 1 = 0)", - }, { - in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual WHERE 1 = 0)", - expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual WHERE 1 = 0)", - }, { - in: "SELECT * FROM tbl WHERE id IN (SELECT 1,2 FROM dual)", - expected: "SELECT * FROM tbl WHERE id IN (SELECT 1,2 FROM dual)", - }, { - in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual ORDER BY 1)", - expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual ORDER BY 1)", - }, { - in: "SELECT * FROM tbl WHERE id IN (SELECT id FROM user GROUP BY id)", - expected: "SELECT * FROM tbl WHERE id IN (SELECT id FROM user GROUP BY id)", - }, { - in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual, user)", - expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual, user)", - }, { - in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual limit 1)", - expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual limit 1)", - }, { - // SELECT * behaves different depending the join type used, so if that has been used, we won't rewrite - in: "SELECT * FROM A JOIN B USING (id1,id2,id3)", - expected: "SELECT * FROM A JOIN B USING (id1,id2,id3)", - }, { - in: "CALL proc(@foo)", - expected: "CALL proc(:__vtudvfoo)", - udv: 1, - }, { - in: "SELECT * FROM tbl WHERE NOT id = 42", - expected: "SELECT * FROM tbl WHERE id != 42", - }, { - in: "SELECT * FROM tbl WHERE not id < 12", - expected: "SELECT * FROM tbl WHERE id >= 12", - }, { - in: "SELECT * FROM tbl WHERE not id > 12", - expected: "SELECT * FROM tbl WHERE id <= 12", - }, { - in: "SELECT * FROM tbl WHERE not id <= 33", - expected: "SELECT * FROM tbl WHERE id > 33", - }, { - in: "SELECT * FROM tbl WHERE not id >= 33", - expected: "SELECT * FROM tbl WHERE id < 33", - }, { - in: "SELECT * FROM tbl WHERE not id != 33", - expected: "SELECT * FROM tbl WHERE id = 33", - }, { - in: "SELECT * FROM tbl WHERE not id in (1,2,3)", - expected: "SELECT * FROM tbl WHERE id not in (1,2,3)", - }, { - in: "SELECT * FROM tbl WHERE not id not in (1,2,3)", - expected: "SELECT * FROM tbl WHERE id in (1,2,3)", - }, { - in: "SELECT * FROM tbl WHERE not id not in (1,2,3)", - expected: "SELECT * FROM tbl WHERE id in (1,2,3)", - }, { - in: "SELECT * FROM tbl WHERE not id like '%foobar'", - expected: "SELECT * FROM tbl WHERE id not like '%foobar'", - }, { - in: "SELECT * FROM tbl WHERE not id not like '%foobar'", - expected: "SELECT * FROM tbl WHERE id like '%foobar'", - }, { - in: "SELECT * FROM tbl WHERE not id regexp '%foobar'", - expected: "SELECT * FROM tbl WHERE id not regexp '%foobar'", - }, { - in: "SELECT * FROM tbl WHERE not id not regexp '%foobar'", - expected: "select * from tbl where id regexp '%foobar'", - }, { - in: "SELECT * FROM tbl WHERE exists(select col1, col2 from other_table where foo > bar)", - expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar)", - }, { - in: "SELECT * FROM tbl WHERE exists(select col1, col2 from other_table where foo > bar limit 100 offset 34)", - expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar limit 100 offset 34)", - }, { - in: "SELECT * FROM tbl WHERE exists(select col1, col2, count(*) from other_table where foo > bar group by col1, col2)", - expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar)", - }, { - in: "SELECT * FROM tbl WHERE exists(select col1, col2 from other_table where foo > bar group by col1, col2)", - expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar)", - }, { - in: "SELECT * FROM tbl WHERE exists(select count(*) from other_table where foo > bar)", - expected: "SELECT * FROM tbl WHERE true", - }, { - in: "SELECT * FROM tbl WHERE exists(select col1, col2, count(*) from other_table where foo > bar group by col1, col2 having count(*) > 3)", - expected: "SELECT * FROM tbl WHERE exists(select col1, col2, count(*) from other_table where foo > bar group by col1, col2 having count(*) > 3)", - }, { - in: "SELECT id, name, salary FROM user_details", - expected: "SELECT id, name, salary FROM (select user.id, user.name, user_extra.salary from user join user_extra where user.id = user_extra.user_id) as user_details", - }, { - in: "select max(distinct c1), min(distinct c2), avg(distinct c3), sum(distinct c4), count(distinct c5), group_concat(distinct c6) from tbl", - expected: "select max(c1) as `max(distinct c1)`, min(c2) as `min(distinct c2)`, avg(distinct c3), sum(distinct c4), count(distinct c5), group_concat(distinct c6) from tbl", - }, { - in: "SHOW VARIABLES", - expected: "SHOW VARIABLES", - autocommit: true, - foreignKeyChecks: true, - clientFoundRows: true, - skipQueryPlanCache: true, - sqlSelectLimit: true, - transactionMode: true, - workload: true, - version: true, - versionComment: true, - ddlStrategy: true, - migrationContext: true, - sessionUUID: true, - sessionEnableSystemSettings: true, - rawGTID: true, - rawTimeout: true, - sessTrackGTID: true, - socket: true, - queryTimeout: true, - }, { - in: "SHOW GLOBAL VARIABLES", - expected: "SHOW GLOBAL VARIABLES", - autocommit: true, - foreignKeyChecks: true, - clientFoundRows: true, - skipQueryPlanCache: true, - sqlSelectLimit: true, - transactionMode: true, - workload: true, - version: true, - versionComment: true, - ddlStrategy: true, - migrationContext: true, - sessionUUID: true, - sessionEnableSystemSettings: true, - rawGTID: true, - rawTimeout: true, - sessTrackGTID: true, - socket: true, - queryTimeout: true, - }} - parser := NewTestParser() - for _, tc := range tests { - in.Run(tc.in, func(t *testing.T) { - require := require.New(t) - stmt, err := parser.Parse(tc.in) - require.NoError(err) - - result, err := RewriteAST( - stmt, - "ks", // passing `ks` just to test that no rewriting happens as it is not system schema - SQLSelectLimitUnset, - "", - nil, - nil, - &fakeViews{}, - ) - require.NoError(err) - - expected, err := parser.Parse(tc.expected) - require.NoError(err, "test expectation does not parse [%s]", tc.expected) - - s := String(expected) - assert := assert.New(t) - assert.Equal(s, String(result.AST)) - assert.Equal(tc.liid, result.NeedsFuncResult(LastInsertIDName), "should need last insert id") - assert.Equal(tc.db, result.NeedsFuncResult(DBVarName), "should need database name") - assert.Equal(tc.foundRows, result.NeedsFuncResult(FoundRowsName), "should need found rows") - assert.Equal(tc.rowCount, result.NeedsFuncResult(RowCountName), "should need row count") - assert.Equal(tc.udv, len(result.NeedUserDefinedVariables), "count of user defined variables") - assert.Equal(tc.autocommit, result.NeedsSysVar(sysvars.Autocommit.Name), "should need :__vtautocommit") - assert.Equal(tc.foreignKeyChecks, result.NeedsSysVar(sysvars.ForeignKeyChecks), "should need :__vtforeignKeyChecks") - assert.Equal(tc.clientFoundRows, result.NeedsSysVar(sysvars.ClientFoundRows.Name), "should need :__vtclientFoundRows") - assert.Equal(tc.skipQueryPlanCache, result.NeedsSysVar(sysvars.SkipQueryPlanCache.Name), "should need :__vtskipQueryPlanCache") - assert.Equal(tc.sqlSelectLimit, result.NeedsSysVar(sysvars.SQLSelectLimit.Name), "should need :__vtsqlSelectLimit") - assert.Equal(tc.transactionMode, result.NeedsSysVar(sysvars.TransactionMode.Name), "should need :__vttransactionMode") - assert.Equal(tc.workload, result.NeedsSysVar(sysvars.Workload.Name), "should need :__vtworkload") - assert.Equal(tc.queryTimeout, result.NeedsSysVar(sysvars.QueryTimeout.Name), "should need :__vtquery_timeout") - assert.Equal(tc.ddlStrategy, result.NeedsSysVar(sysvars.DDLStrategy.Name), "should need ddlStrategy") - assert.Equal(tc.migrationContext, result.NeedsSysVar(sysvars.MigrationContext.Name), "should need migrationContext") - assert.Equal(tc.sessionUUID, result.NeedsSysVar(sysvars.SessionUUID.Name), "should need sessionUUID") - assert.Equal(tc.sessionEnableSystemSettings, result.NeedsSysVar(sysvars.SessionEnableSystemSettings.Name), "should need sessionEnableSystemSettings") - assert.Equal(tc.rawGTID, result.NeedsSysVar(sysvars.ReadAfterWriteGTID.Name), "should need rawGTID") - assert.Equal(tc.rawTimeout, result.NeedsSysVar(sysvars.ReadAfterWriteTimeOut.Name), "should need rawTimeout") - assert.Equal(tc.sessTrackGTID, result.NeedsSysVar(sysvars.SessionTrackGTIDs.Name), "should need sessTrackGTID") - assert.Equal(tc.version, result.NeedsSysVar(sysvars.Version.Name), "should need Vitess version") - assert.Equal(tc.versionComment, result.NeedsSysVar(sysvars.VersionComment.Name), "should need Vitess version") - assert.Equal(tc.socket, result.NeedsSysVar(sysvars.Socket.Name), "should need :__vtsocket") - }) - } -} - -type fakeViews struct{} - -func (*fakeViews) FindView(name TableName) TableStatement { - if name.Name.String() != "user_details" { - return nil - } - parser := NewTestParser() - statement, err := parser.Parse("select user.id, user.name, user_extra.salary from user join user_extra where user.id = user_extra.user_id") - if err != nil { - return nil - } - return statement.(TableStatement) -} - -func TestRewritesWithSetVarComment(in *testing.T) { - tests := []testCaseSetVar{{ - in: "select 1", - expected: "select 1", - setVarComment: "", - }, { - in: "select 1", - expected: "select /*+ AA(a) */ 1", - setVarComment: "AA(a)", - }, { - in: "insert /* toto */ into t(id) values(1)", - expected: "insert /*+ AA(a) */ /* toto */ into t(id) values(1)", - setVarComment: "AA(a)", - }, { - in: "select /* toto */ * from t union select * from s", - expected: "select /*+ AA(a) */ /* toto */ * from t union select /*+ AA(a) */ * from s", - setVarComment: "AA(a)", - }, { - in: "vstream /* toto */ * from t1", - expected: "vstream /*+ AA(a) */ /* toto */ * from t1", - setVarComment: "AA(a)", - }, { - in: "stream /* toto */ t from t1", - expected: "stream /*+ AA(a) */ /* toto */ t from t1", - setVarComment: "AA(a)", - }, { - in: "update /* toto */ t set id = 1", - expected: "update /*+ AA(a) */ /* toto */ t set id = 1", - setVarComment: "AA(a)", - }, { - in: "delete /* toto */ from t", - expected: "delete /*+ AA(a) */ /* toto */ from t", - setVarComment: "AA(a)", - }} - - parser := NewTestParser() - for _, tc := range tests { - in.Run(tc.in, func(t *testing.T) { - require := require.New(t) - stmt, err := parser.Parse(tc.in) - require.NoError(err) - - result, err := RewriteAST(stmt, "ks", SQLSelectLimitUnset, tc.setVarComment, nil, nil, &fakeViews{}) - require.NoError(err) - - expected, err := parser.Parse(tc.expected) - require.NoError(err, "test expectation does not parse [%s]", tc.expected) - - assert.Equal(t, String(expected), String(result.AST)) - }) - } -} - -func TestRewritesSysVar(in *testing.T) { - tests := []testCaseSysVar{{ - in: "select @x = @@sql_mode", - expected: "select :__vtudvx = @@sql_mode as `@x = @@sql_mode` from dual", - }, { - in: "select @x = @@sql_mode", - expected: "select :__vtudvx = :__vtsql_mode as `@x = @@sql_mode` from dual", - sysVar: map[string]string{"sql_mode": "' '"}, - }, { - in: "SELECT @@tx_isolation", - expected: "select @@tx_isolation from dual", - }, { - in: "SELECT @@transaction_isolation", - expected: "select @@transaction_isolation from dual", - }, { - in: "SELECT @@session.transaction_isolation", - expected: "select @@session.transaction_isolation from dual", - }, { - in: "SELECT @@tx_isolation", - sysVar: map[string]string{"tx_isolation": "'READ-COMMITTED'"}, - expected: "select :__vttx_isolation as `@@tx_isolation` from dual", - }, { - in: "SELECT @@transaction_isolation", - sysVar: map[string]string{"transaction_isolation": "'READ-COMMITTED'"}, - expected: "select :__vttransaction_isolation as `@@transaction_isolation` from dual", - }, { - in: "SELECT @@session.transaction_isolation", - sysVar: map[string]string{"transaction_isolation": "'READ-COMMITTED'"}, - expected: "select :__vttransaction_isolation as `@@session.transaction_isolation` from dual", - }} - - parser := NewTestParser() - for _, tc := range tests { - in.Run(tc.in, func(t *testing.T) { - require := require.New(t) - stmt, err := parser.Parse(tc.in) - require.NoError(err) - - result, err := RewriteAST(stmt, "ks", SQLSelectLimitUnset, "", tc.sysVar, nil, &fakeViews{}) - require.NoError(err) - - expected, err := parser.Parse(tc.expected) - require.NoError(err, "test expectation does not parse [%s]", tc.expected) - - assert.Equal(t, String(expected), String(result.AST)) - }) - } -} - -func TestRewritesWithDefaultKeyspace(in *testing.T) { - tests := []myTestCase{{ - in: "SELECT 1 from x.test", - expected: "SELECT 1 from x.test", // no change - }, { - in: "SELECT x.col as c from x.test", - expected: "SELECT x.col as c from x.test", // no change - }, { - in: "SELECT 1 from test", - expected: "SELECT 1 from sys.test", - }, { - in: "SELECT 1 from test as t", - expected: "SELECT 1 from sys.test as t", - }, { - in: "SELECT 1 from `test 24` as t", - expected: "SELECT 1 from sys.`test 24` as t", - }, { - in: "SELECT 1, (select 1 from test) from x.y", - expected: "SELECT 1, (select 1 from sys.test) from x.y", - }, { - in: "SELECT 1 from (select 2 from test) t", - expected: "SELECT 1 from (select 2 from sys.test) t", - }, { - in: "SELECT 1 from test where exists(select 2 from test)", - expected: "SELECT 1 from sys.test where exists(select 1 from sys.test)", - }, { - in: "SELECT 1 from dual", - expected: "SELECT 1 from dual", - }, { - in: "SELECT (select 2 from dual) from DUAL", - expected: "SELECT 2 as `(select 2 from dual)` from DUAL", - }} - - parser := NewTestParser() - for _, tc := range tests { - in.Run(tc.in, func(t *testing.T) { - require := require.New(t) - stmt, err := parser.Parse(tc.in) - require.NoError(err) - - result, err := RewriteAST(stmt, "sys", SQLSelectLimitUnset, "", nil, nil, &fakeViews{}) - require.NoError(err) - - expected, err := parser.Parse(tc.expected) - require.NoError(err, "test expectation does not parse [%s]", tc.expected) - - assert.Equal(t, String(expected), String(result.AST)) - }) - } -} - -func TestReservedVars(t *testing.T) { - for _, prefix := range []string{"vtg", "bv"} { - t.Run("prefix_"+prefix, func(t *testing.T) { - reserved := NewReservedVars(prefix, make(BindVars)) - for i := 1; i < 1000; i++ { - require.Equal(t, fmt.Sprintf("%s%d", prefix, i), reserved.nextUnusedVar()) - } - }) - } -} diff --git a/go/vt/sqlparser/bind_var_needs.go b/go/vt/sqlparser/bind_var_needs.go index 1b26919ca03..64e5c528e97 100644 --- a/go/vt/sqlparser/bind_var_needs.go +++ b/go/vt/sqlparser/bind_var_needs.go @@ -22,14 +22,7 @@ type BindVarNeeds struct { NeedSystemVariable, // NeedUserDefinedVariables keeps track of all user defined variables a query is using NeedUserDefinedVariables []string - otherRewrites bool -} - -// MergeWith adds bind vars needs coming from sub scopes -func (bvn *BindVarNeeds) MergeWith(other *BindVarNeeds) { - bvn.NeedFunctionResult = append(bvn.NeedFunctionResult, other.NeedFunctionResult...) - bvn.NeedSystemVariable = append(bvn.NeedSystemVariable, other.NeedSystemVariable...) - bvn.NeedUserDefinedVariables = append(bvn.NeedUserDefinedVariables, other.NeedUserDefinedVariables...) + otherRewrites int } // AddFuncResult adds a function bindvar need @@ -58,14 +51,11 @@ func (bvn *BindVarNeeds) NeedsSysVar(name string) bool { } func (bvn *BindVarNeeds) NoteRewrite() { - bvn.otherRewrites = true + bvn.otherRewrites++ } -func (bvn *BindVarNeeds) HasRewrites() bool { - return bvn.otherRewrites || - len(bvn.NeedFunctionResult) > 0 || - len(bvn.NeedUserDefinedVariables) > 0 || - len(bvn.NeedSystemVariable) > 0 +func (bvn *BindVarNeeds) NumberOfRewrites() int { + return len(bvn.NeedFunctionResult) + len(bvn.NeedUserDefinedVariables) + len(bvn.NeedSystemVariable) + bvn.otherRewrites } func contains(strings []string, name string) bool { diff --git a/go/vt/sqlparser/normalizer.go b/go/vt/sqlparser/normalizer.go index 02cb11e2a97..fb3813b7019 100644 --- a/go/vt/sqlparser/normalizer.go +++ b/go/vt/sqlparser/normalizer.go @@ -18,107 +18,277 @@ package sqlparser import ( "bytes" + "fmt" + "strconv" + "strings" "vitess.io/vitess/go/mysql/datetime" "vitess.io/vitess/go/sqltypes" + querypb "vitess.io/vitess/go/vt/proto/query" vtrpcpb "vitess.io/vitess/go/vt/proto/vtrpc" + "vitess.io/vitess/go/vt/sysvars" "vitess.io/vitess/go/vt/vterrors" +) - querypb "vitess.io/vitess/go/vt/proto/query" +type ( + // BindVars represents a set of reserved bind variables extracted from a SQL statement. + BindVars map[string]struct{} + // normalizer transforms SQL statements to support parameterization and streamline query planning. + // + // It serves two primary purposes: + // 1. **Parameterization:** Allows multiple invocations of the same query with different literals by converting literals + // to bind variables. This enables efficient reuse of execution plans with varying parameters. + // 2. **Simplified Planning:** Reduces the complexity for the query planner by standardizing SQL patterns. For example, + // it ensures that table columns are consistently placed on the left side of comparison expressions. This uniformity + // minimizes the number of distinct patterns the planner must handle, enhancing planning efficiency. + normalizer struct { + bindVars map[string]*querypb.BindVariable + reserved *ReservedVars + vals map[Literal]string + err error + inDerived int + inSelect int + + bindVarNeeds *BindVarNeeds + shouldRewriteDatabaseFunc bool + hasStarInSelect bool + + keyspace string + selectLimit int + setVarComment string + fkChecksState *bool + sysVars map[string]string + views VSchemaViews + + onLeave map[*AliasedExpr]func(*AliasedExpr) + parameterize bool + } + // RewriteASTResult holds the result of rewriting the AST, including bind variable needs. + RewriteASTResult struct { + *BindVarNeeds + AST Statement // The rewritten AST + } + // VSchemaViews provides access to view definitions within the VSchema. + VSchemaViews interface { + FindView(name TableName) TableStatement + } ) -// BindVars is a set of reserved bind variables from a SQL statement -type BindVars map[string]struct{} +const ( + // SQLSelectLimitUnset indicates that sql_select_limit is not set. + SQLSelectLimitUnset = -1 + // LastInsertIDName is the bind variable name for LAST_INSERT_ID(). + LastInsertIDName = "__lastInsertId" + // DBVarName is the bind variable name for DATABASE(). + DBVarName = "__vtdbname" + // FoundRowsName is the bind variable name for FOUND_ROWS(). + FoundRowsName = "__vtfrows" + // RowCountName is the bind variable name for ROW_COUNT(). + RowCountName = "__vtrcount" + // UserDefinedVariableName is the prefix for user-defined variable bind names. + UserDefinedVariableName = "__vtudv" +) -// Normalize changes the statement to use bind values, and -// updates the bind vars to those values. The supplied prefix -// is used to generate the bind var names. The function ensures -// that there are no collisions with existing bind vars. -// Within Select constructs, bind vars are deduped. This allows -// us to identify vindex equality. Otherwise, every value is -// treated as distinct. -func Normalize(stmt Statement, reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) error { - nz := newNormalizer(reserved, bindVars) - _ = SafeRewrite(stmt, nz.walkStatementDown, nz.walkStatementUp) - return nz.err +// funcRewrites lists all functions that must be rewritten. we don't want these to make it down to mysql, +// we need to handle these in the vtgate +var funcRewrites = map[string]string{ + "last_insert_id": LastInsertIDName, + "database": DBVarName, + "schema": DBVarName, + "found_rows": FoundRowsName, + "row_count": RowCountName, } -type normalizer struct { - bindVars map[string]*querypb.BindVariable - reserved *ReservedVars - vals map[Literal]string - err error - inDerived int - inSelect int -} +// PrepareAST normalizes the input SQL statement and returns the rewritten AST along with bind variable information. +func PrepareAST( + in Statement, + reservedVars *ReservedVars, + bindVars map[string]*querypb.BindVariable, + parameterize bool, + keyspace string, + selectLimit int, + setVarComment string, + sysVars map[string]string, + fkChecksState *bool, + views VSchemaViews, +) (*RewriteASTResult, error) { + nz := newNormalizer(reservedVars, bindVars, keyspace, selectLimit, setVarComment, sysVars, fkChecksState, views, parameterize) + nz.shouldRewriteDatabaseFunc = shouldRewriteDatabaseFunc(in) + out := SafeRewrite(in, nz.walkDown, nz.walkUp) + if nz.err != nil { + return nil, nz.err + } -func newNormalizer(reserved *ReservedVars, bindVars map[string]*querypb.BindVariable) *normalizer { - return &normalizer{ - bindVars: bindVars, - reserved: reserved, - vals: make(map[Literal]string), + r := &RewriteASTResult{ + AST: out.(Statement), + BindVarNeeds: nz.bindVarNeeds, } + return r, nil } -// walkStatementUp is one half of the top level walk function. -func (nz *normalizer) walkStatementUp(cursor *Cursor) bool { - if nz.err != nil { - return false - } - switch node := cursor.node.(type) { - case *DerivedTable: - nz.inDerived-- - case *Select: - nz.inSelect-- - case *Literal: - if nz.inSelect == 0 { - nz.convertLiteral(node, cursor) - return nz.err == nil - } - parent := cursor.Parent() - switch parent.(type) { - case *Order, *GroupBy: - return true - case *Limit: - nz.convertLiteral(node, cursor) - default: - nz.convertLiteralDedup(node, cursor) - } +func newNormalizer( + reserved *ReservedVars, + bindVars map[string]*querypb.BindVariable, + keyspace string, + selectLimit int, + setVarComment string, + sysVars map[string]string, + fkChecksState *bool, + views VSchemaViews, + parameterize bool, +) *normalizer { + return &normalizer{ + bindVars: bindVars, + reserved: reserved, + vals: make(map[Literal]string), + bindVarNeeds: &BindVarNeeds{}, + keyspace: keyspace, + selectLimit: selectLimit, + setVarComment: setVarComment, + fkChecksState: fkChecksState, + sysVars: sysVars, + views: views, + onLeave: make(map[*AliasedExpr]func(*AliasedExpr)), + parameterize: parameterize, } - return nz.err == nil // only continue if we haven't found any errors } -// walkStatementDown is the top level walk function. -// If it encounters a Select, it switches to a mode -// where variables are deduped. -func (nz *normalizer) walkStatementDown(node, _ SQLNode) bool { +// walkDown processes nodes when traversing down the AST. +// It handles normalization logic based on node types. +func (nz *normalizer) walkDown(node, _ SQLNode) bool { switch node := node.(type) { - // no need to normalize the statement types - case *Set, *Show, *Begin, *Commit, *Rollback, *Savepoint, DDLStatement, *SRollback, *Release, *OtherAdmin, *Analyze: + case *Begin, *Commit, *Rollback, *Savepoint, *SRollback, *Release, *OtherAdmin, *Analyze, *AssignmentExpr, + *PrepareStmt, *ExecuteStmt, *FramePoint, *ColName, TableName, *ConvertType: + // These statement don't need normalizing return false + case *Set: + // Disable parameterization within SET statements. + nz.parameterize = false case *DerivedTable: nz.inDerived++ case *Select: nz.inSelect++ - case SelectExprs: - return nz.inDerived == 0 + if nz.selectLimit > 0 && node.Limit == nil { + node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(nz.selectLimit))} + } + case *AliasedExpr: + nz.noteAliasedExprName(node) case *ComparisonExpr: nz.convertComparison(node) case *UpdateExpr: nz.convertUpdateExpr(node) - case *ColName, TableName: - // Common node types that never contain Literal or ListArgs but create a lot of object - // allocations. - return false - case *ConvertType: // we should not rewrite the type description + case *StarExpr: + nz.hasStarInSelect = true + // No rewriting needed for prepare or execute statements. return false - case *FramePoint: - // do not make a bind var for rows and range + case *ShowBasic: + if node.Command != VariableGlobal && node.Command != VariableSession { + break + } + varsToAdd := sysvars.GetInterestingVariables() + for _, sysVar := range varsToAdd { + nz.bindVarNeeds.AddSysVar(sysVar) + } + } + b := nz.err == nil + if !b { + fmt.Println(1) + } + return b +} + +// noteAliasedExprName tracks expressions without aliases to add alias if expression is rewritten +func (nz *normalizer) noteAliasedExprName(node *AliasedExpr) { + if node.As.NotEmpty() { + return + } + buf := NewTrackedBuffer(nil) + node.Expr.Format(buf) + rewrites := nz.bindVarNeeds.NumberOfRewrites() + nz.onLeave[node] = func(newAliasedExpr *AliasedExpr) { + if nz.bindVarNeeds.NumberOfRewrites() > rewrites { + newAliasedExpr.As = NewIdentifierCI(buf.String()) + } + } +} + +// walkUp processes nodes when traversing up the AST. +// It finalizes normalization logic based on node types. +func (nz *normalizer) walkUp(cursor *Cursor) bool { + // Add SET_VAR comments if applicable. + if supportOptimizerHint, supports := cursor.Node().(SupportOptimizerHint); supports { + if nz.setVarComment != "" { + newComments, err := supportOptimizerHint.GetParsedComments().AddQueryHint(nz.setVarComment) + if err != nil { + nz.err = err + return false + } + supportOptimizerHint.SetComments(newComments) + } + if nz.fkChecksState != nil { + newComments := supportOptimizerHint.GetParsedComments().SetMySQLSetVarValue(sysvars.ForeignKeyChecks, FkChecksStateString(nz.fkChecksState)) + supportOptimizerHint.SetComments(newComments) + } + } + + if nz.err != nil { return false } - return nz.err == nil // only continue if we haven't found any errors + + switch node := cursor.node.(type) { + case *DerivedTable: + nz.inDerived-- + case *Select: + nz.inSelect-- + case *AliasedExpr: + // if we are tracking this node for changes, this is the time to add the alias if needed + if onLeave, ok := nz.onLeave[node]; ok { + onLeave(node) + delete(nz.onLeave, node) + } + case *Union: + nz.rewriteUnion(node) + case *FuncExpr: + nz.funcRewrite(cursor, node) + case *Variable: + nz.rewriteVariable(cursor, node) + case *Subquery: + nz.unnestSubQueries(cursor, node) + case *NotExpr: + nz.rewriteNotExpr(cursor, node) + case *AliasedTableExpr: + nz.rewriteAliasedTable(cursor, node) + case *ShowBasic: + nz.rewriteShowBasic(node) + case *ExistsExpr: + nz.existsRewrite(cursor, node) + case DistinctableAggr: + nz.rewriteDistinctableAggr(node) + case *Literal: + nz.visitLiteral(cursor, node) + } + return nz.err == nil } +func (nz *normalizer) visitLiteral(cursor *Cursor, node *Literal) { + if !nz.shouldParameterize() { + return + } + if nz.inSelect == 0 { + nz.convertLiteral(node, cursor) + return + } + switch cursor.Parent().(type) { + case *Order, *GroupBy: + return + case *Limit: + nz.convertLiteral(node, cursor) + default: + nz.convertLiteralDedup(node, cursor) + } +} + +// validateLiteral ensures that a Literal node has a valid value based on its type. func validateLiteral(node *Literal) error { switch node.Type { case DateVal: @@ -137,37 +307,31 @@ func validateLiteral(node *Literal) error { return nil } +// convertLiteralDedup converts a Literal node to a bind variable with deduplication. func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { - err := validateLiteral(node) - if err != nil { + if err := validateLiteral(node); err != nil { nz.err = err + return } - // If value is too long, don't dedup. - // Such values are most likely not for vindexes. - // We save a lot of CPU because we avoid building - // the key for them. + // Skip deduplication for long values. if len(node.Val) > 256 { nz.convertLiteral(node, cursor) return } - // Make the bindvar - bval := SQLToBindvar(node) + bval := literalToBindvar(node) if bval == nil { return } - // Check if there's a bindvar for that value already. - bvname, ok := nz.vals[*node] - if !ok { - // If there's no such bindvar, make a new one. + bvname, exists := nz.vals[*node] + if !exists { bvname = nz.reserved.nextUnusedVar() nz.vals[*node] = bvname nz.bindVars[bvname] = bval } - // Modify the AST node to a bindvar. arg, err := NewTypedArgumentFromLiteral(bvname, node) if err != nil { nz.err = err @@ -176,14 +340,14 @@ func (nz *normalizer) convertLiteralDedup(node *Literal, cursor *Cursor) { cursor.Replace(arg) } -// convertLiteral converts an Literal without the dedup. +// convertLiteral converts a Literal node to a bind variable without deduplication. func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) { - err := validateLiteral(node) - if err != nil { + if err := validateLiteral(node); err != nil { nz.err = err + return } - bval := SQLToBindvar(node) + bval := literalToBindvar(node) if bval == nil { return } @@ -198,11 +362,7 @@ func (nz *normalizer) convertLiteral(node *Literal, cursor *Cursor) { cursor.Replace(arg) } -// convertComparison attempts to convert IN clauses to -// use the list bind var construct. If it fails, it returns -// with no change made. The walk function will then continue -// and iterate on converting each individual value into separate -// bind vars. +// convertComparison handles the conversion of comparison expressions to use bind variables. func (nz *normalizer) convertComparison(node *ComparisonExpr) { switch node.Operator { case InOp, NotInOp: @@ -212,14 +372,19 @@ func (nz *normalizer) convertComparison(node *ComparisonExpr) { } } +// rewriteOtherComparisons parameterizes non-IN comparison expressions. func (nz *normalizer) rewriteOtherComparisons(node *ComparisonExpr) { - newR := nz.parameterize(node.Left, node.Right) + newR := nz.normalizeComparisonWithBindVar(node.Left, node.Right) if newR != nil { node.Right = newR } } -func (nz *normalizer) parameterize(left, right Expr) Expr { +// normalizeComparisonWithBindVar attempts to replace a literal in a comparison with a bind variable. +func (nz *normalizer) normalizeComparisonWithBindVar(left, right Expr) Expr { + if !nz.shouldParameterize() { + return nil + } col, ok := left.(*ColName) if !ok { return nil @@ -228,13 +393,12 @@ func (nz *normalizer) parameterize(left, right Expr) Expr { if !ok { return nil } - err := validateLiteral(lit) - if err != nil { + if err := validateLiteral(lit); err != nil { nz.err = err return nil } - bval := SQLToBindvar(lit) + bval := literalToBindvar(lit) if bval == nil { return nil } @@ -247,18 +411,14 @@ func (nz *normalizer) parameterize(left, right Expr) Expr { return arg } +// decideBindVarName determines the appropriate bind variable name for a given literal and column. func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *querypb.BindVariable) string { if len(lit.Val) <= 256 { - // first we check if we already have a bindvar for this value. if we do, we re-use that bindvar name - bvname, ok := nz.vals[*lit] - if ok { + if bvname, ok := nz.vals[*lit]; ok { return bvname } } - // If there's no such bindvar, or we have a big value, make a new one. - // Big values are most likely not for vindexes. - // We save a lot of CPU because we avoid building bvname := nz.reserved.ReserveColName(col) nz.vals[*lit] = bvname nz.bindVars[bvname] = bval @@ -266,19 +426,22 @@ func (nz *normalizer) decideBindVarName(lit *Literal, col *ColName, bval *queryp return bvname } +// rewriteInComparisons converts IN and NOT IN expressions to use list bind variables. func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) { + if !nz.shouldParameterize() { + return + } tupleVals, ok := node.Right.(ValTuple) if !ok { return } - // The RHS is a tuple of values. - // Make a list bindvar. + // Create a list bind variable for the tuple. bvals := &querypb.BindVariable{ Type: querypb.Type_TUPLE, } for _, val := range tupleVals { - bval := SQLToBindvar(val) + bval := literalToBindvar(val) if bval == nil { return } @@ -289,76 +452,74 @@ func (nz *normalizer) rewriteInComparisons(node *ComparisonExpr) { } bvname := nz.reserved.nextUnusedVar() nz.bindVars[bvname] = bvals - // Modify RHS to be a list bindvar. node.Right = ListArg(bvname) } +// convertUpdateExpr parameterizes expressions in UPDATE statements. func (nz *normalizer) convertUpdateExpr(node *UpdateExpr) { - newR := nz.parameterize(node.Name, node.Expr) + newR := nz.normalizeComparisonWithBindVar(node.Name, node.Expr) if newR != nil { node.Expr = newR } } -func SQLToBindvar(node SQLNode) *querypb.BindVariable { - if node, ok := node.(*Literal); ok { - var v sqltypes.Value - var err error - switch node.Type { - case StrVal: - v, err = sqltypes.NewValue(sqltypes.VarChar, node.Bytes()) - case IntVal: - v, err = sqltypes.NewValue(sqltypes.Int64, node.Bytes()) - case FloatVal: - v, err = sqltypes.NewValue(sqltypes.Float64, node.Bytes()) - case DecimalVal: - v, err = sqltypes.NewValue(sqltypes.Decimal, node.Bytes()) - case HexNum: - buf := make([]byte, 0, len(node.Bytes())) - buf = append(buf, "0x"...) - buf = append(buf, bytes.ToUpper(node.Bytes()[2:])...) - v, err = sqltypes.NewValue(sqltypes.HexNum, buf) - case HexVal: - // We parse the `x'7b7d'` string literal into a hex encoded string of `7b7d` in the parser - // We need to re-encode it back to the original MySQL query format before passing it on as a bindvar value to MySQL - buf := make([]byte, 0, len(node.Bytes())+3) - buf = append(buf, 'x', '\'') - buf = append(buf, bytes.ToUpper(node.Bytes())...) - buf = append(buf, '\'') - v, err = sqltypes.NewValue(sqltypes.HexVal, buf) - case BitNum: - out := make([]byte, 0, len(node.Bytes())+2) - out = append(out, '0', 'b') - out = append(out, node.Bytes()[2:]...) - v, err = sqltypes.NewValue(sqltypes.BitNum, out) - case DateVal: - v, err = sqltypes.NewValue(sqltypes.Date, node.Bytes()) - case TimeVal: - v, err = sqltypes.NewValue(sqltypes.Time, node.Bytes()) - case TimestampVal: - // This is actually a DATETIME MySQL type. The timestamp literal - // syntax is part of the SQL standard and MySQL DATETIME matches - // the type best. - v, err = sqltypes.NewValue(sqltypes.Datetime, node.Bytes()) - default: - return nil - } - if err != nil { - return nil - } - return sqltypes.ValueBindVariable(v) +// literalToBindvar converts a SQLNode to a BindVariable if possible. +func literalToBindvar(node SQLNode) *querypb.BindVariable { + lit, ok := node.(*Literal) + if !ok { + return nil } - return nil + var v sqltypes.Value + var err error + switch lit.Type { + case StrVal: + v, err = sqltypes.NewValue(sqltypes.VarChar, lit.Bytes()) + case IntVal: + v, err = sqltypes.NewValue(sqltypes.Int64, lit.Bytes()) + case FloatVal: + v, err = sqltypes.NewValue(sqltypes.Float64, lit.Bytes()) + case DecimalVal: + v, err = sqltypes.NewValue(sqltypes.Decimal, lit.Bytes()) + case HexNum: + buf := make([]byte, 0, len(lit.Bytes())) + buf = append(buf, "0x"...) + buf = append(buf, bytes.ToUpper(lit.Bytes()[2:])...) + v, err = sqltypes.NewValue(sqltypes.HexNum, buf) + case HexVal: + // Re-encode hex string literals to original MySQL format. + buf := make([]byte, 0, len(lit.Bytes())+3) + buf = append(buf, 'x', '\'') + buf = append(buf, bytes.ToUpper(lit.Bytes())...) + buf = append(buf, '\'') + v, err = sqltypes.NewValue(sqltypes.HexVal, buf) + case BitNum: + out := make([]byte, 0, len(lit.Bytes())+2) + out = append(out, '0', 'b') + out = append(out, lit.Bytes()[2:]...) + v, err = sqltypes.NewValue(sqltypes.BitNum, out) + case DateVal: + v, err = sqltypes.NewValue(sqltypes.Date, lit.Bytes()) + case TimeVal: + v, err = sqltypes.NewValue(sqltypes.Time, lit.Bytes()) + case TimestampVal: + // Use DATETIME type for TIMESTAMP literals. + v, err = sqltypes.NewValue(sqltypes.Datetime, lit.Bytes()) + default: + return nil + } + if err != nil { + return nil + } + return sqltypes.ValueBindVariable(v) } -// GetBindvars returns a map of the bind vars referenced in the statement. -func GetBindvars(stmt Statement) map[string]struct{} { +// getBindvars extracts bind variables from a SQL statement. +func getBindvars(stmt Statement) map[string]struct{} { bindvars := make(map[string]struct{}) _ = Walk(func(node SQLNode) (kontinue bool, err error) { switch node := node.(type) { case *ColName, TableName: - // Common node types that never contain expressions but create a lot of object - // allocations. + // These node types do not contain expressions. return false, nil case *Argument: bindvars[node.Name] = struct{}{} @@ -369,3 +530,312 @@ func GetBindvars(stmt Statement) map[string]struct{} { }, stmt) return bindvars } + +var HasValueSubQueryBaseName = []byte("__sq_has_values") + +// shouldRewriteDatabaseFunc determines if the database function should be rewritten based on the statement. +func shouldRewriteDatabaseFunc(in Statement) bool { + selct, ok := in.(*Select) + if !ok { + return false + } + if len(selct.From) != 1 { + return false + } + aliasedTable, ok := selct.From[0].(*AliasedTableExpr) + if !ok { + return false + } + tableName, ok := aliasedTable.Expr.(TableName) + if !ok { + return false + } + return tableName.Name.String() == "dual" +} + +// rewriteUnion sets the SELECT limit for UNION statements if not already set. +func (nz *normalizer) rewriteUnion(node *Union) { + if nz.selectLimit > 0 && node.Limit == nil { + node.Limit = &Limit{Rowcount: NewIntLiteral(strconv.Itoa(nz.selectLimit))} + } +} + +// rewriteAliasedTable handles the rewriting of aliased tables, including view substitutions. +func (nz *normalizer) rewriteAliasedTable(cursor *Cursor, node *AliasedTableExpr) { + aliasTableName, ok := node.Expr.(TableName) + if !ok { + return + } + + // Do not add qualifiers to the dual table. + tblName := aliasTableName.Name.String() + if tblName == "dual" { + return + } + + if SystemSchema(nz.keyspace) { + if aliasTableName.Qualifier.IsEmpty() { + aliasTableName.Qualifier = NewIdentifierCS(nz.keyspace) + node.Expr = aliasTableName + cursor.Replace(node) + } + return + } + + // Replace views with their underlying definitions. + if nz.views == nil { + return + } + view := nz.views.FindView(aliasTableName) + if view == nil { + return + } + + // Substitute the view with a derived table. + node.Expr = &DerivedTable{Select: Clone(view)} + if node.As.IsEmpty() { + node.As = NewIdentifierCS(tblName) + } +} + +// rewriteShowBasic handles the rewriting of SHOW statements, particularly for system variables. +func (nz *normalizer) rewriteShowBasic(node *ShowBasic) { + if node.Command == VariableGlobal || node.Command == VariableSession { + varsToAdd := sysvars.GetInterestingVariables() + for _, sysVar := range varsToAdd { + nz.bindVarNeeds.AddSysVar(sysVar) + } + } +} + +// rewriteNotExpr simplifies NOT expressions where possible. +func (nz *normalizer) rewriteNotExpr(cursor *Cursor, node *NotExpr) { + switch inner := node.Expr.(type) { + case *ComparisonExpr: + // Invert comparison operators. + if canChange, inverse := inverseOp(inner.Operator); canChange { + inner.Operator = inverse + cursor.Replace(inner) + } + case *NotExpr: + // Simplify double negation. + cursor.Replace(inner.Expr) + case BoolVal: + // Negate boolean values. + cursor.Replace(!inner) + } +} + +// rewriteVariable handles the rewriting of variable expressions to bind variables. +func (nz *normalizer) rewriteVariable(cursor *Cursor, node *Variable) { + // Do not rewrite variables on the left side of SET assignments. + if v, isSet := cursor.Parent().(*SetExpr); isSet && v.Var == node { + return + } + switch node.Scope { + case VariableScope: + nz.udvRewrite(cursor, node) + case SessionScope, NextTxScope: + nz.sysVarRewrite(cursor, node) + } +} + +// inverseOp returns the inverse operator for a given comparison operator. +func inverseOp(i ComparisonExprOperator) (bool, ComparisonExprOperator) { + switch i { + case EqualOp: + return true, NotEqualOp + case LessThanOp: + return true, GreaterEqualOp + case GreaterThanOp: + return true, LessEqualOp + case LessEqualOp: + return true, GreaterThanOp + case GreaterEqualOp: + return true, LessThanOp + case NotEqualOp: + return true, EqualOp + case InOp: + return true, NotInOp + case NotInOp: + return true, InOp + case LikeOp: + return true, NotLikeOp + case NotLikeOp: + return true, LikeOp + case RegexpOp: + return true, NotRegexpOp + case NotRegexpOp: + return true, RegexpOp + } + return false, i +} + +// sysVarRewrite replaces system variables with corresponding bind variables. +func (nz *normalizer) sysVarRewrite(cursor *Cursor, node *Variable) { + lowered := node.Name.Lowered() + + var found bool + if nz.sysVars != nil { + _, found = nz.sysVars[lowered] + } + + switch lowered { + case sysvars.Autocommit.Name, + sysvars.Charset.Name, + sysvars.ClientFoundRows.Name, + sysvars.DDLStrategy.Name, + sysvars.MigrationContext.Name, + sysvars.Names.Name, + sysvars.TransactionMode.Name, + sysvars.ReadAfterWriteGTID.Name, + sysvars.ReadAfterWriteTimeOut.Name, + sysvars.SessionEnableSystemSettings.Name, + sysvars.SessionTrackGTIDs.Name, + sysvars.SessionUUID.Name, + sysvars.SkipQueryPlanCache.Name, + sysvars.Socket.Name, + sysvars.SQLSelectLimit.Name, + sysvars.Version.Name, + sysvars.VersionComment.Name, + sysvars.QueryTimeout.Name, + sysvars.Workload.Name: + found = true + } + + if found { + cursor.Replace(NewArgument("__vt" + lowered)) + nz.bindVarNeeds.AddSysVar(lowered) + } +} + +// udvRewrite replaces user-defined variables with corresponding bind variables. +func (nz *normalizer) udvRewrite(cursor *Cursor, node *Variable) { + udv := strings.ToLower(node.Name.CompliantName()) + cursor.Replace(NewArgument(UserDefinedVariableName + udv)) + nz.bindVarNeeds.AddUserDefVar(udv) +} + +// funcRewrite replaces certain function expressions with bind variables. +func (nz *normalizer) funcRewrite(cursor *Cursor, node *FuncExpr) { + lowered := node.Name.Lowered() + if lowered == "last_insert_id" && len(node.Exprs) > 0 { + // Do not rewrite LAST_INSERT_ID() when it has arguments. + return + } + bindVar, found := funcRewrites[lowered] + if !found || (bindVar == DBVarName && !nz.shouldRewriteDatabaseFunc) { + return + } + if len(node.Exprs) > 0 { + nz.err = vterrors.Errorf(vtrpcpb.Code_UNIMPLEMENTED, "Argument to %s() not supported", lowered) + return + } + cursor.Replace(NewArgument(bindVar)) + nz.bindVarNeeds.AddFuncResult(bindVar) +} + +// unnestSubQueries attempts to simplify dual subqueries where possible. +// select (select database() from dual) from test +// => +// select database() from test +func (nz *normalizer) unnestSubQueries(cursor *Cursor, subquery *Subquery) { + if _, isExists := cursor.Parent().(*ExistsExpr); isExists { + return + } + sel, isSimpleSelect := subquery.Select.(*Select) + if !isSimpleSelect { + return + } + + if len(sel.SelectExprs) != 1 || + len(sel.OrderBy) != 0 || + sel.GroupBy != nil || + len(sel.From) != 1 || + sel.Where != nil || + sel.Having != nil || + sel.Limit != nil || sel.Lock != NoLock { + return + } + + aliasedTable, ok := sel.From[0].(*AliasedTableExpr) + if !ok { + return + } + table, ok := aliasedTable.Expr.(TableName) + if !ok || table.Name.String() != "dual" { + return + } + expr, ok := sel.SelectExprs[0].(*AliasedExpr) + if !ok { + return + } + _, isColName := expr.Expr.(*ColName) + if isColName { + // Skip if the subquery already returns a column name. + return + } + nz.bindVarNeeds.NoteRewrite() + rewritten := SafeRewrite(expr.Expr, nz.walkDown, nz.walkUp) + + // Handle special cases for IN clauses. + rewrittenExpr, isExpr := rewritten.(Expr) + _, isColTuple := rewritten.(ColTuple) + comparisonExpr, isCompExpr := cursor.Parent().(*ComparisonExpr) + if isCompExpr && (comparisonExpr.Operator == InOp || comparisonExpr.Operator == NotInOp) && !isColTuple && isExpr { + cursor.Replace(ValTuple{rewrittenExpr}) + return + } + + cursor.Replace(rewritten) +} + +// existsRewrite optimizes EXISTS expressions where possible. +func (nz *normalizer) existsRewrite(cursor *Cursor, node *ExistsExpr) { + sel, ok := node.Subquery.Select.(*Select) + if !ok { + return + } + + if sel.Having != nil { + // Cannot optimize if HAVING clause is present. + return + } + + if sel.GroupBy == nil && sel.SelectExprs.AllAggregation() { + // Replace EXISTS with a boolean true if guaranteed to be non-empty. + cursor.Replace(BoolVal(true)) + return + } + + // Simplify the subquery by selecting a constant. + // WHERE EXISTS(SELECT 1 FROM ...) + sel.SelectExprs = SelectExprs{ + &AliasedExpr{Expr: NewIntLiteral("1")}, + } + sel.GroupBy = nil +} + +// rewriteDistinctableAggr removes DISTINCT from certain aggregations to simplify the plan. +func (nz *normalizer) rewriteDistinctableAggr(node DistinctableAggr) { + if !node.IsDistinct() { + return + } + switch aggr := node.(type) { + case *Max, *Min: + aggr.SetDistinct(false) + nz.bindVarNeeds.NoteRewrite() + } +} + +func (nz *normalizer) shouldParameterize() bool { + return !(nz.inDerived > 0 && len(nz.onLeave) > 0) && nz.parameterize +} + +// SystemSchema checks if the given schema is a system schema. +func SystemSchema(schema string) bool { + return strings.EqualFold(schema, "information_schema") || + strings.EqualFold(schema, "performance_schema") || + strings.EqualFold(schema, "sys") || + strings.EqualFold(schema, "mysql") +} diff --git a/go/vt/sqlparser/normalizer_test.go b/go/vt/sqlparser/normalizer_test.go index 7919a321c91..7c3e660ac9d 100644 --- a/go/vt/sqlparser/normalizer_test.go +++ b/go/vt/sqlparser/normalizer_test.go @@ -25,8 +25,9 @@ import ( "strings" "testing" - "github.com/stretchr/testify/assert" + "vitess.io/vitess/go/vt/sysvars" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "vitess.io/vitess/go/sqltypes" @@ -448,10 +449,11 @@ func TestNormalize(t *testing.T) { t.Run(tc.in, func(t *testing.T) { stmt, err := parser.Parse(tc.in) require.NoError(t, err) - known := GetBindvars(stmt) + known := getBindvars(stmt) bv := make(map[string]*querypb.BindVariable) - require.NoError(t, Normalize(stmt, NewReservedVars(prefix, known), bv)) - assert.Equal(t, tc.outstmt, String(stmt)) + out, err := PrepareAST(stmt, NewReservedVars(prefix, known), bv, true, "ks", 0, "", map[string]string{}, nil, nil) + require.NoError(t, err) + assert.Equal(t, tc.outstmt, String(out.AST)) assert.Equal(t, tc.outbv, bv) }) } @@ -476,9 +478,10 @@ func TestNormalizeInvalidDates(t *testing.T) { t.Run(tc.in, func(t *testing.T) { stmt, err := parser.Parse(tc.in) require.NoError(t, err) - known := GetBindvars(stmt) + known := getBindvars(stmt) bv := make(map[string]*querypb.BindVariable) - require.EqualError(t, Normalize(stmt, NewReservedVars("bv", known), bv), tc.err.Error()) + _, err = PrepareAST(stmt, NewReservedVars("bv", known), bv, true, "ks", 0, "", map[string]string{}, nil, nil) + require.EqualError(t, err, tc.err.Error()) }) } } @@ -498,9 +501,10 @@ func TestNormalizeValidSQL(t *testing.T) { } bv := make(map[string]*querypb.BindVariable) known := make(BindVars) - err = Normalize(tree, NewReservedVars("vtg", known), bv) + + out, err := PrepareAST(tree, NewReservedVars("vtg", known), bv, true, "ks", 0, "", map[string]string{}, nil, nil) require.NoError(t, err) - normalizerOutput := String(tree) + normalizerOutput := String(out.AST) if normalizerOutput == "otheradmin" || normalizerOutput == "otherread" { return } @@ -529,9 +533,9 @@ func TestNormalizeOneCasae(t *testing.T) { } bv := make(map[string]*querypb.BindVariable) known := make(BindVars) - err = Normalize(tree, NewReservedVars("vtg", known), bv) + out, err := PrepareAST(tree, NewReservedVars("vtg", known), bv, true, "ks", 0, "", map[string]string{}, nil, nil) require.NoError(t, err) - normalizerOutput := String(tree) + normalizerOutput := String(out.AST) require.EqualValues(t, testOne.output, normalizerOutput) if normalizerOutput == "otheradmin" || normalizerOutput == "otherread" { return @@ -546,7 +550,7 @@ func TestGetBindVars(t *testing.T) { if err != nil { t.Fatal(err) } - got := GetBindvars(stmt) + got := getBindvars(stmt) want := map[string]struct{}{ "v1": {}, "v2": {}, @@ -559,6 +563,586 @@ func TestGetBindVars(t *testing.T) { } } +type testCaseSetVar struct { + in, expected, setVarComment string +} + +type testCaseSysVar struct { + in, expected string + sysVar map[string]string +} + +type myTestCase struct { + in, expected string + liid, db, foundRows, rowCount, rawGTID, rawTimeout, sessTrackGTID bool + ddlStrategy, migrationContext, sessionUUID, sessionEnableSystemSettings bool + udv int + autocommit, foreignKeyChecks, clientFoundRows, skipQueryPlanCache, socket, queryTimeout bool + sqlSelectLimit, transactionMode, workload, version, versionComment bool +} + +func TestRewrites(in *testing.T) { + tests := []myTestCase{{ + in: "SELECT 42", + expected: "SELECT 42", + // no bindvar needs + }, { + in: "SELECT @@version", + expected: "SELECT :__vtversion as `@@version`", + version: true, + }, { + in: "SELECT @@query_timeout", + expected: "SELECT :__vtquery_timeout as `@@query_timeout`", + queryTimeout: true, + }, { + in: "SELECT @@version_comment", + expected: "SELECT :__vtversion_comment as `@@version_comment`", + versionComment: true, + }, { + in: "SELECT @@enable_system_settings", + expected: "SELECT :__vtenable_system_settings as `@@enable_system_settings`", + sessionEnableSystemSettings: true, + }, { + in: "SELECT last_insert_id()", + expected: "SELECT :__lastInsertId as `last_insert_id()`", + liid: true, + }, { + in: "SELECT database()", + expected: "SELECT :__vtdbname as `database()`", + db: true, + }, { + in: "SELECT database() from test", + expected: "SELECT database() from test", + // no bindvar needs + }, { + in: "SELECT last_insert_id() as test", + expected: "SELECT :__lastInsertId as test", + liid: true, + }, { + in: "SELECT last_insert_id() + database()", + expected: "SELECT :__lastInsertId + :__vtdbname as `last_insert_id() + database()`", + db: true, liid: true, + }, { + // unnest database() call + in: "select (select database()) from test", + expected: "select database() as `(select database() from dual)` from test", + // no bindvar needs + }, { + // unnest database() call + in: "select (select database() from dual) from test", + expected: "select database() as `(select database() from dual)` from test", + // no bindvar needs + }, { + in: "select (select database() from dual) from dual", + expected: "select :__vtdbname as `(select database() from dual)` from dual", + db: true, + }, { + // don't unnest solo columns + in: "select 1 as foobar, (select foobar)", + expected: "select 1 as foobar, (select foobar from dual) from dual", + }, { + in: "select id from user where database()", + expected: "select id from user where database()", + // no bindvar needs + }, { + in: "select table_name from information_schema.tables where table_schema = database()", + expected: "select table_name from information_schema.tables where table_schema = database()", + // no bindvar needs + }, { + in: "select schema()", + expected: "select :__vtdbname as `schema()`", + db: true, + }, { + in: "select found_rows()", + expected: "select :__vtfrows as `found_rows()`", + foundRows: true, + }, { + in: "select @`x y`", + expected: "select :__vtudvx_y as `@``x y``` from dual", + udv: 1, + }, { + in: "select id from t where id = @x and val = @y", + expected: "select id from t where id = :__vtudvx and val = :__vtudvy", + db: false, udv: 2, + }, { + in: "insert into t(id) values(@xyx)", + expected: "insert into t(id) values(:__vtudvxyx)", + db: false, udv: 1, + }, { + in: "select row_count()", + expected: "select :__vtrcount as `row_count()`", + rowCount: true, + }, { + in: "SELECT lower(database())", + expected: "SELECT lower(:__vtdbname) as `lower(database())`", + db: true, + }, { + in: "SELECT @@autocommit", + expected: "SELECT :__vtautocommit as `@@autocommit`", + autocommit: true, + }, { + in: "SELECT @@client_found_rows", + expected: "SELECT :__vtclient_found_rows as `@@client_found_rows`", + clientFoundRows: true, + }, { + in: "SELECT @@skip_query_plan_cache", + expected: "SELECT :__vtskip_query_plan_cache as `@@skip_query_plan_cache`", + skipQueryPlanCache: true, + }, { + in: "SELECT @@sql_select_limit", + expected: "SELECT :__vtsql_select_limit as `@@sql_select_limit`", + sqlSelectLimit: true, + }, { + in: "SELECT @@transaction_mode", + expected: "SELECT :__vttransaction_mode as `@@transaction_mode`", + transactionMode: true, + }, { + in: "SELECT @@workload", + expected: "SELECT :__vtworkload as `@@workload`", + workload: true, + }, { + in: "SELECT @@socket", + expected: "SELECT :__vtsocket as `@@socket`", + socket: true, + }, { + in: "select (select 42) from dual", + expected: "select 42 as `(select 42 from dual)` from dual", + }, { + in: "select * from user where col = (select 42)", + expected: "select * from user where col = 42", + }, { + in: "select * from (select 42) as t", // this is not an expression, and should not be rewritten + expected: "select * from (select 42) as t", + }, { + in: `select (select (select (select (select (select last_insert_id()))))) as x`, + expected: "select :__lastInsertId as x from dual", + liid: true, + }, { + in: `select * from (select last_insert_id()) as t`, + expected: "select * from (select :__lastInsertId as `last_insert_id()` from dual) as t", + liid: true, + }, { + in: `select * from user where col = @@ddl_strategy`, + expected: "select * from user where col = :__vtddl_strategy", + ddlStrategy: true, + }, { + in: `select * from user where col = @@migration_context`, + expected: "select * from user where col = :__vtmigration_context", + migrationContext: true, + }, { + in: `select * from user where col = @@read_after_write_gtid OR col = @@read_after_write_timeout OR col = @@session_track_gtids`, + expected: "select * from user where col = :__vtread_after_write_gtid or col = :__vtread_after_write_timeout or col = :__vtsession_track_gtids", + rawGTID: true, rawTimeout: true, sessTrackGTID: true, + }, { + in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual)", + expected: "SELECT * FROM tbl WHERE id IN (1)", + }, { + in: "SELECT * FROM tbl WHERE id IN (SELECT last_insert_id() FROM dual)", + expected: "SELECT * FROM tbl WHERE id IN (:__lastInsertId)", + liid: true, + }, { + in: "SELECT * FROM tbl WHERE id IN (SELECT (SELECT 1 FROM dual WHERE 1 = 0) FROM dual)", + expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual WHERE 1 = 0)", + }, { + in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual WHERE 1 = 0)", + expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual WHERE 1 = 0)", + }, { + in: "SELECT * FROM tbl WHERE id IN (SELECT 1,2 FROM dual)", + expected: "SELECT * FROM tbl WHERE id IN (SELECT 1,2 FROM dual)", + }, { + in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual ORDER BY 1)", + expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual ORDER BY 1)", + }, { + in: "SELECT * FROM tbl WHERE id IN (SELECT id FROM user GROUP BY id)", + expected: "SELECT * FROM tbl WHERE id IN (SELECT id FROM user GROUP BY id)", + }, { + in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual, user)", + expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual, user)", + }, { + in: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual limit 1)", + expected: "SELECT * FROM tbl WHERE id IN (SELECT 1 FROM dual limit 1)", + }, { + // SELECT * behaves different depending the join type used, so if that has been used, we won't rewrite + in: "SELECT * FROM A JOIN B USING (id1,id2,id3)", + expected: "SELECT * FROM A JOIN B USING (id1,id2,id3)", + }, { + in: "CALL proc(@foo)", + expected: "CALL proc(:__vtudvfoo)", + udv: 1, + }, { + in: "SELECT * FROM tbl WHERE NOT id = 42", + expected: "SELECT * FROM tbl WHERE id != 42", + }, { + in: "SELECT * FROM tbl WHERE not id < 12", + expected: "SELECT * FROM tbl WHERE id >= 12", + }, { + in: "SELECT * FROM tbl WHERE not id > 12", + expected: "SELECT * FROM tbl WHERE id <= 12", + }, { + in: "SELECT * FROM tbl WHERE not id <= 33", + expected: "SELECT * FROM tbl WHERE id > 33", + }, { + in: "SELECT * FROM tbl WHERE not id >= 33", + expected: "SELECT * FROM tbl WHERE id < 33", + }, { + in: "SELECT * FROM tbl WHERE not id != 33", + expected: "SELECT * FROM tbl WHERE id = 33", + }, { + in: "SELECT * FROM tbl WHERE not id in (1,2,3)", + expected: "SELECT * FROM tbl WHERE id not in (1,2,3)", + }, { + in: "SELECT * FROM tbl WHERE not id not in (1,2,3)", + expected: "SELECT * FROM tbl WHERE id in (1,2,3)", + }, { + in: "SELECT * FROM tbl WHERE not id not in (1,2,3)", + expected: "SELECT * FROM tbl WHERE id in (1,2,3)", + }, { + in: "SELECT * FROM tbl WHERE not id like '%foobar'", + expected: "SELECT * FROM tbl WHERE id not like '%foobar'", + }, { + in: "SELECT * FROM tbl WHERE not id not like '%foobar'", + expected: "SELECT * FROM tbl WHERE id like '%foobar'", + }, { + in: "SELECT * FROM tbl WHERE not id regexp '%foobar'", + expected: "SELECT * FROM tbl WHERE id not regexp '%foobar'", + }, { + in: "SELECT * FROM tbl WHERE not id not regexp '%foobar'", + expected: "select * from tbl where id regexp '%foobar'", + }, { + in: "SELECT * FROM tbl WHERE exists(select col1, col2 from other_table where foo > bar)", + expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar)", + }, { + in: "SELECT * FROM tbl WHERE exists(select col1, col2 from other_table where foo > bar limit 100 offset 34)", + expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar limit 100 offset 34)", + }, { + in: "SELECT * FROM tbl WHERE exists(select col1, col2, count(*) from other_table where foo > bar group by col1, col2)", + expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar)", + }, { + in: "SELECT * FROM tbl WHERE exists(select col1, col2 from other_table where foo > bar group by col1, col2)", + expected: "SELECT * FROM tbl WHERE exists(select 1 from other_table where foo > bar)", + }, { + in: "SELECT * FROM tbl WHERE exists(select count(*) from other_table where foo > bar)", + expected: "SELECT * FROM tbl WHERE true", + }, { + in: "SELECT * FROM tbl WHERE exists(select col1, col2, count(*) from other_table where foo > bar group by col1, col2 having count(*) > 3)", + expected: "SELECT * FROM tbl WHERE exists(select col1, col2, count(*) from other_table where foo > bar group by col1, col2 having count(*) > 3)", + }, { + in: "SELECT id, name, salary FROM user_details", + expected: "SELECT id, name, salary FROM (select user.id, user.name, user_extra.salary from user join user_extra where user.id = user_extra.user_id) as user_details", + }, { + in: "select max(distinct c1), min(distinct c2), avg(distinct c3), sum(distinct c4), count(distinct c5), group_concat(distinct c6) from tbl", + expected: "select max(c1) as `max(distinct c1)`, min(c2) as `min(distinct c2)`, avg(distinct c3), sum(distinct c4), count(distinct c5), group_concat(distinct c6) from tbl", + }, { + in: "SHOW VARIABLES", + expected: "SHOW VARIABLES", + autocommit: true, + foreignKeyChecks: true, + clientFoundRows: true, + skipQueryPlanCache: true, + sqlSelectLimit: true, + transactionMode: true, + workload: true, + version: true, + versionComment: true, + ddlStrategy: true, + migrationContext: true, + sessionUUID: true, + sessionEnableSystemSettings: true, + rawGTID: true, + rawTimeout: true, + sessTrackGTID: true, + socket: true, + queryTimeout: true, + }, { + in: "SHOW GLOBAL VARIABLES", + expected: "SHOW GLOBAL VARIABLES", + autocommit: true, + foreignKeyChecks: true, + clientFoundRows: true, + skipQueryPlanCache: true, + sqlSelectLimit: true, + transactionMode: true, + workload: true, + version: true, + versionComment: true, + ddlStrategy: true, + migrationContext: true, + sessionUUID: true, + sessionEnableSystemSettings: true, + rawGTID: true, + rawTimeout: true, + sessTrackGTID: true, + socket: true, + queryTimeout: true, + }} + parser := NewTestParser() + for _, tc := range tests { + in.Run(tc.in, func(t *testing.T) { + require := require.New(t) + stmt, known, err := parser.Parse2(tc.in) + require.NoError(err) + vars := NewReservedVars("v", known) + result, err := PrepareAST( + stmt, + vars, + map[string]*querypb.BindVariable{}, + false, + "ks", + 0, + "", + map[string]string{}, + nil, + &fakeViews{}, + ) + require.NoError(err) + + expected, err := parser.Parse(tc.expected) + require.NoError(err, "test expectation does not parse [%s]", tc.expected) + + s := String(expected) + assert := assert.New(t) + assert.Equal(s, String(result.AST)) + assert.Equal(tc.liid, result.NeedsFuncResult(LastInsertIDName), "should need last insert id") + assert.Equal(tc.db, result.NeedsFuncResult(DBVarName), "should need database name") + assert.Equal(tc.foundRows, result.NeedsFuncResult(FoundRowsName), "should need found rows") + assert.Equal(tc.rowCount, result.NeedsFuncResult(RowCountName), "should need row count") + assert.Equal(tc.udv, len(result.NeedUserDefinedVariables), "count of user defined variables") + assert.Equal(tc.autocommit, result.NeedsSysVar(sysvars.Autocommit.Name), "should need :__vtautocommit") + assert.Equal(tc.foreignKeyChecks, result.NeedsSysVar(sysvars.ForeignKeyChecks), "should need :__vtforeignKeyChecks") + assert.Equal(tc.clientFoundRows, result.NeedsSysVar(sysvars.ClientFoundRows.Name), "should need :__vtclientFoundRows") + assert.Equal(tc.skipQueryPlanCache, result.NeedsSysVar(sysvars.SkipQueryPlanCache.Name), "should need :__vtskipQueryPlanCache") + assert.Equal(tc.sqlSelectLimit, result.NeedsSysVar(sysvars.SQLSelectLimit.Name), "should need :__vtsqlSelectLimit") + assert.Equal(tc.transactionMode, result.NeedsSysVar(sysvars.TransactionMode.Name), "should need :__vttransactionMode") + assert.Equal(tc.workload, result.NeedsSysVar(sysvars.Workload.Name), "should need :__vtworkload") + assert.Equal(tc.queryTimeout, result.NeedsSysVar(sysvars.QueryTimeout.Name), "should need :__vtquery_timeout") + assert.Equal(tc.ddlStrategy, result.NeedsSysVar(sysvars.DDLStrategy.Name), "should need ddlStrategy") + assert.Equal(tc.migrationContext, result.NeedsSysVar(sysvars.MigrationContext.Name), "should need migrationContext") + assert.Equal(tc.sessionUUID, result.NeedsSysVar(sysvars.SessionUUID.Name), "should need sessionUUID") + assert.Equal(tc.sessionEnableSystemSettings, result.NeedsSysVar(sysvars.SessionEnableSystemSettings.Name), "should need sessionEnableSystemSettings") + assert.Equal(tc.rawGTID, result.NeedsSysVar(sysvars.ReadAfterWriteGTID.Name), "should need rawGTID") + assert.Equal(tc.rawTimeout, result.NeedsSysVar(sysvars.ReadAfterWriteTimeOut.Name), "should need rawTimeout") + assert.Equal(tc.sessTrackGTID, result.NeedsSysVar(sysvars.SessionTrackGTIDs.Name), "should need sessTrackGTID") + assert.Equal(tc.version, result.NeedsSysVar(sysvars.Version.Name), "should need Vitess version") + assert.Equal(tc.versionComment, result.NeedsSysVar(sysvars.VersionComment.Name), "should need Vitess version") + assert.Equal(tc.socket, result.NeedsSysVar(sysvars.Socket.Name), "should need :__vtsocket") + }) + } +} + +type fakeViews struct{} + +func (*fakeViews) FindView(name TableName) TableStatement { + if name.Name.String() != "user_details" { + return nil + } + parser := NewTestParser() + statement, err := parser.Parse("select user.id, user.name, user_extra.salary from user join user_extra where user.id = user_extra.user_id") + if err != nil { + return nil + } + return statement.(TableStatement) +} + +func TestRewritesWithSetVarComment(in *testing.T) { + tests := []testCaseSetVar{{ + in: "select 1", + expected: "select 1", + setVarComment: "", + }, { + in: "select 1", + expected: "select /*+ AA(a) */ 1", + setVarComment: "AA(a)", + }, { + in: "insert /* toto */ into t(id) values(1)", + expected: "insert /*+ AA(a) */ /* toto */ into t(id) values(1)", + setVarComment: "AA(a)", + }, { + in: "select /* toto */ * from t union select * from s", + expected: "select /*+ AA(a) */ /* toto */ * from t union select /*+ AA(a) */ * from s", + setVarComment: "AA(a)", + }, { + in: "vstream /* toto */ * from t1", + expected: "vstream /*+ AA(a) */ /* toto */ * from t1", + setVarComment: "AA(a)", + }, { + in: "stream /* toto */ t from t1", + expected: "stream /*+ AA(a) */ /* toto */ t from t1", + setVarComment: "AA(a)", + }, { + in: "update /* toto */ t set id = 1", + expected: "update /*+ AA(a) */ /* toto */ t set id = 1", + setVarComment: "AA(a)", + }, { + in: "delete /* toto */ from t", + expected: "delete /*+ AA(a) */ /* toto */ from t", + setVarComment: "AA(a)", + }} + + parser := NewTestParser() + for _, tc := range tests { + in.Run(tc.in, func(t *testing.T) { + require := require.New(t) + stmt, err := parser.Parse(tc.in) + require.NoError(err) + vars := NewReservedVars("v", nil) + result, err := PrepareAST( + stmt, + vars, + map[string]*querypb.BindVariable{}, + false, + "ks", + 0, + tc.setVarComment, + map[string]string{}, + nil, + &fakeViews{}, + ) + + require.NoError(err) + + expected, err := parser.Parse(tc.expected) + require.NoError(err, "test expectation does not parse [%s]", tc.expected) + + assert.Equal(t, String(expected), String(result.AST)) + }) + } +} + +func TestRewritesSysVar(in *testing.T) { + tests := []testCaseSysVar{{ + in: "select @x = @@sql_mode", + expected: "select :__vtudvx = @@sql_mode as `@x = @@sql_mode` from dual", + }, { + in: "select @x = @@sql_mode", + expected: "select :__vtudvx = :__vtsql_mode as `@x = @@sql_mode` from dual", + sysVar: map[string]string{"sql_mode": "' '"}, + }, { + in: "SELECT @@tx_isolation", + expected: "select @@tx_isolation from dual", + }, { + in: "SELECT @@transaction_isolation", + expected: "select @@transaction_isolation from dual", + }, { + in: "SELECT @@session.transaction_isolation", + expected: "select @@session.transaction_isolation from dual", + }, { + in: "SELECT @@tx_isolation", + sysVar: map[string]string{"tx_isolation": "'READ-COMMITTED'"}, + expected: "select :__vttx_isolation as `@@tx_isolation` from dual", + }, { + in: "SELECT @@transaction_isolation", + sysVar: map[string]string{"transaction_isolation": "'READ-COMMITTED'"}, + expected: "select :__vttransaction_isolation as `@@transaction_isolation` from dual", + }, { + in: "SELECT @@session.transaction_isolation", + sysVar: map[string]string{"transaction_isolation": "'READ-COMMITTED'"}, + expected: "select :__vttransaction_isolation as `@@session.transaction_isolation` from dual", + }} + + parser := NewTestParser() + for _, tc := range tests { + in.Run(tc.in, func(t *testing.T) { + require := require.New(t) + stmt, err := parser.Parse(tc.in) + require.NoError(err) + vars := NewReservedVars("v", nil) + result, err := PrepareAST( + stmt, + vars, + map[string]*querypb.BindVariable{}, + false, + "ks", + 0, + "", + tc.sysVar, + nil, + &fakeViews{}, + ) + + require.NoError(err) + + expected, err := parser.Parse(tc.expected) + require.NoError(err, "test expectation does not parse [%s]", tc.expected) + + assert.Equal(t, String(expected), String(result.AST)) + }) + } +} + +func TestRewritesWithDefaultKeyspace(in *testing.T) { + tests := []myTestCase{{ + in: "SELECT 1 from x.test", + expected: "SELECT 1 from x.test", // no change + }, { + in: "SELECT x.col as c from x.test", + expected: "SELECT x.col as c from x.test", // no change + }, { + in: "SELECT 1 from test", + expected: "SELECT 1 from sys.test", + }, { + in: "SELECT 1 from test as t", + expected: "SELECT 1 from sys.test as t", + }, { + in: "SELECT 1 from `test 24` as t", + expected: "SELECT 1 from sys.`test 24` as t", + }, { + in: "SELECT 1, (select 1 from test) from x.y", + expected: "SELECT 1, (select 1 from sys.test) from x.y", + }, { + in: "SELECT 1 from (select 2 from test) t", + expected: "SELECT 1 from (select 2 from sys.test) t", + }, { + in: "SELECT 1 from test where exists(select 2 from test)", + expected: "SELECT 1 from sys.test where exists(select 1 from sys.test)", + }, { + in: "SELECT 1 from dual", + expected: "SELECT 1 from dual", + }, { + in: "SELECT (select 2 from dual) from DUAL", + expected: "SELECT 2 as `(select 2 from dual)` from DUAL", + }} + + parser := NewTestParser() + for _, tc := range tests { + in.Run(tc.in, func(t *testing.T) { + require := require.New(t) + stmt, err := parser.Parse(tc.in) + require.NoError(err) + vars := NewReservedVars("v", nil) + result, err := PrepareAST( + stmt, + vars, + map[string]*querypb.BindVariable{}, + false, + "sys", + 0, + "", + map[string]string{}, + nil, + &fakeViews{}, + ) + + require.NoError(err) + + expected, err := parser.Parse(tc.expected) + require.NoError(err, "test expectation does not parse [%s]", tc.expected) + + assert.Equal(t, String(expected), String(result.AST)) + }) + } +} + +func TestReservedVars(t *testing.T) { + for _, prefix := range []string{"vtg", "bv"} { + t.Run("prefix_"+prefix, func(t *testing.T) { + reserved := NewReservedVars(prefix, make(BindVars)) + for i := 1; i < 1000; i++ { + require.Equal(t, fmt.Sprintf("%s%d", prefix, i), reserved.nextUnusedVar()) + } + }) + } +} + /* Skipping ColName, TableName: BenchmarkNormalize-8 1000000 2205 ns/op 821 B/op 27 allocs/op @@ -573,7 +1157,8 @@ func BenchmarkNormalize(b *testing.B) { b.Fatal(err) } for i := 0; i < b.N; i++ { - require.NoError(b, Normalize(ast, NewReservedVars("", reservedVars), map[string]*querypb.BindVariable{})) + _, err := PrepareAST(ast, NewReservedVars("", reservedVars), map[string]*querypb.BindVariable{}, true, "ks", 0, "", map[string]string{}, nil, nil) + require.NoError(b, err) } } @@ -602,7 +1187,8 @@ func BenchmarkNormalizeTraces(b *testing.B) { for i := 0; i < b.N; i++ { for i, query := range parsed { - _ = Normalize(query, NewReservedVars("", reservedVars[i]), map[string]*querypb.BindVariable{}) + _, err := PrepareAST(query, NewReservedVars("", reservedVars[i]), map[string]*querypb.BindVariable{}, true, "ks", 0, "", map[string]string{}, nil, nil) + require.NoError(b, err) } } }) diff --git a/go/vt/sqlparser/redact_query.go b/go/vt/sqlparser/redact_query.go index e6b8c009c68..2d018d7c0eb 100644 --- a/go/vt/sqlparser/redact_query.go +++ b/go/vt/sqlparser/redact_query.go @@ -28,10 +28,10 @@ func (p *Parser) RedactSQLQuery(sql string) (string, error) { return "", err } - err = Normalize(stmt, NewReservedVars("redacted", reservedVars), bv) + out, err := PrepareAST(stmt, NewReservedVars("redacted", reservedVars), bv, true, "ks", 0, "", map[string]string{}, nil, nil) if err != nil { return "", err } - return comments.Leading + String(stmt) + comments.Trailing, nil + return comments.Leading + String(out.AST) + comments.Trailing, nil } diff --git a/go/vt/sqlparser/utils.go b/go/vt/sqlparser/utils.go index b785128917f..c56e7740fc5 100644 --- a/go/vt/sqlparser/utils.go +++ b/go/vt/sqlparser/utils.go @@ -41,11 +41,12 @@ func (p *Parser) QueryMatchesTemplates(query string, queryTemplates []string) (m if err != nil { return "", err } - err = Normalize(stmt, NewReservedVars("", reservedVars), bv) + + out, err := PrepareAST(stmt, NewReservedVars("", reservedVars), bv, true, "ks", 0, "", map[string]string{}, nil, nil) if err != nil { return "", err } - normalized := CanonicalString(stmt) + normalized := CanonicalString(out.AST) return normalized, nil } diff --git a/go/vt/vtgate/executor_test.go b/go/vt/vtgate/executor_test.go index 904805e789b..5e7e5c2a07d 100644 --- a/go/vt/vtgate/executor_test.go +++ b/go/vt/vtgate/executor_test.go @@ -1860,6 +1860,30 @@ func TestPassthroughDDL(t *testing.T) { sbc2.Queries = nil } +func TestShowStatus(t *testing.T) { + executor, sbc1, _, _, ctx := createExecutorEnvWithConfig(t, createExecutorConfigWithNormalizer()) + session := &vtgatepb.Session{ + TargetString: "TestExecutor", + } + + sql1 := "show slave status" + _, err := executorExec(ctx, executor, session, sql1, nil) + require.NoError(t, err) + + sql2 := "show replica status" + _, err = executorExec(ctx, executor, session, sql2, nil) + require.NoError(t, err) + + wantQueries := []*querypb.BoundQuery{{ + Sql: sql1, + BindVariables: map[string]*querypb.BindVariable{}, + }, { + Sql: sql2, + BindVariables: map[string]*querypb.BindVariable{}, + }} + assert.Equal(t, wantQueries, sbc1.Queries) +} + func TestParseEmptyTargetSingleKeyspace(t *testing.T) { r, _, _, _, _ := createExecutorEnv(t) diff --git a/go/vt/vtgate/planbuilder/builder.go b/go/vt/vtgate/planbuilder/builder.go index 065c50a6dfa..85d9f5f94ea 100644 --- a/go/vt/vtgate/planbuilder/builder.go +++ b/go/vt/vtgate/planbuilder/builder.go @@ -73,7 +73,7 @@ func (staticConfig) DirectEnabled() bool { // TestBuilder builds a plan for a query based on the specified vschema. // This method is only used from tests func TestBuilder(query string, vschema plancontext.VSchema, keyspace string) (*engine.Plan, error) { - stmt, reserved, err := vschema.Environment().Parser().Parse2(query) + stmt, known, err := vschema.Environment().Parser().Parse2(query) if err != nil { return nil, err } @@ -93,12 +93,12 @@ func TestBuilder(query string, vschema plancontext.VSchema, keyspace string) (*e }() } } - result, err := sqlparser.RewriteAST(stmt, keyspace, sqlparser.SQLSelectLimitUnset, "", nil, vschema.GetForeignKeyChecksState(), vschema) + reservedVars := sqlparser.NewReservedVars("vtg", known) + result, err := sqlparser.PrepareAST(stmt, reservedVars, map[string]*querypb.BindVariable{}, false, keyspace, sqlparser.SQLSelectLimitUnset, "", nil, vschema.GetForeignKeyChecksState(), vschema) if err != nil { return nil, err } - reservedVars := sqlparser.NewReservedVars("vtg", reserved) return BuildFromStmt(context.Background(), query, result.AST, reservedVars, vschema, result.BindVarNeeds, staticConfig{}) } diff --git a/go/vt/vtgate/planbuilder/simplifier_test.go b/go/vt/vtgate/planbuilder/simplifier_test.go index 5aeb0565f9b..012475ba021 100644 --- a/go/vt/vtgate/planbuilder/simplifier_test.go +++ b/go/vt/vtgate/planbuilder/simplifier_test.go @@ -21,6 +21,8 @@ import ( "fmt" "testing" + querypb "vitess.io/vitess/go/vt/proto/query" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" @@ -45,8 +47,8 @@ func TestSimplifyBuggyQuery(t *testing.T) { stmt, reserved, err := sqlparser.NewTestParser().Parse2(query) require.NoError(t, err) - rewritten, _ := sqlparser.RewriteAST(sqlparser.Clone(stmt), vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) reservedVars := sqlparser.NewReservedVars("vtg", reserved) + rewritten, _ := sqlparser.PrepareAST(sqlparser.Clone(stmt), reservedVars, map[string]*querypb.BindVariable{}, false, vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.TableStatement), @@ -69,8 +71,8 @@ func TestSimplifyPanic(t *testing.T) { stmt, reserved, err := sqlparser.NewTestParser().Parse2(query) require.NoError(t, err) - rewritten, _ := sqlparser.RewriteAST(sqlparser.Clone(stmt), vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) reservedVars := sqlparser.NewReservedVars("vtg", reserved) + rewritten, _ := sqlparser.PrepareAST(sqlparser.Clone(stmt), reservedVars, map[string]*querypb.BindVariable{}, false, vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) simplified := simplifier.SimplifyStatement( stmt.(sqlparser.TableStatement), @@ -100,12 +102,12 @@ func TestUnsupportedFile(t *testing.T) { t.Skip() return } - rewritten, err := sqlparser.RewriteAST(stmt, vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) + reservedVars := sqlparser.NewReservedVars("vtg", reserved) + rewritten, err := sqlparser.PrepareAST(stmt, reservedVars, map[string]*querypb.BindVariable{}, false, vw.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) if err != nil { t.Skip() } - reservedVars := sqlparser.NewReservedVars("vtg", reserved) ast := rewritten.AST origQuery := sqlparser.String(ast) stmt, _, _ = sqlparser.NewTestParser().Parse2(tcase.Query) @@ -133,7 +135,7 @@ func keepSameError(query string, reservedVars *sqlparser.ReservedVars, vschema * if err != nil { panic(err) } - rewritten, _ := sqlparser.RewriteAST(stmt, vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) + rewritten, _ := sqlparser.PrepareAST(stmt, reservedVars, map[string]*querypb.BindVariable{}, false, vschema.CurrentDb(), sqlparser.SQLSelectLimitUnset, "", nil, nil, nil) ast := rewritten.AST _, expected := BuildFromStmt(context.Background(), query, ast, reservedVars, vschema, rewritten.BindVarNeeds, staticConfig{}) if expected == nil { diff --git a/go/vt/vtgate/semantics/typer_test.go b/go/vt/vtgate/semantics/typer_test.go index 7de5ecf1340..1ec642b8168 100644 --- a/go/vt/vtgate/semantics/typer_test.go +++ b/go/vt/vtgate/semantics/typer_test.go @@ -41,15 +41,16 @@ func TestNormalizerAndSemanticAnalysisIntegration(t *testing.T) { for _, test := range tests { t.Run(test.query, func(t *testing.T) { - parse, err := sqlparser.NewTestParser().Parse(test.query) + parse, known, err := sqlparser.NewTestParser().Parse2(test.query) require.NoError(t, err) - err = sqlparser.Normalize(parse, sqlparser.NewReservedVars("bv", sqlparser.BindVars{}), map[string]*querypb.BindVariable{}) + rv := sqlparser.NewReservedVars("", known) + out, err := sqlparser.PrepareAST(parse, rv, map[string]*querypb.BindVariable{}, true, "d", 0, "", map[string]string{}, nil, nil) require.NoError(t, err) - st, err := Analyze(parse, "d", fakeSchemaInfo()) + st, err := Analyze(out.AST, "d", fakeSchemaInfo()) require.NoError(t, err) - bv := parse.(*sqlparser.Select).SelectExprs[0].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Argument) + bv := out.AST.(*sqlparser.Select).SelectExprs[0].(*sqlparser.AliasedExpr).Expr.(*sqlparser.Argument) typ, found := st.ExprTypes[bv] require.True(t, found, "bindvar was not typed") require.Equal(t, test.typ, typ.Type().String()) @@ -68,15 +69,15 @@ func TestColumnCollations(t *testing.T) { for _, test := range tests { t.Run(test.query, func(t *testing.T) { - parse, err := sqlparser.NewTestParser().Parse(test.query) + ast, err := sqlparser.NewTestParser().Parse(test.query) require.NoError(t, err) - err = sqlparser.Normalize(parse, sqlparser.NewReservedVars("bv", sqlparser.BindVars{}), map[string]*querypb.BindVariable{}) + out, err := sqlparser.PrepareAST(ast, sqlparser.NewReservedVars("bv", sqlparser.BindVars{}), map[string]*querypb.BindVariable{}, true, "d", 0, "", map[string]string{}, nil, nil) require.NoError(t, err) - st, err := Analyze(parse, "d", fakeSchemaInfo()) + st, err := Analyze(out.AST, "d", fakeSchemaInfo()) require.NoError(t, err) - col := extract(parse.(*sqlparser.Select), 0) + col := extract(out.AST.(*sqlparser.Select), 0) typ, found := st.TypeForExpr(col) require.True(t, found, "column was not typed")