Skip to content

Commit

Permalink
feat: change the value-join to have a RowID mode
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Feb 5, 2025
1 parent ef9b52b commit 0c70b4f
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 74 deletions.
36 changes: 25 additions & 11 deletions go/vt/vtgate/engine/values_join.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,22 +34,24 @@ type ValuesJoin struct {
// of the Join. They can be any primitive.
Left, Right Primitive

// The name for the bind var containing the tuple-of-tuples being sent to the RHS
BindVarName string

// LHSRowID is the offset of the row ID in the LHS, used to use columns from the LHS in the output
// If LHSRowID is false, the output will be the same as the RHS, so the following fields are ignored - Cols, ColNames.
// We copy everything from the LHS to the RHS in this case, and column names are taken from the RHS.
RowID bool

// CopyColumnsToRHS are the offsets of columns from LHS we are copying over to the RHS
// []int{0,2} means that the first column in the t-o-t is the first offset from the left and the second column is the third offset
CopyColumnsToRHS []int

// The name for the bind var containing the tuple-of-tuples being sent to the RHS
BindVarName string

// Cols tells use which side the output columns come from:
// negative numbers are offsets to the left, and positive to the right
Cols []int

// ColNames are the output column names
ColNames []string

// LHSRowID is the offset of the row ID in the LHS, used to use columns from the LHS in the output
RowID bool
}

// TryExecute performs a non-streaming exec.
Expand All @@ -73,6 +75,9 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars
bv.Values = append(bv.Values, sqltypes.TupleToProto(vals))

bindVars[jv.BindVarName] = bv
if jv.RowID {
panic("implement me")
}
return jv.Right.GetFields(ctx, vcursor, bindVars)
}

Expand All @@ -81,11 +86,15 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars
rowSize++ // +1 since we add the row ID
}
for i, row := range lresult.Rows {
newRow := make(sqltypes.Row, 0, rowSize) // +1 since we always add the row ID
newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID

for _, loffset := range jv.CopyColumnsToRHS {
newRow = append(newRow, row[loffset])
newRow := make(sqltypes.Row, 0, rowSize)

if jv.RowID {
for _, loffset := range jv.CopyColumnsToRHS {
newRow = append(newRow, row[loffset])
}
newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID
} else {
newRow = row
}

bv.Values = append(bv.Values, sqltypes.TupleToProto(newRow))
Expand All @@ -97,6 +106,11 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars
return nil, err
}

if !jv.RowID {
// if we are not using the row ID, we can just return the result from the RHS
return rresult, nil
}

result := &sqltypes.Result{}

result.Fields = joinFields(lresult.Fields, rresult.Fields, jv.Cols)
Expand Down
174 changes: 111 additions & 63 deletions go/vt/vtgate/engine/values_join_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,84 +18,132 @@ package engine

import (
"context"
"fmt"
"testing"

"github.com/stretchr/testify/require"

"vitess.io/vitess/go/sqltypes"
querypb "vitess.io/vitess/go/vt/proto/query"

"vitess.io/vitess/go/sqltypes"
)

func TestJoinValuesExecute(t *testing.T) {

/*
select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4
LHS: select col1, col2, col3 from left
RHS: select col5, col6, id from (values row(1,2), ...) left(id,col1) join right on left.col1 = right.col4
*/
type testCase struct {
rowID bool
cols []int
CopyColumnsToRHS []int
rhsResults []*sqltypes.Result
expectedRHSLog []string
}

leftPrim := &fakePrimitive{
useNewPrintBindVars: true,
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
testCases := []testCase{
{
/*
select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4
LHS: select col1, col2, col3 from left
RHS: select col5, col6, id from (values row(1,2), ...) left(id,col1) join right on left.col1 = right.col4
*/

rowID: true,
cols: []int{-1, -2, -3, -1, 1, 2},
CopyColumnsToRHS: []int{0},
rhsResults: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col5|col6|id",
"varchar|varchar|int64",
),
"d|dd|0",
"e|ee|1",
"f|ff|2",
"g|gg|3",
),
"1|a|aa",
"2|b|bb",
"3|c|cc",
"4|d|dd",
),
},
}
rightPrim := &fakePrimitive{
useNewPrintBindVars: true,
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col5|col6|id",
"varchar|varchar|int64",
},
expectedRHSLog: []string{
`Execute a: type:INT64 value:"10" v: [[INT64(1) INT64(0)][INT64(2) INT64(1)][INT64(3) INT64(2)][INT64(4) INT64(3)]] true`,
},
}, {
/*
select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4
LHS: select col1, col2, col3 from left
RHS: select col1, col2, col3, col4, col5, col6 from (values row(1,2,3), ...) left(col1,col2,col3) join right on left.col1 = right.col4
*/

rowID: false,
rhsResults: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3|col4|col5|col6",
"int64|varchar|varchar|int64|varchar|varchar",
),
"1|a|aa|1|d|dd",
"2|b|bb|2|e|ee",
"3|c|cc|3|f|ff",
"4|d|dd|4|g|gg",
),
"d|dd|0",
"e|ee|1",
"f|ff|2",
"g|gg|3",
),
},
expectedRHSLog: []string{
`Execute a: type:INT64 value:"10" v: [[INT64(1) VARCHAR("a") VARCHAR("aa")][INT64(2) VARCHAR("b") VARCHAR("bb")][INT64(3) VARCHAR("c") VARCHAR("cc")][INT64(4) VARCHAR("d") VARCHAR("dd")]] true`,
},
},
}

