diff --git a/go/vt/vtgate/semantics/analyzer.go b/go/vt/vtgate/semantics/analyzer.go index 62cdc019ddf..d5cacb5c6fa 100644 --- a/go/vt/vtgate/semantics/analyzer.go +++ b/go/vt/vtgate/semantics/analyzer.go @@ -286,14 +286,14 @@ func containsStar(s sqlparser.SelectExprs) bool { } func checkUnionColumns(union *sqlparser.Union) error { - firstProj := sqlparser.GetFirstSelect(union).SelectExprs + firstProj := sqlparser.GetFirstSelect(union).GetColumns() if containsStar(firstProj) { // if we still have *, we can't figure out if the query is invalid or not // we'll fail it at run time instead return nil } - secondProj := sqlparser.GetFirstSelect(union.Right).SelectExprs + secondProj := sqlparser.GetFirstSelect(union.Right).GetColumns() if containsStar(secondProj) { return nil } diff --git a/go/vt/vtgate/semantics/cte_table.go b/go/vt/vtgate/semantics/cte_table.go index 498fc5076c1..29330e17ce2 100644 --- a/go/vt/vtgate/semantics/cte_table.go +++ b/go/vt/vtgate/semantics/cte_table.go @@ -150,7 +150,7 @@ func (cte *CTETable) GetMirrorRule() *vindexes.MirrorRule { type CTE struct { Name string - Query sqlparser.SelectStatement + Query sqlparser.TableSubquery isAuthoritative bool recursiveDeps *TableSet Columns sqlparser.Columns diff --git a/go/vt/vtgate/semantics/early_rewriter.go b/go/vt/vtgate/semantics/early_rewriter.go index 3e53ed0816a..99823f24ef5 100644 --- a/go/vt/vtgate/semantics/early_rewriter.go +++ b/go/vt/vtgate/semantics/early_rewriter.go @@ -385,12 +385,13 @@ func getIntLiteral(e sqlparser.Expr) *sqlparser.Literal { // handleOrderBy processes the ORDER BY clause. func (r *earlyRewriter) handleOrderBy(parent sqlparser.SQLNode, iter iterator) error { - stmt, ok := parent.(sqlparser.SelectStatement) + stmt, ok := parent.(sqlparser.TableSubquery) if !ok { return nil } sel := sqlparser.GetFirstSelect(stmt) + for e := iter.next(); e != nil; e = iter.next() { lit, err := r.replaceLiteralsInOrderBy(e, iter) if err != nil { diff --git a/go/vt/vtgate/semantics/table_collector.go b/go/vt/vtgate/semantics/table_collector.go index 45a50fd23a2..a5cd18d0b6c 100644 --- a/go/vt/vtgate/semantics/table_collector.go +++ b/go/vt/vtgate/semantics/table_collector.go @@ -155,7 +155,7 @@ func (tc *tableCollector) visitAliasedTableExpr(node *sqlparser.AliasedTableExpr func (tc *tableCollector) visitUnion(union *sqlparser.Union) error { firstSelect := sqlparser.GetFirstSelect(union) - expanded, selectExprs := getColumnNames(firstSelect.SelectExprs) + expanded, selectExprs := getColumnNames(firstSelect.GetColumns()) info := unionInfo{ isAuthoritative: expanded, exprs: selectExprs, @@ -165,7 +165,7 @@ func (tc *tableCollector) visitUnion(union *sqlparser.Union) error { return nil } - size := len(firstSelect.SelectExprs) + size := firstSelect.GetColumnCount() info.recursive = make([]TableSet, size) typers := make([]evalengine.TypeAggregator, size) collations := tc.org.collationEnv() @@ -414,7 +414,8 @@ func checkValidRecursiveCTE(cteDef *CTE) error { } firstSelect := sqlparser.GetFirstSelect(union.Right) - if firstSelect.GroupBy != nil { + + if slct, ok := firstSelect.(*sqlparser.Select); ok && slct.GroupBy != nil { return vterrors.VT09027(cteDef.Name) } @@ -470,24 +471,35 @@ func (tc *tableCollector) addSelectDerivedTable( return scope.addTable(tableInfo) } -func (tc *tableCollector) addUnionDerivedTable(union *sqlparser.Union, node *sqlparser.AliasedTableExpr, columns sqlparser.Columns, alias sqlparser.IdentifierCS) error { - firstSelect := sqlparser.GetFirstSelect(union) - tables := tc.scoper.wScope[firstSelect] - info, found := tc.unionInfo[union] - if !found { - return vterrors.VT13001("information about union is not available") - } +func (tc *tableCollector) addUnionDerivedTable( + union *sqlparser.Union, + node *sqlparser.AliasedTableExpr, + columns sqlparser.Columns, + alias sqlparser.IdentifierCS, +) error { + switch firstSelect := sqlparser.GetFirstSelect(union).(type) { + case *sqlparser.Select: + tables := tc.scoper.wScope[firstSelect] + info, found := tc.unionInfo[union] + if !found { + return vterrors.VT13001("information about union is not available") + } - tableInfo := createDerivedTableForExpressions(info.exprs, columns, tables.tables, tc.org, info.isAuthoritative, info.recursive, info.types) - if err := tableInfo.checkForDuplicates(); err != nil { - return err - } - tableInfo.ASTNode = node - tableInfo.tableName = alias.String() + tableInfo := createDerivedTableForExpressions(info.exprs, columns, tables.tables, tc.org, info.isAuthoritative, info.recursive, info.types) + if err := tableInfo.checkForDuplicates(); err != nil { + return err + } + tableInfo.ASTNode = node + tableInfo.tableName = alias.String() - tc.Tables = append(tc.Tables, tableInfo) - scope := tc.scoper.currentScope() - return scope.addTable(tableInfo) + tc.Tables = append(tc.Tables, tableInfo) + scope := tc.scoper.currentScope() + return scope.addTable(tableInfo) + case *sqlparser.ValuesStatement: + return vterrors.VT12001("still don't support values inside derived tables") + default: + return vterrors.VT12001(fmt.Sprintf("type not expected %T", firstSelect)) + } } func newVindexTable(t sqlparser.IdentifierCS) *vindexes.Table {