Skip to content

Commit

Permalink
refactor: change how we handle Values and ValuesJoin planning - wip
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Feb 20, 2025
1 parent 4dd3589 commit df126a2
Show file tree
Hide file tree
Showing 8 changed files with 125 additions and 105 deletions.
44 changes: 25 additions & 19 deletions go/vt/vtgate/planbuilder/operator_transformers.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@ func transformToPrimitive(ctx *plancontext.PlanningContext, op operators.Operato
return transformSubQuery(ctx, op)
case *operators.Filter:
return transformFilter(ctx, op)
case *operators.Horizon:
panic("should have been solved in the operator")
case *operators.Projection:
return transformProjection(ctx, op)
case *operators.Limit:
Expand Down Expand Up @@ -80,28 +78,36 @@ func transformToPrimitive(ctx *plancontext.PlanningContext, op operators.Operato
case *operators.PercentBasedMirror:
return transformPercentBasedMirror(ctx, op)
case *operators.ValuesJoin:
lhs, err := transformToPrimitive(ctx, op.LHS)
if err != nil {
return nil, err
}
rhs, err := transformToPrimitive(ctx, op.RHS)
if err != nil {
return nil, err
}

return &engine.ValuesJoin{
Left: lhs,
Right: rhs,
CopyColumnsToRHS: op.CopyColumnsToRHS,
BindVarName: op.BindVarName,
Cols: op.Columns,
ColNames: op.ColumnName,
}, nil
return transformValuesJoin(ctx, op)
case *operators.Values:
panic("should have been pushed under a route")
case *operators.Horizon:
panic("should have been solved in the operator")
}

return nil, vterrors.VT13001(fmt.Sprintf("unknown type encountered: %T (transformToPrimitive)", op))
}

func transformValuesJoin(ctx *plancontext.PlanningContext, op *operators.ValuesJoin) (engine.Primitive, error) {
lhs, err := transformToPrimitive(ctx, op.LHS)
if err != nil {
return nil, err
}
rhs, err := transformToPrimitive(ctx, op.RHS)
if err != nil {
return nil, err
}

return &engine.ValuesJoin{
Left: lhs,
Right: rhs,
CopyColumnsToRHS: op.CopyColumnsToRHS,
BindVarName: op.ValuesDestination,
Cols: op.Columns,
ColNames: op.ColumnName,
}, nil
}

func transformPercentBasedMirror(ctx *plancontext.PlanningContext, op *operators.PercentBasedMirror) (engine.Primitive, error) {
primitive, err := transformToPrimitive(ctx, op.Operator())
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion go/vt/vtgate/planbuilder/operators/expressions_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ func TestSplitComplexPredicateToLHS(t *testing.T) {
}, ast)