bv := map[string]*querypb.BindVariable{
"a": sqltypes.Int64BindVariable(10),
}
for _, tc := range testCases {
t.Run(fmt.Sprintf("rowID:%t", tc.rowID), func(t *testing.T) {
leftPrim := &fakePrimitive{
useNewPrintBindVars: true,
results: []*sqltypes.Result{
sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3",
"int64|varchar|varchar",
),
"1|a|aa",
"2|b|bb",
"3|c|cc",
"4|d|dd",
),
},
}
rightPrim := &fakePrimitive{
useNewPrintBindVars: true,
results: tc.rhsResults,
}

vjn := &ValuesJoin{
Left: leftPrim,
Right: rightPrim,
CopyColumnsToRHS: []int{0},
BindVarName: "v",
Cols: []int{-1, -2, -3, -1, 1, 2},
ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"},
}
bv := map[string]*querypb.BindVariable{
"a": sqltypes.Int64BindVariable(10),
}

vjn := &ValuesJoin{
Left: leftPrim,
Right: rightPrim,
CopyColumnsToRHS: tc.CopyColumnsToRHS,
BindVarName: "v",
Cols: tc.cols,
ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"},
RowID: tc.rowID,
}

r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true)
require.NoError(t, err)
leftPrim.ExpectLog(t, []string{
`Execute a: type:INT64 value:"10" true`,
})
rightPrim.ExpectLog(t, []string{
`Execute a: type:INT64 value:"10" v: [[INT64(0) INT64(1)][INT64(1) INT64(2)][INT64(2) INT64(3)][INT64(3) INT64(4)]] true`,
})

result := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3|col4|col5|col6",
"int64|varchar|varchar|int64|varchar|varchar",
),
"1|a|aa|1|d|dd",
"2|b|bb|2|e|ee",
"3|c|cc|3|f|ff",
"4|d|dd|4|g|gg",
)
expectResult(t, r, result)
r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true)
require.NoError(t, err)
leftPrim.ExpectLog(t, []string{
`Execute a: type:INT64 value:"10" true`,
})
rightPrim.ExpectLog(t, tc.expectedRHSLog)

result := sqltypes.MakeTestResult(
sqltypes.MakeTestFields(
"col1|col2|col3|col4|col5|col6",
"int64|varchar|varchar|int64|varchar|varchar",
),
"1|a|aa|1|d|dd",
"2|b|bb|2|e|ee",
"3|c|cc|3|f|ff",
"4|d|dd|4|g|gg",
)
expectResult(t, r, result)
})
}
}

0 comments on commit 0c70b4f

Please sign in to comment.