Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
Signed-off-by: Florent Poinsard <florent.poinsard@outlook.fr>
  • Loading branch information
frouioui committed Feb 25, 2025
1 parent b392c85 commit 950edc4
Show file tree
Hide file tree
Showing 11 changed files with 252 additions and 106 deletions.
14 changes: 9 additions & 5 deletions go/vt/vtgate/planbuilder/operators/expressions.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,19 +134,23 @@ func breakValuesJoinExpressionInLHS(ctx *plancontext.PlanningContext,
expr sqlparser.Expr,
lhs semantics.TableSet,
) (result valuesJoinColumn) {
result.Original = expr
result.Original = sqlparser.Clone(expr)
result.PureLHS = true
_ = sqlparser.Walk(func(node sqlparser.SQLNode) (kontinue bool, err error) {
result.RHS = expr
_ = sqlparser.Rewrite(expr, func(cursor *sqlparser.Cursor) bool {
node := cursor.Node()
col, ok := node.(*sqlparser.ColName)
if !ok {
return true, nil
return true
}
if ctx.SemTable.RecursiveDeps(col) == lhs {
result.LHS = append(result.LHS, col)
// TODO: Fine all the LHS columns, and
// rewrite the expression to use the value join name and the column.
} else {
result.PureLHS = false
}
return true, nil
}, expr)
return true
}, nil)
return
}
6 changes: 6 additions & 0 deletions go/vt/vtgate/planbuilder/operators/info_schema_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"slices"
"strings"

"vitess.io/vitess/go/vt/vtgate/semantics"

"vitess.io/vitess/go/mysql/collations"
"vitess.io/vitess/go/sqltypes"
"vitess.io/vitess/go/vt/sqlparser"
Expand Down Expand Up @@ -102,6 +104,10 @@ func (isr *InfoSchemaRouting) updateRoutingLogic(ctx *plancontext.PlanningContex
return isr
}

func (isr *InfoSchemaRouting) AddValuesTableID(id semantics.TableSet) {
panic(vterrors.VT13001("think about values and info schema routing"))
}

func (isr *InfoSchemaRouting) Cost() int {
return 0
}
Expand Down
16 changes: 13 additions & 3 deletions go/vt/vtgate/planbuilder/operators/misc_routing.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vtgate/engine"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
"vitess.io/vitess/go/vt/vtgate/semantics"
"vitess.io/vitess/go/vt/vtgate/vindexes"
)

