From 0c70b4f147f03141b00b99bef83264039d3fd942 Mon Sep 17 00:00:00 2001 From: Andres Taylor Date: Wed, 5 Feb 2025 15:29:07 +0100 Subject: [PATCH] feat: change the value-join to have a RowID mode Signed-off-by: Andres Taylor --- go/vt/vtgate/engine/values_join.go | 36 +++-- go/vt/vtgate/engine/values_join_test.go | 174 +++++++++++++++--------- 2 files changed, 136 insertions(+), 74 deletions(-) diff --git a/go/vt/vtgate/engine/values_join.go b/go/vt/vtgate/engine/values_join.go index ed4bd42367c..ced6283dbe8 100644 --- a/go/vt/vtgate/engine/values_join.go +++ b/go/vt/vtgate/engine/values_join.go @@ -34,22 +34,24 @@ type ValuesJoin struct { // of the Join. They can be any primitive. Left, Right Primitive + // The name for the bind var containing the tuple-of-tuples being sent to the RHS + BindVarName string + + // LHSRowID is the offset of the row ID in the LHS, used to use columns from the LHS in the output + // If LHSRowID is false, the output will be the same as the RHS, so the following fields are ignored - Cols, ColNames. + // We copy everything from the LHS to the RHS in this case, and column names are taken from the RHS. + RowID bool + // CopyColumnsToRHS are the offsets of columns from LHS we are copying over to the RHS // []int{0,2} means that the first column in the t-o-t is the first offset from the left and the second column is the third offset CopyColumnsToRHS []int - // The name for the bind var containing the tuple-of-tuples being sent to the RHS - BindVarName string - // Cols tells use which side the output columns come from: // negative numbers are offsets to the left, and positive to the right Cols []int // ColNames are the output column names ColNames []string - - // LHSRowID is the offset of the row ID in the LHS, used to use columns from the LHS in the output - RowID bool } // TryExecute performs a non-streaming exec. @@ -73,6 +75,9 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars bv.Values = append(bv.Values, sqltypes.TupleToProto(vals)) bindVars[jv.BindVarName] = bv + if jv.RowID { + panic("implement me") + } return jv.Right.GetFields(ctx, vcursor, bindVars) } @@ -81,11 +86,15 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars rowSize++ // +1 since we add the row ID } for i, row := range lresult.Rows { - newRow := make(sqltypes.Row, 0, rowSize) // +1 since we always add the row ID - newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID - - for _, loffset := range jv.CopyColumnsToRHS { - newRow = append(newRow, row[loffset]) + newRow := make(sqltypes.Row, 0, rowSize) + + if jv.RowID { + for _, loffset := range jv.CopyColumnsToRHS { + newRow = append(newRow, row[loffset]) + } + newRow = append(newRow, sqltypes.NewInt64(int64(i))) // Adding the LHS row ID + } else { + newRow = row } bv.Values = append(bv.Values, sqltypes.TupleToProto(newRow)) @@ -97,6 +106,11 @@ func (jv *ValuesJoin) TryExecute(ctx context.Context, vcursor VCursor, bindVars return nil, err } + if !jv.RowID { + // if we are not using the row ID, we can just return the result from the RHS + return rresult, nil + } + result := &sqltypes.Result{} result.Fields = joinFields(lresult.Fields, rresult.Fields, jv.Cols) diff --git a/go/vt/vtgate/engine/values_join_test.go b/go/vt/vtgate/engine/values_join_test.go index 21427113fa7..29297d6aa32 100644 --- a/go/vt/vtgate/engine/values_join_test.go +++ b/go/vt/vtgate/engine/values_join_test.go @@ -18,84 +18,132 @@ package engine import ( "context" + "fmt" "testing" "github.com/stretchr/testify/require" - "vitess.io/vitess/go/sqltypes" querypb "vitess.io/vitess/go/vt/proto/query" + + "vitess.io/vitess/go/sqltypes" ) func TestJoinValuesExecute(t *testing.T) { - /* - select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4 - LHS: select col1, col2, col3 from left - RHS: select col5, col6, id from (values row(1,2), ...) left(id,col1) join right on left.col1 = right.col4 - */ + type testCase struct { + rowID bool + cols []int + CopyColumnsToRHS []int + rhsResults []*sqltypes.Result + expectedRHSLog []string + } - leftPrim := &fakePrimitive{ - useNewPrintBindVars: true, - results: []*sqltypes.Result{ - sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col1|col2|col3", - "int64|varchar|varchar", + testCases := []testCase{ + { + /* + select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4 + LHS: select col1, col2, col3 from left + RHS: select col5, col6, id from (values row(1,2), ...) left(id,col1) join right on left.col1 = right.col4 + */ + + rowID: true, + cols: []int{-1, -2, -3, -1, 1, 2}, + CopyColumnsToRHS: []int{0}, + rhsResults: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col5|col6|id", + "varchar|varchar|int64", + ), + "d|dd|0", + "e|ee|1", + "f|ff|2", + "g|gg|3", ), - "1|a|aa", - "2|b|bb", - "3|c|cc", - "4|d|dd", - ), - }, - } - rightPrim := &fakePrimitive{ - useNewPrintBindVars: true, - results: []*sqltypes.Result{ - sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col5|col6|id", - "varchar|varchar|int64", + }, + expectedRHSLog: []string{ + `Execute a: type:INT64 value:"10" v: [[INT64(1) INT64(0)][INT64(2) INT64(1)][INT64(3) INT64(2)][INT64(4) INT64(3)]] true`, + }, + }, { + /* + select col1, col2, col3, col4, col5, col6 from left join right on left.col1 = right.col4 + LHS: select col1, col2, col3 from left + RHS: select col1, col2, col3, col4, col5, col6 from (values row(1,2,3), ...) left(col1,col2,col3) join right on left.col1 = right.col4 + */ + + rowID: false, + rhsResults: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3|col4|col5|col6", + "int64|varchar|varchar|int64|varchar|varchar", + ), + "1|a|aa|1|d|dd", + "2|b|bb|2|e|ee", + "3|c|cc|3|f|ff", + "4|d|dd|4|g|gg", ), - "d|dd|0", - "e|ee|1", - "f|ff|2", - "g|gg|3", - ), + }, + expectedRHSLog: []string{ + `Execute a: type:INT64 value:"10" v: [[INT64(1) VARCHAR("a") VARCHAR("aa")][INT64(2) VARCHAR("b") VARCHAR("bb")][INT64(3) VARCHAR("c") VARCHAR("cc")][INT64(4) VARCHAR("d") VARCHAR("dd")]] true`, + }, }, } - bv := map[string]*querypb.BindVariable{ - "a": sqltypes.Int64BindVariable(10), - } + for _, tc := range testCases { + t.Run(fmt.Sprintf("rowID:%t", tc.rowID), func(t *testing.T) { + leftPrim := &fakePrimitive{ + useNewPrintBindVars: true, + results: []*sqltypes.Result{ + sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3", + "int64|varchar|varchar", + ), + "1|a|aa", + "2|b|bb", + "3|c|cc", + "4|d|dd", + ), + }, + } + rightPrim := &fakePrimitive{ + useNewPrintBindVars: true, + results: tc.rhsResults, + } - vjn := &ValuesJoin{ - Left: leftPrim, - Right: rightPrim, - CopyColumnsToRHS: []int{0}, - BindVarName: "v", - Cols: []int{-1, -2, -3, -1, 1, 2}, - ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"}, - } + bv := map[string]*querypb.BindVariable{ + "a": sqltypes.Int64BindVariable(10), + } + + vjn := &ValuesJoin{ + Left: leftPrim, + Right: rightPrim, + CopyColumnsToRHS: tc.CopyColumnsToRHS, + BindVarName: "v", + Cols: tc.cols, + ColNames: []string{"col1", "col2", "col3", "col4", "col5", "col6"}, + RowID: tc.rowID, + } - r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true) - require.NoError(t, err) - leftPrim.ExpectLog(t, []string{ - `Execute a: type:INT64 value:"10" true`, - }) - rightPrim.ExpectLog(t, []string{ - `Execute a: type:INT64 value:"10" v: [[INT64(0) INT64(1)][INT64(1) INT64(2)][INT64(2) INT64(3)][INT64(3) INT64(4)]] true`, - }) - - result := sqltypes.MakeTestResult( - sqltypes.MakeTestFields( - "col1|col2|col3|col4|col5|col6", - "int64|varchar|varchar|int64|varchar|varchar", - ), - "1|a|aa|1|d|dd", - "2|b|bb|2|e|ee", - "3|c|cc|3|f|ff", - "4|d|dd|4|g|gg", - ) - expectResult(t, r, result) + r, err := vjn.TryExecute(context.Background(), &noopVCursor{}, bv, true) + require.NoError(t, err) + leftPrim.ExpectLog(t, []string{ + `Execute a: type:INT64 value:"10" true`, + }) + rightPrim.ExpectLog(t, tc.expectedRHSLog) + + result := sqltypes.MakeTestResult( + sqltypes.MakeTestFields( + "col1|col2|col3|col4|col5|col6", + "int64|varchar|varchar|int64|varchar|varchar", + ), + "1|a|aa|1|d|dd", + "2|b|bb|2|e|ee", + "3|c|cc|3|f|ff", + "4|d|dd|4|g|gg", + ) + expectResult(t, r, result) + }) + } }