diff --git a/go/vt/vtgate/planbuilder/operators/info_schema_planning.go b/go/vt/vtgate/planbuilder/operators/info_schema_planning.go index 1e15237f30d..2d54e012c7e 100644 --- a/go/vt/vtgate/planbuilder/operators/info_schema_planning.go +++ b/go/vt/vtgate/planbuilder/operators/info_schema_planning.go @@ -21,6 +21,8 @@ import ( "slices" "strings" + "vitess.io/vitess/go/vt/vtgate/semantics" + "vitess.io/vitess/go/mysql/collations" "vitess.io/vitess/go/sqltypes" "vitess.io/vitess/go/vt/sqlparser" @@ -102,6 +104,10 @@ func (isr *InfoSchemaRouting) updateRoutingLogic(ctx *plancontext.PlanningContex return isr } +func (isr *InfoSchemaRouting) AddValuesTableID(id semantics.TableSet) { + panic(vterrors.VT13001("think about values and info schema routing")) +} + func (isr *InfoSchemaRouting) Cost() int { return 0 } diff --git a/go/vt/vtgate/planbuilder/operators/misc_routing.go b/go/vt/vtgate/planbuilder/operators/misc_routing.go index 575aa7b4e9a..27415b39355 100644 --- a/go/vt/vtgate/planbuilder/operators/misc_routing.go +++ b/go/vt/vtgate/planbuilder/operators/misc_routing.go @@ -21,6 +21,7 @@ import ( "vitess.io/vitess/go/vt/sqlparser" "vitess.io/vitess/go/vt/vtgate/engine" "vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext" + "vitess.io/vitess/go/vt/vtgate/semantics" "vitess.io/vitess/go/vt/vtgate/vindexes" ) @@ -77,6 +78,7 @@ func (tr *TargetedRouting) Clone() Routing { func (tr *TargetedRouting) updateRoutingLogic(_ *plancontext.PlanningContext, _ sqlparser.Expr) Routing { return tr } +func (tr *TargetedRouting) AddValuesTableID(semantics.TableSet) {} func (tr *TargetedRouting) Cost() int { return 1 @@ -102,6 +104,8 @@ func (n *NoneRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlparser return n } +func (*NoneRouting) AddValuesTableID(semantics.TableSet) {} + func (n *NoneRouting) Cost() int { return 0 } @@ -129,6 +133,8 @@ func (rr *AnyShardRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlp return rr } +func (tr *AnyShardRouting) AddValuesTableID(semantics.TableSet) {} + func (rr *AnyShardRouting) Cost() int { return 0 } @@ -166,6 +172,8 @@ func (dr *DualRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlparse return dr } +func (tr *DualRouting) AddValuesTableID(semantics.TableSet) {} + func (dr *DualRouting) Cost() int { return 0 } @@ -191,14 +199,16 @@ func (sr *SequenceRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlp return sr } -func (sr *SequenceRouting) Cost() int { +func (*SequenceRouting) AddValuesTableID(semantics.TableSet) {} + +func (*SequenceRouting) Cost() int { return 0 } -func (sr *SequenceRouting) OpCode() engine.Opcode { +func (*SequenceRouting) OpCode() engine.Opcode { return engine.Next } -func (sr *SequenceRouting) Keyspace() *vindexes.Keyspace { +func (*SequenceRouting) Keyspace() *vindexes.Keyspace { return nil } diff --git a/go/vt/vtgate/planbuilder/operators/phases.go b/go/vt/vtgate/planbuilder/operators/phases.go index bca7d4ebab0..69111b62cd2 100644 --- a/go/vt/vtgate/planbuilder/operators/phases.go +++ b/go/vt/vtgate/planbuilder/operators/phases.go @@ -114,37 +114,80 @@ func (p Phase) act(ctx *plancontext.PlanningContext, op Operator) Operator { case dmlWithInput: return findDMLAboveRoute(ctx, op) case rewriteApplyJoin: - visit := func(op Operator, lhsTables semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { - aj, ok := op.(*ApplyJoin) - if !ok { - return op, NoRewrite - } + return rewriteApplyToValues(ctx, op) - vj := newValuesJoin(ctx, aj.LHS, aj.RHS, aj.JoinType) - if vj == nil { - return op, NoRewrite - } + default: + return op + } +} - for _, column := range aj.JoinColumns.columns { - vj.AddColumn(ctx, true, false, aeWrap(column.Original)) +func rewriteApplyToValues(ctx *plancontext.PlanningContext, op Operator) Operator { + var skipped []sqlparser.Expr + isSkipped := func(expr sqlparser.Expr) bool { + for _, skip := range skipped { + if ctx.SemTable.EqualsExpr(expr, skip) { + return true } + } + return false + } + + // Traverse the operator tree to convert ApplyJoin to ValuesJoin. + // Then add a Values node to the RHS of the new ValuesJoin, + // and usually a filter containing the join predicates is placed there. + visit := func(op Operator, lhsTables semantics.TableSet, isRoot bool) (Operator, *ApplyResult) { + aj, ok := op.(*ApplyJoin) + if !ok { + return op, NoRewrite + } + + vj := newValuesJoin(ctx, aj.LHS, aj.RHS, aj.JoinType) + if vj == nil { + return op, NoRewrite + } - for _, pred := range aj.JoinPredicates.columns { - err := ctx.SkipJoinPredicates(pred.Original) - if err != nil { - panic(err) - } - vj.AddJoinPredicate(ctx, pred.Original) + for _, column := range aj.JoinColumns.columns { + vj.AddColumn(ctx, true, false, aeWrap(column.Original)) + } + + for _, pred := range aj.JoinPredicates.columns { + skipped = append(skipped, pred.RHSExpr) + err := ctx.SkipJoinPredicates(pred.Original) + if err != nil { + panic(err) } + vj.AddJoinPredicate(ctx, pred.Original) + } + + return vj, Rewrote("rewrote ApplyJoin to ValuesJoin") + } + + shouldVisit := func(op Operator) VisitRule { + rb, ok := op.(*Route) + if !ok { + return VisitChildren + } - return vj, Rewrote("rewrote ApplyJoin to ValuesJoin") + routing, ok := rb.Routing.(*ShardedRouting) + if !ok { + return SkipChildren } - return TopDown(op, TableID, visit, stopAtRoute) + // We need to skip the predicates that are already pushed down to the mysql - + // we will push down the JoinValues predicates, and they will be used for routing + var preds []sqlparser.Expr + for _, pred := range routing.SeenPredicates { + if !isSkipped(pred) { + preds = append(preds, pred) + } + } + routing.SeenPredicates = preds - default: - return op + rb.Routing = routing.resetRoutingLogic(ctx) + return SkipChildren } + + return TopDown(op, TableID, visit, shouldVisit) } func newValuesJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType sqlparser.JoinType) *ValuesJoin { diff --git a/go/vt/vtgate/planbuilder/operators/query_planning.go b/go/vt/vtgate/planbuilder/operators/query_planning.go index 82df7285b12..6c7a38faf0f 100644 --- a/go/vt/vtgate/planbuilder/operators/query_planning.go +++ b/go/vt/vtgate/planbuilder/operators/query_planning.go @@ -134,9 +134,7 @@ func tryPushValues(ctx *plancontext.PlanningContext, in *Values) (Operator, *App case *Filter: return Swap(in, src, "pushed values under filter") case *Route: - if !reachedPhase(ctx, rewriteApplyJoin+1) { - return in, NoRewrite - } + src.Routing.AddValuesTableID(in.TableID) return Swap(in, src, "pushed values under route") } return in, NoRewrite diff --git a/go/vt/vtgate/planbuilder/operators/route.go b/go/vt/vtgate/planbuilder/operators/route.go index 9aeafec2799..532f76a3219 100644 --- a/go/vt/vtgate/planbuilder/operators/route.go +++ b/go/vt/vtgate/planbuilder/operators/route.go @@ -101,13 +101,15 @@ type ( OpCode() engine.Opcode Keyspace() *vindexes.Keyspace // note that all routings do not have a keyspace, so this method can return nil + AddValuesTableID(id semantics.TableSet) + // updateRoutingLogic updates the routing to take predicates into account. This can be used for routing // using vindexes or for figuring out which keyspace an information_schema query should be sent to. updateRoutingLogic(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Routing } ) -// UpdateRoutingLogic first checks if we are dealing with a predicate that +// UpdateRoutingLogic first checks if we are dealing with a predicate that can be evaluated to false or NULL. func UpdateRoutingLogic(ctx *plancontext.PlanningContext, expr sqlparser.Expr, r Routing) Routing { ks := r.Keyspace() if ks == nil { diff --git a/go/vt/vtgate/planbuilder/operators/sharded_routing.go b/go/vt/vtgate/planbuilder/operators/sharded_routing.go index 891e3cf5862..c7a06092b87 100644 --- a/go/vt/vtgate/planbuilder/operators/sharded_routing.go +++ b/go/vt/vtgate/planbuilder/operators/sharded_routing.go @@ -46,7 +46,8 @@ type ShardedRouting struct { // SeenPredicates contains all the predicates that have had a chance to influence routing. // If we need to replan routing, we'll use this list - SeenPredicates []sqlparser.Expr + SeenPredicates []sqlparser.Expr + ValuesTablesIDs semantics.TableSet } var _ Routing = (*ShardedRouting)(nil) @@ -189,6 +190,10 @@ func (tr *ShardedRouting) Clone() Routing { } } +func (sr *ShardedRouting) AddValuesTableID(id semantics.TableSet) { + sr.ValuesTablesIDs = sr.ValuesTablesIDs.Merge(id) +} + func (tr *ShardedRouting) updateRoutingLogic(ctx *plancontext.PlanningContext, expr sqlparser.Expr) Routing { tr.SeenPredicates = append(tr.SeenPredicates, expr) @@ -206,6 +211,7 @@ func (tr *ShardedRouting) updateRoutingLogic(ctx *plancontext.PlanningContext, e return tr } +// resetRoutingLogic resets the routing logic to the initial state, and uses the predicates to recompute the routing func (tr *ShardedRouting) resetRoutingLogic(ctx *plancontext.PlanningContext) Routing { tr.RouteOpCode = engine.Scatter tr.Selected = nil @@ -537,6 +543,20 @@ func (tr *ShardedRouting) planEqualOp(ctx *plancontext.PlanningContext, node *sq } val := makeEvalEngineExpr(ctx, vdValue) if val == nil { + col, ok := vdValue.(*sqlparser.ColName) + if !ok { + return false + } + from := ctx.SemTable.RecursiveDeps(col) + if from.IsSolvedBy(tr.ValuesTablesIDs) { + multiEual := func(vindex *vindexes.ColumnVindex) engine.Opcode { + // TODO @harshit - what else should we do here? + return engine.MultiEqual + } + arg := sqlparser.NewListArg("values") // TODO: HACK - we need to store these names? + + return tr.haveMatchingVindex(ctx, node, arg, column, val, multiEual, justTheVindex) + } return false } diff --git a/go/vt/vtgate/planbuilder/testdata/onecase.json b/go/vt/vtgate/planbuilder/testdata/onecase.json index 9d653b2f6e9..f34b2bd60d7 100644 --- a/go/vt/vtgate/planbuilder/testdata/onecase.json +++ b/go/vt/vtgate/planbuilder/testdata/onecase.json @@ -1,8 +1,52 @@ [ { "comment": "Add your test case here for debugging and run go test -run=One.", - "query": "", + "query": "select /*vt+ ALLOW_VALUES_JOIN */ user.foo, user_extra.user_id from user, user_extra where user.id = user_extra.toto and user.foo = 1 and user_extra.bar = 2", "plan": { + "QueryType": "SELECT", + "Original": "select /*vt+ ALLOW_VALUES_JOIN */ user.foo, user_extra.user_id from user, user_extra where user.id = user_extra.toto", + "Instructions": { + "OperatorType": "Join", + "Variant": "Values", + "BindVarName": "values", + "CopyColumnsToRHS": [ + 0, + 1 + ], + "RowID": "false", + "Inputs": [ + { + "OperatorType": "Route", + "Variant": "Scatter", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select user_extra.user_id, user_extra.toto from user_extra where 1 != 1", + "Query": "select /*vt+ ALLOW_VALUES_JOIN */ user_extra.user_id, user_extra.toto from user_extra where user_extra.bar = 2", + "Table": "user_extra" + }, + { + "OperatorType": "Route", + "Variant": "EqualUnique", + "Keyspace": { + "Name": "user", + "Sharded": true + }, + "FieldQuery": "select `user`.foo, user_extra.user_id from (values ::values) as `values`(user_id, toto), `user` where 1 != 1", + "Query": "select /*vt+ ALLOW_VALUES_JOIN */ `user`.foo, values.user_extra_user_id as user_id from (values ::values) as `values`(user_extra_user_id, user_extra_toto), `user` where `user`.foo = 1 and `user`.id = values.user_extra_toto", + "Table": "`user`", + "Values": [ + ":user_extra_toto" + ], + "Vindex": "user_index" + } + ] + }, + "TablesUsed": [ + "user.user", + "user.user_extra" + ] } } ] \ No newline at end of file