diff --git a/go/vt/schemadiff/view.go b/go/vt/schemadiff/view.go index 8783f1803bb..ff34d772dcf 100644 --- a/go/vt/schemadiff/view.go +++ b/go/vt/schemadiff/view.go @@ -427,6 +427,6 @@ func (c *CreateViewEntity) identicalOtherThanName(other *CreateViewEntity) bool c.IsReplace == other.IsReplace && sqlparser.Equals.RefOfDefiner(c.Definer, other.Definer) && sqlparser.Equals.Columns(c.Columns, other.Columns) && - sqlparser.Equals.SelectStatement(c.Select, other.Select) && + sqlparser.Equals.Statement(c.Select, other.Select) && sqlparser.Equals.RefOfParsedComments(c.Comments, other.Comments) } diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index cd1e032a56a..c1c7a0eeef2 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -82,6 +82,11 @@ type ( SetWith(with *With) } + Distinctable interface { + MakeDistinct() + IsDistinct() bool + } + // SelectStatement any SELECT statement. SelectStatement interface { Statement @@ -90,12 +95,11 @@ type ( Commented ColumnResults Withable + Distinctable iSelectStatement() GetLock() Lock SetLock(lock Lock) SetInto(into *SelectInto) - MakeDistinct() - IsDistinct() bool } // DDLStatement represents any DDL Statement diff --git a/go/vt/sqlparser/ast_rewriting.go b/go/vt/sqlparser/ast_rewriting.go index f059efee6ac..b7ebf541950 100644 --- a/go/vt/sqlparser/ast_rewriting.go +++ b/go/vt/sqlparser/ast_rewriting.go @@ -532,7 +532,7 @@ func (er *astRewriter) existsRewrite(cursor *Cursor, node *ExistsExpr) { sel.GroupBy = nil } -// rewriteDistinctableAggr removed Distinct from Max and Min Aggregations as it does not impact the result. But, makes the plan simpler. +// rewriteDistinctableAggr removed Distinctable from Max and Min Aggregations as it does not impact the result. But, makes the plan simpler. func (er *astRewriter) rewriteDistinctableAggr(cursor *Cursor, node DistinctableAggr) { if !node.IsDistinct() { return diff --git a/go/vt/vtgate/planbuilder/ddl.go b/go/vt/vtgate/planbuilder/ddl.go index a0045cec060..e8a4b6f1baf 100644 --- a/go/vt/vtgate/planbuilder/ddl.go +++ b/go/vt/vtgate/planbuilder/ddl.go @@ -196,7 +196,7 @@ func buildCreateViewCommon( vschema plancontext.VSchema, reservedVars *sqlparser.ReservedVars, cfg dynamicconfig.DDL, - ddlSelect sqlparser.SelectStatement, + ddlSelect sqlparser.TableSubquery, ddl sqlparser.DDLStatement, ) (key.Destination, *vindexes.Keyspace, error) { // For Create View, we require that the keyspace exist and the select query can be satisfied within the keyspace itself diff --git a/go/vt/vtgate/planbuilder/operators/SQL_builder.go b/go/vt/vtgate/planbuilder/operators/SQL_builder.go index fc91569981d..577046cacf5 100644 --- a/go/vt/vtgate/planbuilder/operators/SQL_builder.go +++ b/go/vt/vtgate/planbuilder/operators/SQL_builder.go @@ -37,8 +37,9 @@ type ( } ) -func (qb *queryBuilder) asSelectStatement() sqlparser.SelectStatement { - return qb.stmt.(sqlparser.SelectStatement) +func (qb *queryBuilder) asSelectStatement() sqlparser.TableSubquery { + return qb.stmt.(sqlparser.TableSubquery) + } func (qb *queryBuilder) asOrderAndLimit() sqlparser.OrderAndLimit { return qb.stmt.(sqlparser.OrderAndLimit) @@ -191,7 +192,8 @@ func (qb *queryBuilder) pushUnionInsideDerived() { As: sqlparser.NewIdentifierCS("dt"), }}, } - sel.SelectExprs = unionSelects(sqlparser.GetFirstSelect(selStmt).SelectExprs) + firstSelect := getFirstSelect(selStmt) + sel.SelectExprs = unionSelects(firstSelect.SelectExprs) qb.stmt = sel } @@ -208,9 +210,10 @@ func unionSelects(exprs sqlparser.SelectExprs) (selectExprs sqlparser.SelectExpr return } -func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.SelectStatement) { +func checkUnionColumnByName(column *sqlparser.ColName, sel sqlparser.TableSubquery) { colName := column.Name.String() - exprs := sqlparser.GetFirstSelect(sel).SelectExprs + firstSelect := getFirstSelect(sel) + exprs := firstSelect.SelectExprs offset := slices.IndexFunc(exprs, func(expr sqlparser.SelectExpr) bool { switch ae := expr.(type) { case *sqlparser.StarExpr: @@ -244,8 +247,8 @@ func (qb *queryBuilder) unionWith(other *queryBuilder, distinct bool) { func (qb *queryBuilder) recursiveCteWith(other *queryBuilder, name, alias string, distinct bool, columns sqlparser.Columns) { cteUnion := &sqlparser.Union{ - Left: qb.stmt.(sqlparser.SelectStatement), - Right: other.stmt.(sqlparser.SelectStatement), + Left: qb.stmt.(sqlparser.TableSubquery), + Right: other.stmt.(sqlparser.TableSubquery), Distinct: distinct, } @@ -393,7 +396,7 @@ func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) { } } -func stripDownQuery(from, to sqlparser.SelectStatement) { +func stripDownQuery(from, to sqlparser.TableSubquery) { switch node := from.(type) { case *sqlparser.Select: toNode, ok := to.(*sqlparser.Select) @@ -450,7 +453,12 @@ func buildQuery(op Operator, qb *queryBuilder) { buildUnion(op, qb) case *Distinct: buildQuery(op.Source, qb) - qb.asSelectStatement().MakeDistinct() + statement := qb.asSelectStatement() + d, ok := statement.(sqlparser.Distinctable) + if !ok { + panic(vterrors.VT13001("expected a select statement with distinct")) + } + d.MakeDistinct() case *Update: buildUpdate(op, qb) case *Delete: diff --git a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go index 73169369a41..ced81df147a 100644 --- a/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go +++ b/go/vt/vtgate/planbuilder/operators/aggregation_pushing.go @@ -544,7 +544,7 @@ func splitAggrColumnsToLeftAndRight( canPushDistinctAggr, distinctExprs := checkIfWeCanPush(ctx, aggregator) - // Distinct aggregation cannot be pushed down in the join. + // Distinctable aggregation cannot be pushed down in the join. // We keep node of the distinct aggregation expression to be used later for ordering. if !canPushDistinctAggr { if len(distinctExprs) != 1 { diff --git a/go/vt/vtgate/planbuilder/operators/ast_to_op.go b/go/vt/vtgate/planbuilder/operators/ast_to_op.go index 12c19bb72a6..08449e83341 100644 --- a/go/vt/vtgate/planbuilder/operators/ast_to_op.go +++ b/go/vt/vtgate/planbuilder/operators/ast_to_op.go @@ -178,7 +178,7 @@ func createOperatorFromUnion(ctx *plancontext.PlanningContext, node *sqlparser.U return newHorizon(union, node) } -func translateQueryToOpForUnion(ctx *plancontext.PlanningContext, node sqlparser.SelectStatement) Operator { +func translateQueryToOpForUnion(ctx *plancontext.PlanningContext, node sqlparser.TableSubquery) Operator { op := translateQueryToOp(ctx, node) if hz, ok := op.(*Horizon); ok { hz.Truncate = true diff --git a/go/vt/vtgate/planbuilder/operators/expressions.go b/go/vt/vtgate/planbuilder/operators/expressions.go index 38848693775..504b2b87ddf 100644 --- a/go/vt/vtgate/planbuilder/operators/expressions.go +++ b/go/vt/vtgate/planbuilder/operators/expressions.go @@ -17,7 +17,10 @@ limitations under the License. package operators import ( + "fmt" + "vitess.io/vitess/go/vt/sqlparser" + "vitess.io/vitess/go/vt/vterrors" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" "vitess.io/vitess/go/vt/vtgate/semantics" ) @@ -121,3 +124,11 @@ func simplifyPredicates(ctx *plancontext.PlanningContext, in sqlparser.Expr) sql } return output } + +func getFirstSelect(selStmt sqlparser.TableSubquery) *sqlparser.Select { + firstSelect, err := sqlparser.GetFirstSelect(selStmt) + if err != nil { + panic(vterrors.VT12001(fmt.Sprintf("first UNION part not a SELECT: %v", err))) + } + return firstSelect +} diff --git a/go/vt/vtgate/planbuilder/operators/horizon.go b/go/vt/vtgate/planbuilder/operators/horizon.go index 292be1b37c5..e3bf76d9d05 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon.go +++ b/go/vt/vtgate/planbuilder/operators/horizon.go @@ -45,7 +45,7 @@ type Horizon struct { // QP contains the QueryProjection for this op QP *QueryProjection - Query sqlparser.SelectStatement + Query sqlparser.TableSubquery // Columns needed to feed other plans Columns []*sqlparser.ColName @@ -54,7 +54,7 @@ type Horizon struct { Truncate bool } -func newHorizon(src Operator, query sqlparser.SelectStatement) *Horizon { +func newHorizon(src Operator, query sqlparser.TableSubquery) *Horizon { return &Horizon{ unaryOperator: newUnaryOp(src), Query: query, @@ -148,7 +148,7 @@ func (h *Horizon) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, return -1 } - for idx, se := range sqlparser.GetFirstSelect(h.Query).SelectExprs { + for idx, se := range getFirstSelect(h.Query).SelectExprs { ae, ok := se.(*sqlparser.AliasedExpr) if !ok { panic(vterrors.VT09015()) @@ -174,7 +174,7 @@ func (h *Horizon) GetColumns(ctx *plancontext.PlanningContext) (exprs []*sqlpars } func (h *Horizon) GetSelectExprs(*plancontext.PlanningContext) sqlparser.SelectExprs { - return sqlparser.GetFirstSelect(h.Query).SelectExprs + return getFirstSelect(h.Query).SelectExprs } func (h *Horizon) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { @@ -185,7 +185,7 @@ func (h *Horizon) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy { } // TODO: REMOVE -func (h *Horizon) selectStatement() sqlparser.SelectStatement { +func (h *Horizon) selectStatement() sqlparser.TableSubquery { return h.Query } diff --git a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go index dad5ad3a91a..2b53bff74f3 100644 --- a/go/vt/vtgate/planbuilder/operators/horizon_expanding.go +++ b/go/vt/vtgate/planbuilder/operators/horizon_expanding.go @@ -95,7 +95,7 @@ func expandSelectHorizon(ctx *plancontext.PlanningContext, horizon *Horizon, sel if qp.NeedsDistinct() { op = newDistinct(op, qp, true) - extracted = append(extracted, "Distinct") + extracted = append(extracted, "Distinctable") } if sel.Having != nil { diff --git a/go/vt/vtgate/planbuilder/operators/insert.go b/go/vt/vtgate/planbuilder/operators/insert.go index 4ce37901a77..a51cda54334 100644 --- a/go/vt/vtgate/planbuilder/operators/insert.go +++ b/go/vt/vtgate/planbuilder/operators/insert.go @@ -394,7 +394,7 @@ func createInsertOperator(ctx *plancontext.PlanningContext, insStmt *sqlparser.I case sqlparser.Values: op = route route.Source = insertRowsPlan(ctx, insOp, insStmt, rows) - case sqlparser.SelectStatement: + case sqlparser.TableSubquery: op = insertSelectPlan(ctx, insOp, route, insStmt, rows) } if insStmt.Comments != nil { @@ -408,7 +408,7 @@ func insertSelectPlan( insOp *Insert, routeOp *Route, ins *sqlparser.Insert, - sel sqlparser.SelectStatement, + sel sqlparser.TableSubquery, ) *InsertSelection { if columnMismatch(insOp.AutoIncrement, ins, sel) { panic(vterrors.VT03006()) @@ -457,7 +457,7 @@ func insertSelectPlan( return insertSelect } -func columnMismatch(gen *Generate, ins *sqlparser.Insert, sel sqlparser.SelectStatement) bool { +func columnMismatch(gen *Generate, ins *sqlparser.Insert, sel sqlparser.TableSubquery) bool { origColCount := len(ins.Columns) if gen != nil && gen.added { // One column got added to the insert query ast for auto increment column. @@ -468,7 +468,7 @@ func columnMismatch(gen *Generate, ins *sqlparser.Insert, sel sqlparser.SelectSt return true } if origColCount > sel.GetColumnCount() { - sel := sqlparser.GetFirstSelect(sel) + sel := getFirstSelect(sel) var hasStarExpr bool for _, sExpr := range sel.SelectExprs { if _, hasStarExpr = sExpr.(*sqlparser.StarExpr); hasStarExpr { diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index eb6c42b8724..c005c398498 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -59,7 +59,7 @@ func (p Phase) String() string { case addAggrOrdering: return "optimize aggregations with ORDER BY" case cleanOutPerfDistinct: - return "optimize Distinct operations" + return "optimize Distinctable operations" case subquerySettling: return "settle subqueries" case dmlWithInput: diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 5fe0c7773c1..db716966d47 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -32,7 +32,7 @@ import ( func planQuery(ctx *plancontext.PlanningContext, root Operator) Operator { var selExpr sqlparser.SelectExprs if horizon, isHorizon := root.(*Horizon); isHorizon { - sel := sqlparser.GetFirstSelect(horizon.Query) + sel := getFirstSelect(horizon.Query) selExpr = sqlparser.Clone(sel.SelectExprs) } @@ -207,7 +207,7 @@ func pushOrExpandHorizon(ctx *plancontext.PlanningContext, in *Horizon) (Operato !hasHaving && !needsOrdering && !qp.NeedsAggregation() && - !in.selectStatement().IsDistinct() && + !isDistinctAST(in.selectStatement()) && in.selectStatement().GetLimit() == nil if canPush { @@ -784,7 +784,7 @@ func isDistinct(op Operator) bool { case *Union: return op.distinct case *Horizon: - return op.Query.IsDistinct() + return isDistinctAST(op.Query) case *Limit: return isDistinct(op.Source) default: @@ -792,6 +792,13 @@ func isDistinct(op Operator) bool { } } +func isDistinctAST(s sqlparser.Statement) bool { + if d, ok := s.(sqlparser.Distinctable); ok { + return d.IsDistinct() + } + return false +} + func tryPushUnion(ctx *plancontext.PlanningContext, op *Union) (Operator, *ApplyResult) { if res := compactUnion(op); res != NoRewrite { return op, res diff --git a/go/vt/vtgate/planbuilder/operators/queryprojection.go b/go/vt/vtgate/planbuilder/operators/queryprojection.go index a245831ca13..9136dee9455 100644 --- a/go/vt/vtgate/planbuilder/operators/queryprojection.go +++ b/go/vt/vtgate/planbuilder/operators/queryprojection.go @@ -203,7 +203,7 @@ func (qp *QueryProjection) addSelectExpressions(ctx *plancontext.PlanningContext func createQPFromUnion(ctx *plancontext.PlanningContext, union *sqlparser.Union) *QueryProjection { qp := &QueryProjection{} - sel := sqlparser.GetFirstSelect(union) + sel := getFirstSelect(union) qp.addSelectExpressions(ctx, sel) qp.addOrderBy(ctx, union.OrderBy) @@ -714,7 +714,7 @@ func CompareRefInt(a *int, b *int) bool { return *a < *b } -func CreateQPFromSelectStatement(ctx *plancontext.PlanningContext, stmt sqlparser.SelectStatement) *QueryProjection { +func CreateQPFromSelectStatement(ctx *plancontext.PlanningContext, stmt sqlparser.TableSubquery) *QueryProjection { switch sel := stmt.(type) { case *sqlparser.Select: return createQPFromSelect(ctx, sel) diff --git a/go/vt/vtgate/planbuilder/operators/subquery_builder.go b/go/vt/vtgate/planbuilder/operators/subquery_builder.go index c2256df06f4..afd325afa9b 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_builder.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_builder.go @@ -115,7 +115,7 @@ func createSubqueryOp( // inspectStatement goes through all the predicates contained in the AST // and extracts subqueries into operators func (sqb *SubQueryBuilder) inspectStatement(ctx *plancontext.PlanningContext, - stmt sqlparser.SelectStatement, + stmt sqlparser.TableSubquery, ) (sqlparser.Exprs, []applyJoinColumn) { switch stmt := stmt.(type) { case *sqlparser.Select: diff --git a/go/vt/vtgate/planbuilder/operators/subquery_planning.go b/go/vt/vtgate/planbuilder/operators/subquery_planning.go index e222ae0f343..c248975a8df 100644 --- a/go/vt/vtgate/planbuilder/operators/subquery_planning.go +++ b/go/vt/vtgate/planbuilder/operators/subquery_planning.go @@ -30,7 +30,7 @@ import ( "vitess.io/vitess/go/vt/vtgate/semantics" ) -func isMergeable(ctx *plancontext.PlanningContext, query sqlparser.SelectStatement, op Operator) bool { +func isMergeable(ctx *plancontext.PlanningContext, query sqlparser.TableSubquery, op Operator) bool { validVindex := func(expr sqlparser.Expr) bool { sc := findColumnVindex(ctx, op, expr) return sc != nil && sc.IsUnique() @@ -672,7 +672,7 @@ func (s *subqueryRouteMerger) rewriteASTExpression(ctx *plancontext.PlanningCont if err != nil { panic(err) } - subqStmt, ok := stmt.(sqlparser.SelectStatement) + subqStmt, ok := stmt.(sqlparser.TableSubquery) if !ok { panic(vterrors.VT13001("subqueries should only be select statement")) } @@ -700,7 +700,7 @@ func (s *subqueryRouteMerger) rewriteASTExpression(ctx *plancontext.PlanningCont if !deps.IsSolvedBy(subqID) { cursor.Replace(exprFound) } - }, nil).(sqlparser.SelectStatement) + }, nil).(sqlparser.TableSubquery) if err != nil { panic(err) } diff --git a/go/vt/vtgate/planbuilder/operators/union.go b/go/vt/vtgate/planbuilder/operators/union.go index 7d09391cf7d..1c692c9f38e 100644 --- a/go/vt/vtgate/planbuilder/operators/union.go +++ b/go/vt/vtgate/planbuilder/operators/union.go @@ -150,7 +150,7 @@ func (u *Union) GetSelectFor(source int) *sqlparser.Select { for { switch op := src.(type) { case *Horizon: - return sqlparser.GetFirstSelect(op.Query) + return getFirstSelect(op.Query) case *Route: src = op.Source default: diff --git a/go/vt/vtgate/planbuilder/simplifier_test.go b/go/vt/vtgate/planbuilder/simplifier_test.go index dce21b3e175..a382874d8e3 100644 --- a/go/vt/vtgate/planbuilder/simplifier_test.go +++ b/go/vt/vtgate/planbuilder/simplifier_test.go @@ -49,7 +49,7 @@ func TestSimplifyBuggyQuery(t *testing.T) { reservedVars := sqlparser.NewReservedVars("vtg", reserved) simplified := simplifier.SimplifyStatement( - stmt.(sqlparser.SelectStatement), + stmt.(sqlparser.TableSubquery), vw.CurrentDb(), vw, keepSameError(query, reservedVars, vw, rewritten.BindVarNeeds), @@ -73,7 +73,7 @@ func TestSimplifyPanic(t *testing.T) { reservedVars := sqlparser.NewReservedVars("vtg", reserved) simplified := simplifier.SimplifyStatement( - stmt.(sqlparser.SelectStatement), + stmt.(sqlparser.TableSubquery), vw.CurrentDb(), vw, keepPanicking(query, reservedVars, vw, rewritten.BindVarNeeds), @@ -95,7 +95,7 @@ func TestUnsupportedFile(t *testing.T) { log.Errorf("unsupported_cases.txt - %s", tcase.Query) stmt, reserved, err := sqlparser.NewTestParser().Parse2(tcase.Query) require.NoError(t, err) - _, ok := stmt.(sqlparser.SelectStatement) + _, ok := stmt.(sqlparser.TableSubquery) if !ok { t.Skip() return @@ -110,7 +110,7 @@ func TestUnsupportedFile(t *testing.T) { origQuery := sqlparser.String(ast) stmt, _, _ = sqlparser.NewTestParser().Parse2(tcase.Query) simplified := simplifier.SimplifyStatement( - stmt.(sqlparser.SelectStatement), + stmt.(sqlparser.TableSubquery), vw.CurrentDb(), vw, keepSameError(tcase.Query, reservedVars, vw, rewritten.BindVarNeeds), @@ -128,7 +128,7 @@ func TestUnsupportedFile(t *testing.T) { } } -func keepSameError(query string, reservedVars *sqlparser.ReservedVars, vschema *vschemawrapper.VSchemaWrapper, needs *sqlparser.BindVarNeeds) func(statement sqlparser.SelectStatement) bool { +func keepSameError(query string, reservedVars *sqlparser.ReservedVars, vschema *vschemawrapper.VSchemaWrapper, needs *sqlparser.BindVarNeeds) func(statement sqlparser.TableSubquery) bool { stmt, _, err := sqlparser.NewTestParser().Parse2(query) if err != nil { panic(err) @@ -139,7 +139,7 @@ func keepSameError(query string, reservedVars *sqlparser.ReservedVars, vschema * if expected == nil { panic("query does not fail to plan") } - return func(statement sqlparser.SelectStatement) bool { + return func(statement sqlparser.TableSubquery) bool { _, myErr := BuildFromStmt(context.Background(), query, statement, reservedVars, vschema, needs, staticConfig{}) if myErr == nil { return false @@ -152,8 +152,8 @@ func keepSameError(query string, reservedVars *sqlparser.ReservedVars, vschema * } } -func keepPanicking(query string, reservedVars *sqlparser.ReservedVars, vschema *vschemawrapper.VSchemaWrapper, needs *sqlparser.BindVarNeeds) func(statement sqlparser.SelectStatement) bool { - cmp := func(statement sqlparser.SelectStatement) (res bool) { +func keepPanicking(query string, reservedVars *sqlparser.ReservedVars, vschema *vschemawrapper.VSchemaWrapper, needs *sqlparser.BindVarNeeds) func(statement sqlparser.TableSubquery) bool { + cmp := func(statement sqlparser.TableSubquery) (res bool) { defer func() { r := recover() if r != nil { @@ -172,7 +172,7 @@ func keepPanicking(query string, reservedVars *sqlparser.ReservedVars, vschema * if err != nil { panic(err.Error()) } - if !cmp(stmt.(sqlparser.SelectStatement)) { + if !cmp(stmt.(sqlparser.TableSubquery)) { panic("query is not panicking") } diff --git a/go/vt/vtgate/semantics/semantic_table.go b/go/vt/vtgate/semantics/semantic_table.go index 30a41ba5f12..bc5ecd81f3b 100644 --- a/go/vt/vtgate/semantics/semantic_table.go +++ b/go/vt/vtgate/semantics/semantic_table.go @@ -493,7 +493,7 @@ func (st *SemTable) ForeignKeysPresent() bool { return false } -func (st *SemTable) SelectExprs(sel sqlparser.SelectStatement) sqlparser.SelectExprs { +func (st *SemTable) SelectExprs(sel sqlparser.TableSubquery) sqlparser.SelectExprs { switch sel := sel.(type) { case *sqlparser.Select: return sel.SelectExprs diff --git a/go/vt/vtgate/simplifier/simplifier.go b/go/vt/vtgate/simplifier/simplifier.go index e838450e3a2..3af5f20a0a2 100644 --- a/go/vt/vtgate/simplifier/simplifier.go +++ b/go/vt/vtgate/simplifier/simplifier.go @@ -25,17 +25,17 @@ import ( // SimplifyStatement simplifies the AST of a query. It basically iteratively prunes leaves of the AST, as long as the pruning // continues to return true from the `test` function. func SimplifyStatement( - in sqlparser.SelectStatement, + in sqlparser.TableSubquery, currentDB string, si semantics.SchemaInformation, - testF func(sqlparser.SelectStatement) bool, -) sqlparser.SelectStatement { + testF func(sqlparser.TableSubquery) bool, +) sqlparser.TableSubquery { tables, err := getTables(in, currentDB, si) if err != nil { panic(err) } - test := func(s sqlparser.SelectStatement) bool { + test := func(s sqlparser.TableSubquery) bool { // Since our semantic analysis changes the AST, we clone it first, so we have a pristine AST to play with return testF(sqlparser.Clone(s)) } @@ -68,7 +68,7 @@ func SimplifyStatement( return in } -func trySimplifyDistinct(in sqlparser.SelectStatement, test func(statement sqlparser.SelectStatement) bool) sqlparser.SelectStatement { +func trySimplifyDistinct(in sqlparser.TableSubquery, test func(statement sqlparser.TableSubquery) bool) sqlparser.TableSubquery { simplified := false alwaysVisitChildren := func(node, parent sqlparser.SQLNode) bool { return true @@ -100,7 +100,7 @@ func trySimplifyDistinct(in sqlparser.SelectStatement, test func(statement sqlpa return nil } -func trySimplifyExpressions(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement { +func trySimplifyExpressions(in sqlparser.TableSubquery, test func(sqlparser.TableSubquery) bool) sqlparser.TableSubquery { simplified := false visit := func(cursor expressionCursor) bool { // first - let's try to remove the expression @@ -141,7 +141,7 @@ func trySimplifyExpressions(in sqlparser.SelectStatement, test func(sqlparser.Se return nil } -func trySimplifyUnions(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) (res sqlparser.SelectStatement) { +func trySimplifyUnions(in sqlparser.TableSubquery, test func(subquery sqlparser.TableSubquery) bool) (res sqlparser.TableSubquery) { if union, ok := in.(*sqlparser.Union); ok { // the root object is an UNION if test(sqlparser.Clone(union.Left)) { @@ -193,7 +193,7 @@ func trySimplifyUnions(in sqlparser.SelectStatement, test func(sqlparser.SelectS return nil } -func tryRemoveTable(tables []semantics.TableInfo, in sqlparser.SelectStatement, currentDB string, si semantics.SchemaInformation, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement { +func tryRemoveTable(tables []semantics.TableInfo, in sqlparser.TableSubquery, currentDB string, si semantics.SchemaInformation, test func(sqlparser.TableSubquery) bool) sqlparser.TableSubquery { // we start by removing one table at a time, and see if we still have an interesting plan for idx, tbl := range tables { clone := sqlparser.Clone(in) @@ -209,7 +209,7 @@ func tryRemoveTable(tables []semantics.TableInfo, in sqlparser.SelectStatement, return nil } -func getTables(in sqlparser.SelectStatement, currentDB string, si semantics.SchemaInformation) ([]semantics.TableInfo, error) { +func getTables(in sqlparser.TableSubquery, currentDB string, si semantics.SchemaInformation) ([]semantics.TableInfo, error) { // Since our semantic analysis changes the AST, we clone it first, so we have a pristine AST to play with clone := sqlparser.Clone(in) semTable, err := semantics.Analyze(clone, currentDB, si) @@ -219,7 +219,7 @@ func getTables(in sqlparser.SelectStatement, currentDB string, si semantics.Sche return semTable.Tables, nil } -func simplifyStarExpr(in sqlparser.SelectStatement, test func(sqlparser.SelectStatement) bool) sqlparser.SelectStatement { +func simplifyStarExpr(in sqlparser.TableSubquery, test func(sqlparser.TableSubquery) bool) sqlparser.TableSubquery { simplified := false alwaysVisitChildren := func(node, parent sqlparser.SQLNode) bool { return true @@ -254,7 +254,7 @@ func simplifyStarExpr(in sqlparser.SelectStatement, test func(sqlparser.SelectSt // removeTable removes the table with the given index from the select statement, which includes the FROM clause // but also all expressions and predicates that depend on the table -func removeTable(clone sqlparser.SelectStatement, searchedTS semantics.TableSet, db string, si semantics.SchemaInformation) bool { +func removeTable(clone sqlparser.TableSubquery, searchedTS semantics.TableSet, db string, si semantics.SchemaInformation) bool { semTable, err := semantics.Analyze(clone, db, si) if err != nil { panic(err) @@ -429,7 +429,7 @@ func newExprCursor(expr sqlparser.Expr, replace func(replaceWith sqlparser.Expr) // This cursor has a few extra capabilities that the normal sqlparser.SafeRewrite does not have, // such as visiting and being able to change individual expressions in a AND tree // if visit returns true, then traversal continues, otherwise traversal stops -func visitAllExpressionsInAST(clone sqlparser.SelectStatement, visit func(expressionCursor) bool) { +func visitAllExpressionsInAST(clone sqlparser.TableSubquery, visit func(expressionCursor) bool) { alwaysVisitChildren := func(node, parent sqlparser.SQLNode) bool { return true } diff --git a/go/vt/vtgate/simplifier/simplifier_test.go b/go/vt/vtgate/simplifier/simplifier_test.go index 340497da8ef..3d1cdd7544b 100644 --- a/go/vt/vtgate/simplifier/simplifier_test.go +++ b/go/vt/vtgate/simplifier/simplifier_test.go @@ -52,7 +52,7 @@ limit 123 offset 456 ` ast, err := sqlparser.NewTestParser().Parse(query) require.NoError(t, err) - visitAllExpressionsInAST(ast.(sqlparser.SelectStatement), func(cursor expressionCursor) bool { + visitAllExpressionsInAST(ast.(sqlparser.TableSubquery), func(cursor expressionCursor) bool { fmt.Printf(">> found expression: %s\n", sqlparser.String(cursor.expr)) cursor.remove() fmt.Printf("remove: %s\n", sqlparser.String(ast)) @@ -70,7 +70,7 @@ func TestAbortExpressionCursor(t *testing.T) { query := "select user.id, count(*), unsharded.name from user join unsharded on 13 = 14 where unsharded.id = 42 and name = 'foo' and user.id = unsharded.id" ast, err := sqlparser.NewTestParser().Parse(query) require.NoError(t, err) - visitAllExpressionsInAST(ast.(sqlparser.SelectStatement), func(cursor expressionCursor) bool { + visitAllExpressionsInAST(ast.(sqlparser.TableSubquery), func(cursor expressionCursor) bool { fmt.Println(sqlparser.String(cursor.expr)) cursor.replace(sqlparser.NewIntLiteral("1")) fmt.Println(sqlparser.String(ast))