Expand Down Expand Up @@ -77,6 +78,7 @@ func (tr *TargetedRouting) Clone() Routing {
func (tr *TargetedRouting) updateRoutingLogic(_ *plancontext.PlanningContext, _ sqlparser.Expr) Routing {
return tr
}
func (tr *TargetedRouting) AddValuesTableID(semantics.TableSet) {}

func (tr *TargetedRouting) Cost() int {
return 1
Expand All @@ -102,6 +104,8 @@ func (n *NoneRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlparser
return n
}

func (*NoneRouting) AddValuesTableID(semantics.TableSet) {}

func (n *NoneRouting) Cost() int {
return 0
}
Expand Down Expand Up @@ -129,6 +133,8 @@ func (rr *AnyShardRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlp
return rr
}

func (tr *AnyShardRouting) AddValuesTableID(semantics.TableSet) {}

func (rr *AnyShardRouting) Cost() int {
return 0
}
Expand Down Expand Up @@ -166,6 +172,8 @@ func (dr *DualRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlparse
return dr
}

func (tr *DualRouting) AddValuesTableID(semantics.TableSet) {}

func (dr *DualRouting) Cost() int {
return 0
}
Expand All @@ -191,14 +199,16 @@ func (sr *SequenceRouting) updateRoutingLogic(*plancontext.PlanningContext, sqlp
return sr
}

func (sr *SequenceRouting) Cost() int {
func (*SequenceRouting) AddValuesTableID(semantics.TableSet) {}

func (*SequenceRouting) Cost() int {
return 0
}

func (sr *SequenceRouting) OpCode() engine.Opcode {
func (*SequenceRouting) OpCode() engine.Opcode {
return engine.Next
}

func (sr *SequenceRouting) Keyspace() *vindexes.Keyspace {
func (*SequenceRouting) Keyspace() *vindexes.Keyspace {
return nil
}
106 changes: 42 additions & 64 deletions go/vt/vtgate/planbuilder/operators/op_to_ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@ package operators

import (
"fmt"
"strings"

"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
Expand All @@ -38,45 +36,6 @@ func ToAST(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement
return q.stmt, q.dmlOperator, nil
}

func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) {
switch expr := expr.(type) {
case *sqlparser.AliasedExpr:
sqlparser.RemoveKeyspaceInCol(expr.Expr)
case *sqlparser.StarExpr:
expr.TableName.Qualifier = sqlparser.NewIdentifierCS("")
}
}

func stripDownQuery(from, to sqlparser.TableStatement) {
switch node := from.(type) {
case *sqlparser.Select:
toNode, ok := to.(*sqlparser.Select)
if !ok {
panic(vterrors.VT13001("AST did not match"))
}
toNode.Distinct = node.Distinct
toNode.GroupBy = node.GroupBy
toNode.Having = node.Having
toNode.OrderBy = node.OrderBy
toNode.Comments = node.Comments
toNode.Limit = node.Limit
toNode.SelectExprs = node.SelectExprs
for _, expr := range toNode.GetColumns() {
removeKeyspaceFromSelectExpr(expr)
}
case *sqlparser.Union:
toNode, ok := to.(*sqlparser.Union)
if !ok {
panic(vterrors.VT13001("AST did not match"))
}
stripDownQuery(node.Left, toNode.Left)
stripDownQuery(node.Right, toNode.Right)
toNode.OrderBy = node.OrderBy
default:
panic(vterrors.VT13001(fmt.Sprintf("this should not happen - we have covered all implementations of SelectStatement %T", from)))
}
}

// buildAST recursively builds the query into an AST, from an operator tree
func buildAST(op Operator, qb *queryBuilder) {
switch op := op.(type) {
Expand Down Expand Up @@ -121,6 +80,45 @@ func buildAST(op Operator, qb *queryBuilder) {
}
}

func removeKeyspaceFromSelectExpr(expr sqlparser.SelectExpr) {
switch expr := expr.(type) {
case *sqlparser.AliasedExpr:
sqlparser.RemoveKeyspaceInCol(expr.Expr)
case *sqlparser.StarExpr:
expr.TableName.Qualifier = sqlparser.NewIdentifierCS("")
}
}

func stripDownQuery(from, to sqlparser.TableStatement) {
switch node := from.(type) {
case *sqlparser.Select:
toNode, ok := to.(*sqlparser.Select)
if !ok {
panic(vterrors.VT13001("AST did not match"))
}
toNode.Distinct = node.Distinct
toNode.GroupBy = node.GroupBy
toNode.Having = node.Having
toNode.OrderBy = node.OrderBy
toNode.Comments = node.Comments
toNode.Limit = node.Limit
toNode.SelectExprs = node.SelectExprs
for _, expr := range toNode.GetColumns() {
removeKeyspaceFromSelectExpr(expr)
}
case *sqlparser.Union:
toNode, ok := to.(*sqlparser.Union)
if !ok {
panic(vterrors.VT13001("AST did not match"))
}
stripDownQuery(node.Left, toNode.Left)
stripDownQuery(node.Right, toNode.Right)
toNode.OrderBy = node.OrderBy
default:
panic(vterrors.VT13001(fmt.Sprintf("this should not happen - we have covered all implementations of SelectStatement %T", from)))
}
}

func buildDistinct(op *Distinct, qb *queryBuilder) {
buildAST(op.Source, qb)
statement := qb.asSelectStatement()
Expand Down Expand Up @@ -151,32 +149,12 @@ func buildValues(op *Values, qb *queryBuilder) {
},
}

apa := semantics.EmptyTableSet()
deps := semantics.EmptyTableSet()
for _, ae := range qb.ctx.ValuesJoinColumns[op.Name] {
apa = apa.Merge(qb.ctx.SemTable.RecursiveDeps(ae.Expr))
deps = deps.Merge(qb.ctx.SemTable.RecursiveDeps(ae.Expr))
}

tableName := getTableName(qb.ctx, apa)

qb.addTableExpr(tableName, tableName, TableID(op), expr, nil, op.getColumnNamesFromCtx(qb.ctx))
}

func getTableName(ctx *plancontext.PlanningContext, id semantics.TableSet) string {
var names []string
for _, ts := range id.Constituents() {
ti, err := ctx.SemTable.TableInfoFor(ts)
if err != nil {
names = append(names, "X")
continue
}
name, err := ti.Name()
if err != nil {
names = append(names, "X")
continue
}
names = append(names, name.Name.String())
}
return strings.Join(names, "_")
qb.addTableExpr(op.Name, op.Name, TableID(op), expr, nil, op.getColumnNamesFromCtx(qb.ctx))
}

func buildDelete(op *Delete, qb *queryBuilder) {
Expand Down
103 changes: 80 additions & 23 deletions go/vt/vtgate/planbuilder/operators/phases.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,42 +114,99 @@ func (p Phase) act(ctx *plancontext.PlanningContext, op Operator) Operator {
case dmlWithInput:
return findDMLAboveRoute(ctx, op)
case rewriteApplyJoin:
visit := func(op Operator, lhsTables semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
aj, ok := op.(*ApplyJoin)
if !ok {
return op, NoRewrite
}
return rewriteApplyToValues(ctx, op)

vj := newValuesJoin(ctx, aj.LHS, aj.RHS, aj.JoinType)
if vj == nil {
return op, NoRewrite
default:
return op
}
}

func rewriteApplyToValues(ctx *plancontext.PlanningContext, op Operator) Operator {
var skipped []sqlparser.Expr
isSkipped := func(expr sqlparser.Expr) bool {
for _, skip := range skipped {
if ctx.SemTable.EqualsExpr(expr, skip) {
return true
}
}
return false
}

for _, column := range aj.JoinColumns.columns {
vj.AddColumn(ctx, true, false, aeWrap(column.Original))
// Traverse the operator tree to convert ApplyJoin to ValuesJoin.
// Then add a Values node to the RHS of the new ValuesJoin,
// and usually a filter containing the join predicates is placed there.
visit := func(op Operator, lhsTables semantics.TableSet, isRoot bool) (Operator, *ApplyResult) {
aj, ok := op.(*ApplyJoin)
if !ok {
return op, NoRewrite
}

vj, valuesTableID := newValuesJoin(ctx, aj.LHS, aj.RHS, aj.JoinType)
if vj == nil {
return op, NoRewrite
}

for _, column := range aj.JoinColumns.columns {
vj.AddColumn(ctx, true, false, aeWrap(column.Original))
}

for _, pred := range aj.JoinPredicates.columns {
skipped = append(skipped, pred.RHSExpr)
err := ctx.SkipJoinPredicates(pred.Original)
if err != nil {
panic(err)
}

for _, pred := range aj.JoinPredicates.columns {
err := ctx.SkipJoinPredicates(pred.Original)
if err != nil {
panic(err)
newOriginal := sqlparser.Rewrite(pred.Original, nil, func(cursor *sqlparser.Cursor) bool {
col, isCol := cursor.Node().(*sqlparser.ColName)
if !isCol || ctx.SemTable.RecursiveDeps(col) != valuesTableID {
return true
}
vj.AddJoinPredicate(ctx, pred.Original)
}

return vj, Rewrote("rewrote ApplyJoin to ValuesJoin")
cursor.Replace(&sqlparser.ColName{
Name: sqlparser.NewIdentifierCI(getValuesJoinColName(ctx, vj.ValuesDestination, valuesTableID, col)),
Qualifier: sqlparser.NewTableName(vj.ValuesDestination),
})
return true
})

vj.AddJoinPredicate(ctx, newOriginal.(sqlparser.Expr))
}

return TopDown(op, TableID, visit, stopAtRoute)
return vj, Rewrote("rewrote ApplyJoin to ValuesJoin")
}

default:
return op
shouldVisit := func(op Operator) VisitRule {
rb, ok := op.(*Route)
if !ok {
return VisitChildren
}

routing, ok := rb.Routing.(*ShardedRouting)
if !ok {
return SkipChildren
}

// We need to skip the predicates that are already pushed down to the mysql -
// we will push down the JoinValues predicates, and they will be used for routing
var preds []sqlparser.Expr
for _, pred := range routing.SeenPredicates {
if !isSkipped(pred) {
preds = append(preds, pred)
}
}
routing.SeenPredicates = preds

rb.Routing = routing.resetRoutingLogic(ctx)
return SkipChildren
}

return TopDown(op, TableID, visit, shouldVisit)
}

func newValuesJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType sqlparser.JoinType) *ValuesJoin {
func newValuesJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType sqlparser.JoinType) (*ValuesJoin, semantics.TableSet) {
if !joinType.IsInner() {
return nil
return nil, semantics.EmptyTableSet()
}

bindVariableName := ctx.ReservedVars.ReserveVariable("values")
Expand All @@ -162,7 +219,7 @@ func newValuesJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType
return &ValuesJoin{
binaryOperator: newBinaryOp(lhs, v),
ValuesDestination: bindVariableName,
}
}, v.TableID
}

type phaser struct {
Expand Down
4 changes: 1 addition & 3 deletions go/vt/vtgate/planbuilder/operators/query_planning.go
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,7 @@ func tryPushValues(ctx *plancontext.PlanningContext, in *Values) (Operator, *App
case *Filter:
return Swap(in, src, "pushed values under filter")
case *Route:
if !reachedPhase(ctx, rewriteApplyJoin+1) {
return in, NoRewrite
}
src.Routing.AddValuesTableID(in.TableID)
return Swap(in, src, "pushed values under route")
}
return in, NoRewrite
Expand Down
Loading

0 comments on commit 950edc4

Please sign in to comment.