valuesJoinCols := breakValuesJoinExpressionInLHS(ctx, ast, lID)
nodes := slice.Map(valuesJoinCols.LHS, func(from *sqlparser.ColName) string {
nodes := slice.Map(valuesJoinCols.LHS, func(from sqlparser.Expr) string {
return sqlparser.String(from)
})

Expand Down
39 changes: 34 additions & 5 deletions go/vt/vtgate/planbuilder/operators/op_to_ast.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,13 @@ package operators

import (
"fmt"
"strings"

"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
"vitess.io/vitess/go/vt/vtgate/semantics"
)

func ToAST(ctx *plancontext.PlanningContext, op Operator) (_ sqlparser.Statement, _ Operator, err error) {
Expand Down Expand Up @@ -130,7 +132,7 @@ func buildDistinct(op *Distinct, qb *queryBuilder) {
}

func buildValuesJoin(op *ValuesJoin, qb *queryBuilder) {
qb.ctx.SkipValuesArgument(op.BindVarName)
qb.ctx.SkipValuesArgument(op.ValuesDestination)
buildAST(op.LHS, qb)
qbR := &queryBuilder{ctx: qb.ctx}
buildAST(op.RHS, qbR)
Expand All @@ -139,15 +141,42 @@ func buildValuesJoin(op *ValuesJoin, qb *queryBuilder) {

func buildValues(op *Values, qb *queryBuilder) {
buildAST(op.Source, qb)
if qb.ctx.IsValuesArgumentSkipped(op.Arg) {
if qb.ctx.IsValuesArgumentSkipped(op.Name) {
return
}

qb.addTableExpr(op.Name, op.Name, TableID(op), &sqlparser.DerivedTable{
expr := &sqlparser.DerivedTable{
Select: &sqlparser.ValuesStatement{
ListArg: sqlparser.NewListArg(op.Arg),
ListArg: sqlparser.NewListArg(op.Name),
},
}, nil, op.getColsFromCtx(qb.ctx))
}

apa := semantics.EmptyTableSet()
for _, ae := range qb.ctx.ValuesJoinColumns[op.Name] {
apa = apa.Merge(qb.ctx.SemTable.RecursiveDeps(ae.Expr))
}

tableName := getTableName(qb.ctx, apa)

qb.addTableExpr(tableName, tableName, TableID(op), expr, nil, op.getColumnNamesFromCtx(qb.ctx))
}

func getTableName(ctx *plancontext.PlanningContext, id semantics.TableSet) string {
var names []string
for _, ts := range id.Constituents() {
ti, err := ctx.SemTable.TableInfoFor(ts)
if err != nil {
names = append(names, "X")
continue
}
name, err := ti.Name()
if err != nil {
names = append(names, "X")
continue
}
names = append(names, name.Name.String())
}
return strings.Join(names, "_")
}

func buildDelete(op *Delete, qb *queryBuilder) {
Expand Down
26 changes: 12 additions & 14 deletions go/vt/vtgate/planbuilder/operators/op_to_ast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,21 @@ import (

func TestToSQLValues(t *testing.T) {
ctx := plancontext.CreateEmptyPlanningContext()
bindVarName := "toto"
ctx.ValuesJoinColumns[bindVarName] = sqlparser.Columns{sqlparser.NewIdentifierCI("user_id")}
name := "toto"
ctx.ValuesJoinColumns[name] = []*sqlparser.AliasedExpr{{Expr: sqlparser.NewColName("user_id")}}

tableName := sqlparser.NewTableName("x")
tableColumn := sqlparser.NewColName("id")
source := &Table{
QTable: &QueryTable{
Table: tableName,
Alias: sqlparser.NewAliasedTableExpr(tableName, ""),
},
Columns: []*sqlparser.ColName{tableColumn},
}
op := &Values{
unaryOperator: newUnaryOp(&Table{
QTable: &QueryTable{
Table: tableName,
Alias: sqlparser.NewAliasedTableExpr(tableName, ""),
},
Columns: []*sqlparser.ColName{tableColumn},
}),
Name: "t",
Arg: bindVarName,
unaryOperator: newUnaryOp(source),
Name: name,
}

stmt, _, err := ToAST(ctx, op)
Expand Down Expand Up @@ -85,7 +85,7 @@ func TestToSQLValuesJoin(t *testing.T) {
}

const argumentName = "v"
ctx.ValuesJoinColumns[argumentName] = sqlparser.Columns{sqlparser.NewIdentifierCI("id")}
ctx.ValuesJoinColumns[argumentName] = []*sqlparser.AliasedExpr{{Expr: sqlparser.NewColName("user_id")}}
rhsTableName := sqlparser.NewTableName("y")
rhsTableColumn := sqlparser.NewColName("tata")
rhsFilterPred, err := parser.ParseExpr("y.tata = 42")
Expand All @@ -103,14 +103,12 @@ func TestToSQLValuesJoin(t *testing.T) {
Columns: []*sqlparser.ColName{rhsTableColumn},
}),
Name: lhsTableName.Name.String(),
Arg: argumentName,
}),
Predicates: []sqlparser.Expr{rhsFilterPred, rhsJoinFilterPred},
}

vj := &ValuesJoin{
binaryOperator: newBinaryOp(LHS, RHS),
BindVarName: argumentName,
}

stmt, _, err := ToAST(ctx, vj)
Expand Down
30 changes: 5 additions & 25 deletions go/vt/vtgate/planbuilder/operators/phases.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@ package operators

import (
"io"
"strings"

"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/sqlparser"
Expand Down Expand Up @@ -152,36 +151,17 @@ func newValuesJoin(ctx *plancontext.PlanningContext, lhs, rhs Operator, joinType
if !joinType.IsInner() {
return nil
}
lhsID := TableID(lhs)
lhsTableName := getTableName(ctx, lhsID)

bindVariableName := ctx.ReservedVars.ReserveVariable("values")
ctx.ValueJoins[bindVariableName] = bindVariableName
v := &Values{
unaryOperator: newUnaryOp(rhs),
Name: lhsTableName,
Arg: bindVariableName,
Name: bindVariableName,
}
return &ValuesJoin{
binaryOperator: newBinaryOp(lhs, v),
BindVarName: bindVariableName,
}
}

func getTableName(ctx *plancontext.PlanningContext, lhsID semantics.TableSet) string {
var parts []string
for _, ts := range lhsID.Constituents() {
lhsTableInfo, err := ctx.SemTable.TableInfoFor(ts)
if err != nil {
parts = append(parts, "X")
continue
}
lhsTableName, err := lhsTableInfo.Name()
if err != nil {
parts = append(parts, "X")
continue
}
parts = append(parts, lhsTableName.Name.String())
binaryOperator: newBinaryOp(lhs, v),
ValuesDestination: bindVariableName,
}
return strings.Join(parts, "_")
}

type phaser struct {
Expand Down
32 changes: 17 additions & 15 deletions go/vt/vtgate/planbuilder/operators/values.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ limitations under the License.
package operators

import (
"vitess.io/vitess/go/slice"
"vitess.io/vitess/go/vt/sqlparser"
"vitess.io/vitess/go/vt/vterrors"
"vitess.io/vitess/go/vt/vtgate/planbuilder/plancontext"
Expand All @@ -26,7 +27,6 @@ type Values struct {
unaryOperator

Name string
Arg string
}

func (v *Values) Clone(inputs []Operator) Operator {
Expand All @@ -51,32 +51,34 @@ func (v *Values) AddWSColumn(ctx *plancontext.PlanningContext, offset int, under
}

func (v *Values) FindCol(ctx *plancontext.PlanningContext, expr sqlparser.Expr, underRoute bool) int {
col, ok := expr.(*sqlparser.ColName)
if !ok {
return -1
}
for i, column := range v.getColsFromCtx(ctx) {
if col.Name.Equal(column) {
for i, column := range v.getExprsFromCtx(ctx) {
if ctx.SemTable.EqualsExpr(column, expr) {
return i
}
}
return -1
}

func (v *Values) getColsFromCtx(ctx *plancontext.PlanningContext) sqlparser.Columns {
columns, found := ctx.ValuesJoinColumns[v.Arg]
func (v *Values) getColumnNamesFromCtx(ctx *plancontext.PlanningContext) sqlparser.Columns {
columns, found := ctx.ValuesJoinColumns[v.Name]
if !found {
panic(vterrors.VT13001("columns not found"))
}
return columns
return slice.Map(columns, func(ae *sqlparser.AliasedExpr) sqlparser.IdentifierCI {
return sqlparser.NewIdentifierCI(ae.ColumnName())
})
}

func (v *Values) getExprsFromCtx(ctx *plancontext.PlanningContext) []sqlparser.Expr {
columns := ctx.ValuesJoinColumns[v.Name]
return slice.Map(columns, func(ae *sqlparser.AliasedExpr) sqlparser.Expr {
return ae.Expr
})
}

func (v *Values) GetColumns(ctx *plancontext.PlanningContext) []*sqlparser.AliasedExpr {
var cols []*sqlparser.AliasedExpr
for _, column := range v.getColsFromCtx(ctx) {
cols = append(cols, sqlparser.NewAliasedExpr(sqlparser.NewColNameWithQualifier(column.String(), sqlparser.NewTableName(v.Name)), ""))
}
return cols
columns := ctx.ValuesJoinColumns[v.Name]
return columns
}

func (v *Values) GetSelectExprs(ctx *plancontext.PlanningContext) []sqlparser.SelectExpr {
Expand Down
43 changes: 20 additions & 23 deletions go/vt/vtgate/planbuilder/operators/values_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package operators

import (
"fmt"
"slices"
"strings"

"vitess.io/vitess/go/slice"
Expand All @@ -29,7 +30,7 @@ type (
ValuesJoin struct {
binaryOperator

BindVarName string
ValuesDestination string

JoinColumns []valuesJoinColumn
JoinPredicates []valuesJoinColumn
Expand All @@ -46,7 +47,7 @@ type (

valuesJoinColumn struct {
Original sqlparser.Expr
LHS []*sqlparser.ColName
LHS []sqlparser.Expr
PureLHS bool
}
)
Expand Down Expand Up @@ -149,7 +150,7 @@ func (vj *ValuesJoin) ShortDescription() string {
return strings.Join(out, ", ")
}

firstPart := fmt.Sprintf("on %s columns: %s", fn(vj.JoinPredicates), fn(vj.JoinColumns))
firstPart := fmt.Sprintf("%s on %s columns: %s", vj.ValuesDestination, fn(vj.JoinPredicates), fn(vj.JoinColumns))

return firstPart
}
Expand All @@ -159,34 +160,30 @@ func (vj *ValuesJoin) GetOrdering(ctx *plancontext.PlanningContext) []OrderBy {
}

func (vj *ValuesJoin) planOffsets(ctx *plancontext.PlanningContext) Operator {
valuesColumns := ctx.ValuesJoinColumns[vj.BindVarName]
exprs := ctx.ValuesJoinColumns[vj.ValuesDestination]
for _, jc := range vj.JoinColumns {
vj.planOffsetsForValueJoinPredicate(ctx, jc.LHS, &valuesColumns)
ctx.ValuesJoinColumns[vj.BindVarName] = valuesColumns

newExprs := vj.planOffsetsForLHSExprs(ctx, jc.LHS)
exprs = append(exprs, newExprs...)
offset := vj.RHS.AddColumn(ctx, true, false, aeWrap(jc.Original))
vj.Columns = append(vj.Columns, ToRightOffset(offset))
}

for _, predicate := range vj.JoinPredicates {
vj.planOffsetsForValueJoinPredicate(ctx, predicate.LHS, &valuesColumns)
for _, jc := range vj.JoinPredicates {
// for join predicates, we only need to push the LHS dependencies. The RHS expressions are already pushed
newExprs := vj.planOffsetsForLHSExprs(ctx, jc.LHS)
exprs = append(exprs, newExprs...)
}

ctx.ValuesJoinColumns[vj.BindVarName] = valuesColumns
ctx.ValuesJoinColumns[vj.ValuesDestination] = exprs
return vj
}

func (vj *ValuesJoin) planOffsetsForValueJoinPredicate(ctx *plancontext.PlanningContext, lhsPred []*sqlparser.ColName, valuesColumns *sqlparser.Columns) {
outer:
for _, lh := range lhsPred {
for _, ci := range *valuesColumns {
if ci.Equal(lh.Name) {
// already there, no need to add it again
continue outer
}
func (vj *ValuesJoin) planOffsetsForLHSExprs(ctx *plancontext.PlanningContext, input []sqlparser.Expr) (exprs []*sqlparser.AliasedExpr) {
for _, lhsExpr := range input {
offset := vj.LHS.AddColumn(ctx, true, false, aeWrap(lhsExpr))
// only add it if we don't already have it
if slices.Index(vj.CopyColumnsToRHS, offset) == -1 {
vj.CopyColumnsToRHS = append(vj.CopyColumnsToRHS, offset)
exprs = append(exprs, aeWrap(lhsExpr))
}
offset := vj.LHS.AddColumn(ctx, true, false, aeWrap(lh))
vj.CopyColumnsToRHS = append(vj.CopyColumnsToRHS, offset)
*valuesColumns = append(*valuesColumns, lh.Name)
}
return exprs
}
Loading

0 comments on commit df126a2

Please sign in to comment.