diff --git a/go/vt/sqlparser/ast.go b/go/vt/sqlparser/ast.go index 806784c27e8..a182acfb5a0 100644 --- a/go/vt/sqlparser/ast.go +++ b/go/vt/sqlparser/ast.go @@ -57,8 +57,8 @@ type ( OrderAndLimit interface { AddOrder(*Order) SetLimit(*Limit) - GetOrderBy() OrderBy - SetOrderBy(OrderBy) + GetOrderBy() []*Order + SetOrderBy([]*Order) GetLimit() *Limit } @@ -300,7 +300,7 @@ type ( GroupBy *GroupBy Having *Where Windows NamedWindows - OrderBy OrderBy + OrderBy []*Order Limit *Limit Lock Lock Into *SelectInto @@ -329,7 +329,7 @@ type ( Left TableStatement Right TableStatement Distinct bool - OrderBy OrderBy + OrderBy []*Order Limit *Limit Lock Lock Into *SelectInto @@ -388,7 +388,7 @@ type ( TableExprs []TableExpr Exprs UpdateExprs Where *Where - OrderBy OrderBy + OrderBy []*Order Limit *Limit } @@ -402,7 +402,7 @@ type ( Targets TableNames Partitions Partitions Where *Where - OrderBy OrderBy + OrderBy []*Order Limit *Limit } @@ -2211,7 +2211,7 @@ type ( WindowSpecification struct { Name IdentifierCI PartitionClause []Expr - OrderClause OrderBy + OrderBy []*Order FrameClause *FrameClause } @@ -3063,7 +3063,7 @@ type ( GroupConcatExpr struct { Distinct bool Exprs []Expr - OrderBy OrderBy + OrderBy []*Order Separator string Limit *Limit } @@ -3617,7 +3617,7 @@ type ValuesStatement struct { ListArg ListArg Comments *ParsedComments - Order OrderBy + OrderBy []*Order Limit *Limit } diff --git a/go/vt/sqlparser/ast_clone.go b/go/vt/sqlparser/ast_clone.go index 997485e086c..df1482e74a1 100644 --- a/go/vt/sqlparser/ast_clone.go +++ b/go/vt/sqlparser/ast_clone.go @@ -1196,7 +1196,7 @@ func CloneRefOfDelete(n *Delete) *Delete { out.Targets = CloneTableNames(n.Targets) out.Partitions = ClonePartitions(n.Partitions) out.Where = CloneRefOfWhere(n.Where) - out.OrderBy = CloneOrderBy(n.OrderBy) + out.OrderBy = CloneSliceOfRefOfOrder(n.OrderBy) out.Limit = CloneRefOfLimit(n.Limit) return &out } @@ -1568,7 +1568,7 @@ func CloneRefOfGroupConcatExpr(n *GroupConcatExpr) *GroupConcatExpr { } out := *n out.Exprs = CloneSliceOfExpr(n.Exprs) - out.OrderBy = CloneOrderBy(n.OrderBy) + out.OrderBy = CloneSliceOfRefOfOrder(n.OrderBy) out.Limit = CloneRefOfLimit(n.Limit) return &out } @@ -2769,7 +2769,7 @@ func CloneRefOfSelect(n *Select) *Select { out.GroupBy = CloneRefOfGroupBy(n.GroupBy) out.Having = CloneRefOfWhere(n.Having) out.Windows = CloneNamedWindows(n.Windows) - out.OrderBy = CloneOrderBy(n.OrderBy) + out.OrderBy = CloneSliceOfRefOfOrder(n.OrderBy) out.Limit = CloneRefOfLimit(n.Limit) out.Into = CloneRefOfSelectInto(n.Into) return &out @@ -3182,7 +3182,7 @@ func CloneRefOfUnion(n *Union) *Union { out.With = CloneRefOfWith(n.With) out.Left = CloneTableStatement(n.Left) out.Right = CloneTableStatement(n.Right) - out.OrderBy = CloneOrderBy(n.OrderBy) + out.OrderBy = CloneSliceOfRefOfOrder(n.OrderBy) out.Limit = CloneRefOfLimit(n.Limit) out.Into = CloneRefOfSelectInto(n.Into) return &out @@ -3208,7 +3208,7 @@ func CloneRefOfUpdate(n *Update) *Update { out.TableExprs = CloneSliceOfTableExpr(n.TableExprs) out.Exprs = CloneUpdateExprs(n.Exprs) out.Where = CloneRefOfWhere(n.Where) - out.OrderBy = CloneOrderBy(n.OrderBy) + out.OrderBy = CloneSliceOfRefOfOrder(n.OrderBy) out.Limit = CloneRefOfLimit(n.Limit) return &out } @@ -3335,7 +3335,7 @@ func CloneRefOfValuesStatement(n *ValuesStatement) *ValuesStatement { out.With = CloneRefOfWith(n.With) out.Rows = CloneValues(n.Rows) out.Comments = CloneRefOfParsedComments(n.Comments) - out.Order = CloneOrderBy(n.Order) + out.OrderBy = CloneSliceOfRefOfOrder(n.OrderBy) out.Limit = CloneRefOfLimit(n.Limit) return &out } @@ -3463,7 +3463,7 @@ func CloneRefOfWindowSpecification(n *WindowSpecification) *WindowSpecification out := *n out.Name = CloneIdentifierCI(n.Name) out.PartitionClause = CloneSliceOfExpr(n.PartitionClause) - out.OrderClause = CloneOrderBy(n.OrderClause) + out.OrderBy = CloneSliceOfRefOfOrder(n.OrderBy) out.FrameClause = CloneRefOfFrameClause(n.FrameClause) return &out } @@ -4493,6 +4493,18 @@ func CloneSliceOfTableExpr(n []TableExpr) []TableExpr { return res } +// CloneSliceOfRefOfOrder creates a deep clone of the input. +func CloneSliceOfRefOfOrder(n []*Order) []*Order { + if n == nil { + return nil + } + res := make([]*Order, len(n)) + for i, x := range n { + res[i] = CloneRefOfOrder(x) + } + return res +} + // CloneSliceOfRefOfVariable creates a deep clone of the input. func CloneSliceOfRefOfVariable(n []*Variable) []*Variable { if n == nil { diff --git a/go/vt/sqlparser/ast_copy_on_rewrite.go b/go/vt/sqlparser/ast_copy_on_rewrite.go index 46ce38ee29d..7f7229ac9f9 100644 --- a/go/vt/sqlparser/ast_copy_on_rewrite.go +++ b/go/vt/sqlparser/ast_copy_on_rewrite.go @@ -1894,7 +1894,15 @@ func (c *cow) copyOnRewriteRefOfDelete(n *Delete, parent SQLNode) (out SQLNode, _Targets, changedTargets := c.copyOnRewriteTableNames(n.Targets, n) _Partitions, changedPartitions := c.copyOnRewritePartitions(n.Partitions, n) _Where, changedWhere := c.copyOnRewriteRefOfWhere(n.Where, n) - _OrderBy, changedOrderBy := c.copyOnRewriteOrderBy(n.OrderBy, n) + var changedOrderBy bool + _OrderBy := make([]*Order, len(n.OrderBy)) + for x, el := range n.OrderBy { + this, changed := c.copyOnRewriteRefOfOrder(el, n) + _OrderBy[x] = this.(*Order) + if changed { + changedOrderBy = true + } + } _Limit, changedLimit := c.copyOnRewriteRefOfLimit(n.Limit, n) if changedWith || changedComments || changedTableExprs || changedTargets || changedPartitions || changedWhere || changedOrderBy || changedLimit { res := *n @@ -1904,7 +1912,7 @@ func (c *cow) copyOnRewriteRefOfDelete(n *Delete, parent SQLNode) (out SQLNode, res.Targets, _ = _Targets.(TableNames) res.Partitions, _ = _Partitions.(Partitions) res.Where, _ = _Where.(*Where) - res.OrderBy, _ = _OrderBy.(OrderBy) + res.OrderBy = _OrderBy res.Limit, _ = _Limit.(*Limit) out = &res if c.cloned != nil { @@ -2733,12 +2741,20 @@ func (c *cow) copyOnRewriteRefOfGroupConcatExpr(n *GroupConcatExpr, parent SQLNo changedExprs = true } } - _OrderBy, changedOrderBy := c.copyOnRewriteOrderBy(n.OrderBy, n) + var changedOrderBy bool + _OrderBy := make([]*Order, len(n.OrderBy)) + for x, el := range n.OrderBy { + this, changed := c.copyOnRewriteRefOfOrder(el, n) + _OrderBy[x] = this.(*Order) + if changed { + changedOrderBy = true + } + } _Limit, changedLimit := c.copyOnRewriteRefOfLimit(n.Limit, n) if changedExprs || changedOrderBy || changedLimit { res := *n res.Exprs = _Exprs - res.OrderBy, _ = _OrderBy.(OrderBy) + res.OrderBy = _OrderBy res.Limit, _ = _Limit.(*Limit) out = &res if c.cloned != nil { @@ -5366,7 +5382,15 @@ func (c *cow) copyOnRewriteRefOfSelect(n *Select, parent SQLNode) (out SQLNode, _GroupBy, changedGroupBy := c.copyOnRewriteRefOfGroupBy(n.GroupBy, n) _Having, changedHaving := c.copyOnRewriteRefOfWhere(n.Having, n) _Windows, changedWindows := c.copyOnRewriteNamedWindows(n.Windows, n) - _OrderBy, changedOrderBy := c.copyOnRewriteOrderBy(n.OrderBy, n) + var changedOrderBy bool + _OrderBy := make([]*Order, len(n.OrderBy)) + for x, el := range n.OrderBy { + this, changed := c.copyOnRewriteRefOfOrder(el, n) + _OrderBy[x] = this.(*Order) + if changed { + changedOrderBy = true + } + } _Limit, changedLimit := c.copyOnRewriteRefOfLimit(n.Limit, n) _Into, changedInto := c.copyOnRewriteRefOfSelectInto(n.Into, n) if changedWith || changedFrom || changedComments || changedSelectExprs || changedWhere || changedGroupBy || changedHaving || changedWindows || changedOrderBy || changedLimit || changedInto { @@ -5379,7 +5403,7 @@ func (c *cow) copyOnRewriteRefOfSelect(n *Select, parent SQLNode) (out SQLNode, res.GroupBy, _ = _GroupBy.(*GroupBy) res.Having, _ = _Having.(*Where) res.Windows, _ = _Windows.(NamedWindows) - res.OrderBy, _ = _OrderBy.(OrderBy) + res.OrderBy = _OrderBy res.Limit, _ = _Limit.(*Limit) res.Into, _ = _Into.(*SelectInto) out = &res @@ -6231,7 +6255,15 @@ func (c *cow) copyOnRewriteRefOfUnion(n *Union, parent SQLNode) (out SQLNode, ch _With, changedWith := c.copyOnRewriteRefOfWith(n.With, n) _Left, changedLeft := c.copyOnRewriteTableStatement(n.Left, n) _Right, changedRight := c.copyOnRewriteTableStatement(n.Right, n) - _OrderBy, changedOrderBy := c.copyOnRewriteOrderBy(n.OrderBy, n) + var changedOrderBy bool + _OrderBy := make([]*Order, len(n.OrderBy)) + for x, el := range n.OrderBy { + this, changed := c.copyOnRewriteRefOfOrder(el, n) + _OrderBy[x] = this.(*Order) + if changed { + changedOrderBy = true + } + } _Limit, changedLimit := c.copyOnRewriteRefOfLimit(n.Limit, n) _Into, changedInto := c.copyOnRewriteRefOfSelectInto(n.Into, n) if changedWith || changedLeft || changedRight || changedOrderBy || changedLimit || changedInto { @@ -6239,7 +6271,7 @@ func (c *cow) copyOnRewriteRefOfUnion(n *Union, parent SQLNode) (out SQLNode, ch res.With, _ = _With.(*With) res.Left, _ = _Left.(TableStatement) res.Right, _ = _Right.(TableStatement) - res.OrderBy, _ = _OrderBy.(OrderBy) + res.OrderBy = _OrderBy res.Limit, _ = _Limit.(*Limit) res.Into, _ = _Into.(*SelectInto) out = &res @@ -6285,7 +6317,15 @@ func (c *cow) copyOnRewriteRefOfUpdate(n *Update, parent SQLNode) (out SQLNode, } _Exprs, changedExprs := c.copyOnRewriteUpdateExprs(n.Exprs, n) _Where, changedWhere := c.copyOnRewriteRefOfWhere(n.Where, n) - _OrderBy, changedOrderBy := c.copyOnRewriteOrderBy(n.OrderBy, n) + var changedOrderBy bool + _OrderBy := make([]*Order, len(n.OrderBy)) + for x, el := range n.OrderBy { + this, changed := c.copyOnRewriteRefOfOrder(el, n) + _OrderBy[x] = this.(*Order) + if changed { + changedOrderBy = true + } + } _Limit, changedLimit := c.copyOnRewriteRefOfLimit(n.Limit, n) if changedWith || changedComments || changedTableExprs || changedExprs || changedWhere || changedOrderBy || changedLimit { res := *n @@ -6294,7 +6334,7 @@ func (c *cow) copyOnRewriteRefOfUpdate(n *Update, parent SQLNode) (out SQLNode, res.TableExprs = _TableExprs res.Exprs, _ = _Exprs.(UpdateExprs) res.Where, _ = _Where.(*Where) - res.OrderBy, _ = _OrderBy.(OrderBy) + res.OrderBy = _OrderBy res.Limit, _ = _Limit.(*Limit) out = &res if c.cloned != nil { @@ -6547,15 +6587,23 @@ func (c *cow) copyOnRewriteRefOfValuesStatement(n *ValuesStatement, parent SQLNo _Rows, changedRows := c.copyOnRewriteValues(n.Rows, n) _ListArg, changedListArg := c.copyOnRewriteListArg(n.ListArg, n) _Comments, changedComments := c.copyOnRewriteRefOfParsedComments(n.Comments, n) - _Order, changedOrder := c.copyOnRewriteOrderBy(n.Order, n) + var changedOrderBy bool + _OrderBy := make([]*Order, len(n.OrderBy)) + for x, el := range n.OrderBy { + this, changed := c.copyOnRewriteRefOfOrder(el, n) + _OrderBy[x] = this.(*Order) + if changed { + changedOrderBy = true + } + } _Limit, changedLimit := c.copyOnRewriteRefOfLimit(n.Limit, n) - if changedWith || changedRows || changedListArg || changedComments || changedOrder || changedLimit { + if changedWith || changedRows || changedListArg || changedComments || changedOrderBy || changedLimit { res := *n res.With, _ = _With.(*With) res.Rows, _ = _Rows.(Values) res.ListArg, _ = _ListArg.(ListArg) res.Comments, _ = _Comments.(*ParsedComments) - res.Order, _ = _Order.(OrderBy) + res.OrderBy = _OrderBy res.Limit, _ = _Limit.(*Limit) out = &res if c.cloned != nil { @@ -6849,13 +6897,21 @@ func (c *cow) copyOnRewriteRefOfWindowSpecification(n *WindowSpecification, pare changedPartitionClause = true } } - _OrderClause, changedOrderClause := c.copyOnRewriteOrderBy(n.OrderClause, n) + var changedOrderBy bool + _OrderBy := make([]*Order, len(n.OrderBy)) + for x, el := range n.OrderBy { + this, changed := c.copyOnRewriteRefOfOrder(el, n) + _OrderBy[x] = this.(*Order) + if changed { + changedOrderBy = true + } + } _FrameClause, changedFrameClause := c.copyOnRewriteRefOfFrameClause(n.FrameClause, n) - if changedName || changedPartitionClause || changedOrderClause || changedFrameClause { + if changedName || changedPartitionClause || changedOrderBy || changedFrameClause { res := *n res.Name, _ = _Name.(IdentifierCI) res.PartitionClause = _PartitionClause - res.OrderClause, _ = _OrderClause.(OrderBy) + res.OrderBy = _OrderBy res.FrameClause, _ = _FrameClause.(*FrameClause) out = &res if c.cloned != nil { diff --git a/go/vt/sqlparser/ast_equals.go b/go/vt/sqlparser/ast_equals.go index 2795daab2c5..b119797ddb5 100644 --- a/go/vt/sqlparser/ast_equals.go +++ b/go/vt/sqlparser/ast_equals.go @@ -2406,7 +2406,7 @@ func (cmp *Comparator) RefOfDelete(a, b *Delete) bool { cmp.TableNames(a.Targets, b.Targets) && cmp.Partitions(a.Partitions, b.Partitions) && cmp.RefOfWhere(a.Where, b.Where) && - cmp.OrderBy(a.OrderBy, b.OrderBy) && + cmp.SliceOfRefOfOrder(a.OrderBy, b.OrderBy) && cmp.RefOfLimit(a.Limit, b.Limit) } @@ -2840,7 +2840,7 @@ func (cmp *Comparator) RefOfGroupConcatExpr(a, b *GroupConcatExpr) bool { return a.Distinct == b.Distinct && a.Separator == b.Separator && cmp.SliceOfExpr(a.Exprs, b.Exprs) && - cmp.OrderBy(a.OrderBy, b.OrderBy) && + cmp.SliceOfRefOfOrder(a.OrderBy, b.OrderBy) && cmp.RefOfLimit(a.Limit, b.Limit) } @@ -4211,7 +4211,7 @@ func (cmp *Comparator) RefOfSelect(a, b *Select) bool { cmp.RefOfGroupBy(a.GroupBy, b.GroupBy) && cmp.RefOfWhere(a.Having, b.Having) && cmp.NamedWindows(a.Windows, b.Windows) && - cmp.OrderBy(a.OrderBy, b.OrderBy) && + cmp.SliceOfRefOfOrder(a.OrderBy, b.OrderBy) && cmp.RefOfLimit(a.Limit, b.Limit) && a.Lock == b.Lock && cmp.RefOfSelectInto(a.Into, b.Into) @@ -4689,7 +4689,7 @@ func (cmp *Comparator) RefOfUnion(a, b *Union) bool { cmp.RefOfWith(a.With, b.With) && cmp.TableStatement(a.Left, b.Left) && cmp.TableStatement(a.Right, b.Right) && - cmp.OrderBy(a.OrderBy, b.OrderBy) && + cmp.SliceOfRefOfOrder(a.OrderBy, b.OrderBy) && cmp.RefOfLimit(a.Limit, b.Limit) && a.Lock == b.Lock && cmp.RefOfSelectInto(a.Into, b.Into) @@ -4720,7 +4720,7 @@ func (cmp *Comparator) RefOfUpdate(a, b *Update) bool { cmp.SliceOfTableExpr(a.TableExprs, b.TableExprs) && cmp.UpdateExprs(a.Exprs, b.Exprs) && cmp.RefOfWhere(a.Where, b.Where) && - cmp.OrderBy(a.OrderBy, b.OrderBy) && + cmp.SliceOfRefOfOrder(a.OrderBy, b.OrderBy) && cmp.RefOfLimit(a.Limit, b.Limit) } @@ -4861,7 +4861,7 @@ func (cmp *Comparator) RefOfValuesStatement(a, b *ValuesStatement) bool { cmp.Values(a.Rows, b.Rows) && a.ListArg == b.ListArg && cmp.RefOfParsedComments(a.Comments, b.Comments) && - cmp.OrderBy(a.Order, b.Order) && + cmp.SliceOfRefOfOrder(a.OrderBy, b.OrderBy) && cmp.RefOfLimit(a.Limit, b.Limit) } @@ -5003,7 +5003,7 @@ func (cmp *Comparator) RefOfWindowSpecification(a, b *WindowSpecification) bool } return cmp.IdentifierCI(a.Name, b.Name) && cmp.SliceOfExpr(a.PartitionClause, b.PartitionClause) && - cmp.OrderBy(a.OrderClause, b.OrderClause) && + cmp.SliceOfRefOfOrder(a.OrderBy, b.OrderBy) && cmp.RefOfFrameClause(a.FrameClause, b.FrameClause) } @@ -7419,6 +7419,19 @@ func (cmp *Comparator) SliceOfTableExpr(a, b []TableExpr) bool { return true } +// SliceOfRefOfOrder does deep equals between the two objects. +func (cmp *Comparator) SliceOfRefOfOrder(a, b []*Order) bool { + if len(a) != len(b) { + return false + } + for i := 0; i < len(a); i++ { + if !cmp.RefOfOrder(a[i], b[i]) { + return false + } + } + return true +} + // SliceOfRefOfVariable does deep equals between the two objects. func (cmp *Comparator) SliceOfRefOfVariable(a, b []*Variable) bool { if len(a) != len(b) { diff --git a/go/vt/sqlparser/ast_format.go b/go/vt/sqlparser/ast_format.go index 8939af71f51..276ab89ae94 100644 --- a/go/vt/sqlparser/ast_format.go +++ b/go/vt/sqlparser/ast_format.go @@ -72,9 +72,8 @@ func (node *Select) Format(buf *TrackedBuffer) { if node.Windows != nil { buf.astPrintf(node, " %v", node.Windows) } - - buf.astPrintf(node, "%v%v%s%v", - node.OrderBy, + formatSlice(buf, OrderByForStr, node.OrderBy) + buf.astPrintf(node, "%v%s%v", node.Limit, node.Lock.ToString(), node.Into) } @@ -111,7 +110,8 @@ func (node *Union) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%v", node.Right) } - buf.astPrintf(node, "%v%v%s", node.OrderBy, node.Limit, node.Lock.ToString()) + formatSlice(buf, OrderByForStr, node.OrderBy) + buf.astPrintf(node, "%v%s", node.Limit, node.Lock.ToString()) } // Format formats the node. @@ -136,8 +136,8 @@ func (node *ValuesStatement) Format(buf *TrackedBuffer) { } } } - buf.astPrintf(node, "%v%v", - node.Order, node.Limit) + formatSlice(buf, OrderByForStr, node.OrderBy) + buf.astPrintf(node, "%v", node.Limit) } // Format formats the node. @@ -202,7 +202,9 @@ func (node *Update) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%s%v", prefix, expr) prefix = ", " } - buf.astPrintf(node, " set %v%v%v%v", node.Exprs, node.Where, node.OrderBy, node.Limit) + buf.astPrintf(node, " set %v%v", node.Exprs, node.Where) + formatSlice(buf, OrderByForStr, node.OrderBy) + buf.astPrintf(node, "%v", node.Limit) } // Format formats the node. @@ -222,7 +224,9 @@ func (node *Delete) Format(buf *TrackedBuffer) { buf.astPrintf(node, "%s%v", prefix, expr) prefix = ", " } - buf.astPrintf(node, "%v%v%v%v", node.Partitions, node.Where, node.OrderBy, node.Limit) + buf.astPrintf(node, "%v%v", node.Partitions, node.Where) + formatSlice(buf, OrderByForStr, node.OrderBy) + buf.astPrintf(node, "%v", node.Limit) } // Format formats the node. @@ -1700,10 +1704,11 @@ func (node *FuncExpr) Format(buf *TrackedBuffer) { // Format formats the node func (node *GroupConcatExpr) Format(buf *TrackedBuffer) { if node.Distinct { - buf.astPrintf(node, "group_concat(%s%n%v", DistinctStr, node.Exprs, node.OrderBy) + buf.astPrintf(node, "group_concat(%s%n", DistinctStr, node.Exprs) } else { - buf.astPrintf(node, "group_concat(%n%v", node.Exprs, node.OrderBy) + buf.astPrintf(node, "group_concat(%n", node.Exprs) } + formatSlice(buf, OrderByForStr, node.OrderBy) if node.Separator != "" { buf.astPrintf(node, " %s %#s", keywordStrings[SEPARATOR], node.Separator) } @@ -1752,9 +1757,7 @@ func (node *WindowSpecification) Format(buf *TrackedBuffer) { if node.PartitionClause != nil { buf.astPrintf(node, " partition by %n", node.PartitionClause) } - if node.OrderClause != nil { - buf.astPrintf(node, "%v", node.OrderClause) - } + formatSlice(buf, OrderByForStr, node.OrderBy) if node.FrameClause != nil { buf.astPrintf(node, "%v", node.FrameClause) } diff --git a/go/vt/sqlparser/ast_format_fast.go b/go/vt/sqlparser/ast_format_fast.go index 2a6797871eb..8f5f67d8980 100644 --- a/go/vt/sqlparser/ast_format_fast.go +++ b/go/vt/sqlparser/ast_format_fast.go @@ -80,8 +80,7 @@ func (node *Select) FormatFast(buf *TrackedBuffer) { buf.WriteByte(' ') node.Windows.FormatFast(buf) } - - node.OrderBy.FormatFast(buf) + formatSlice(buf, OrderByForStr, node.OrderBy) node.Limit.FormatFast(buf) buf.WriteString(node.Lock.ToString()) @@ -126,7 +125,7 @@ func (node *Union) FormatFast(buf *TrackedBuffer) { node.Right.FormatFast(buf) } - node.OrderBy.FormatFast(buf) + formatSlice(buf, OrderByForStr, node.OrderBy) node.Limit.FormatFast(buf) buf.WriteString(node.Lock.ToString()) } @@ -159,10 +158,8 @@ func (node *ValuesStatement) FormatFast(buf *TrackedBuffer) { } } } - - node.Order.FormatFast(buf) + formatSlice(buf, OrderByForStr, node.OrderBy) node.Limit.FormatFast(buf) - } // FormatFast formats the node. @@ -289,7 +286,7 @@ func (node *Update) FormatFast(buf *TrackedBuffer) { buf.WriteString(" set ") node.Exprs.FormatFast(buf) node.Where.FormatFast(buf) - node.OrderBy.FormatFast(buf) + formatSlice(buf, OrderByForStr, node.OrderBy) node.Limit.FormatFast(buf) } @@ -315,7 +312,7 @@ func (node *Delete) FormatFast(buf *TrackedBuffer) { } node.Partitions.FormatFast(buf) node.Where.FormatFast(buf) - node.OrderBy.FormatFast(buf) + formatSlice(buf, OrderByForStr, node.OrderBy) node.Limit.FormatFast(buf) } @@ -2238,12 +2235,11 @@ func (node *GroupConcatExpr) FormatFast(buf *TrackedBuffer) { buf.WriteString("group_concat(") buf.WriteString(DistinctStr) buf.formatExprs(node.Exprs) - node.OrderBy.FormatFast(buf) } else { buf.WriteString("group_concat(") buf.formatExprs(node.Exprs) - node.OrderBy.FormatFast(buf) } + formatSlice(buf, OrderByForStr, node.OrderBy) if node.Separator != "" { buf.WriteByte(' ') buf.WriteString(keywordStrings[SEPARATOR]) @@ -2309,9 +2305,7 @@ func (node *WindowSpecification) FormatFast(buf *TrackedBuffer) { buf.WriteString(" partition by ") buf.formatExprs(node.PartitionClause) } - if node.OrderClause != nil { - node.OrderClause.FormatFast(buf) - } + formatSlice(buf, OrderByForStr, node.OrderBy) if node.FrameClause != nil { node.FrameClause.FormatFast(buf) } diff --git a/go/vt/sqlparser/ast_funcs.go b/go/vt/sqlparser/ast_funcs.go index 47806e4afd4..06a2d2b7ea7 100644 --- a/go/vt/sqlparser/ast_funcs.go +++ b/go/vt/sqlparser/ast_funcs.go @@ -1222,12 +1222,12 @@ func (node *Select) AddOrder(order *Order) { } // SetOrderBy sets the order by clause -func (node *Select) SetOrderBy(orderBy OrderBy) { +func (node *Select) SetOrderBy(orderBy []*Order) { node.OrderBy = orderBy } // GetOrderBy gets the order by clause -func (node *Select) GetOrderBy() OrderBy { +func (node *Select) GetOrderBy() []*Order { return node.OrderBy } @@ -1365,12 +1365,12 @@ func (node *Union) AddOrder(order *Order) { } // SetOrderBy sets the order by clause -func (node *Union) SetOrderBy(orderBy OrderBy) { +func (node *Union) SetOrderBy(orderBy []*Order) { node.OrderBy = orderBy } // GetOrderBy gets the order by clause -func (node *Union) GetOrderBy() OrderBy { +func (node *Union) GetOrderBy() []*Order { return node.OrderBy } @@ -2910,11 +2910,11 @@ func (node *Update) SetLimit(limit *Limit) { node.Limit = limit } -func (node *Update) GetOrderBy() OrderBy { +func (node *Update) GetOrderBy() []*Order { return node.OrderBy } -func (node *Update) SetOrderBy(by OrderBy) { +func (node *Update) SetOrderBy(by []*Order) { node.OrderBy = by } @@ -2922,11 +2922,11 @@ func (node *Update) GetLimit() *Limit { return node.Limit } -func (node *Delete) GetOrderBy() OrderBy { +func (node *Delete) GetOrderBy() []*Order { return node.OrderBy } -func (node *Delete) SetOrderBy(by OrderBy) { +func (node *Delete) SetOrderBy(by []*Order) { node.OrderBy = by } @@ -3049,12 +3049,12 @@ func (node *ValuesStatement) SetWith(with *With) { node.With = with } -func (node *ValuesStatement) GetOrderBy() OrderBy { - return node.Order +func (node *ValuesStatement) GetOrderBy() []*Order { + return node.OrderBy } -func (node *ValuesStatement) SetOrderBy(by OrderBy) { - node.Order = by +func (node *ValuesStatement) SetOrderBy(by []*Order) { + node.OrderBy = by } func (node *ValuesStatement) GetLimit() *Limit { @@ -3062,7 +3062,7 @@ func (node *ValuesStatement) GetLimit() *Limit { } func (node *ValuesStatement) AddOrder(order *Order) { - node.Order = append(node.Order, order) + node.OrderBy = append(node.OrderBy, order) } func (node *ValuesStatement) SetLimit(limit *Limit) { diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index a4cfeb4f01c..0dfca3974f1 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -2682,10 +2682,14 @@ func (a *application) rewriteRefOfDelete(parent SQLNode, node *Delete, replacer }) { return false } - if !a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*Delete).OrderBy = newNode.(OrderBy) - }) { - return false + for x, el := range node.OrderBy { + if !a.rewriteRefOfOrder(node, el, func(idx int) replacerFunc { + return func(newNode, parent SQLNode) { + parent.(*Delete).OrderBy[idx] = newNode.(*Order) + } + }(x)) { + return false + } } if !a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*Delete).Limit = newNode.(*Limit) @@ -3949,10 +3953,14 @@ func (a *application) rewriteRefOfGroupConcatExpr(parent SQLNode, node *GroupCon return false } } - if !a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*GroupConcatExpr).OrderBy = newNode.(OrderBy) - }) { - return false + for x, el := range node.OrderBy { + if !a.rewriteRefOfOrder(node, el, func(idx int) replacerFunc { + return func(newNode, parent SQLNode) { + parent.(*GroupConcatExpr).OrderBy[idx] = newNode.(*Order) + } + }(x)) { + return false + } } if !a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*GroupConcatExpr).Limit = newNode.(*Limit) @@ -8078,10 +8086,14 @@ func (a *application) rewriteRefOfSelect(parent SQLNode, node *Select, replacer }) { return false } - if !a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*Select).OrderBy = newNode.(OrderBy) - }) { - return false + for x, el := range node.OrderBy { + if !a.rewriteRefOfOrder(node, el, func(idx int) replacerFunc { + return func(newNode, parent SQLNode) { + parent.(*Select).OrderBy[idx] = newNode.(*Order) + } + }(x)) { + return false + } } if !a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*Select).Limit = newNode.(*Limit) @@ -9452,10 +9464,14 @@ func (a *application) rewriteRefOfUnion(parent SQLNode, node *Union, replacer re }) { return false } - if !a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*Union).OrderBy = newNode.(OrderBy) - }) { - return false + for x, el := range node.OrderBy { + if !a.rewriteRefOfOrder(node, el, func(idx int) replacerFunc { + return func(newNode, parent SQLNode) { + parent.(*Union).OrderBy[idx] = newNode.(*Order) + } + }(x)) { + return false + } } if !a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*Union).Limit = newNode.(*Limit) @@ -9552,10 +9568,14 @@ func (a *application) rewriteRefOfUpdate(parent SQLNode, node *Update, replacer }) { return false } - if !a.rewriteOrderBy(node, node.OrderBy, func(newNode, parent SQLNode) { - parent.(*Update).OrderBy = newNode.(OrderBy) - }) { - return false + for x, el := range node.OrderBy { + if !a.rewriteRefOfOrder(node, el, func(idx int) replacerFunc { + return func(newNode, parent SQLNode) { + parent.(*Update).OrderBy[idx] = newNode.(*Order) + } + }(x)) { + return false + } } if !a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*Update).Limit = newNode.(*Limit) @@ -9978,10 +9998,14 @@ func (a *application) rewriteRefOfValuesStatement(parent SQLNode, node *ValuesSt }) { return false } - if !a.rewriteOrderBy(node, node.Order, func(newNode, parent SQLNode) { - parent.(*ValuesStatement).Order = newNode.(OrderBy) - }) { - return false + for x, el := range node.OrderBy { + if !a.rewriteRefOfOrder(node, el, func(idx int) replacerFunc { + return func(newNode, parent SQLNode) { + parent.(*ValuesStatement).OrderBy[idx] = newNode.(*Order) + } + }(x)) { + return false + } } if !a.rewriteRefOfLimit(node, node.Limit, func(newNode, parent SQLNode) { parent.(*ValuesStatement).Limit = newNode.(*Limit) @@ -10426,10 +10450,14 @@ func (a *application) rewriteRefOfWindowSpecification(parent SQLNode, node *Wind return false } } - if !a.rewriteOrderBy(node, node.OrderClause, func(newNode, parent SQLNode) { - parent.(*WindowSpecification).OrderClause = newNode.(OrderBy) - }) { - return false + for x, el := range node.OrderBy { + if !a.rewriteRefOfOrder(node, el, func(idx int) replacerFunc { + return func(newNode, parent SQLNode) { + parent.(*WindowSpecification).OrderBy[idx] = newNode.(*Order) + } + }(x)) { + return false + } } if !a.rewriteRefOfFrameClause(node, node.FrameClause, func(newNode, parent SQLNode) { parent.(*WindowSpecification).FrameClause = newNode.(*FrameClause) diff --git a/go/vt/sqlparser/ast_visit.go b/go/vt/sqlparser/ast_visit.go index b8ece9e3ae3..f6b7f891ecf 100644 --- a/go/vt/sqlparser/ast_visit.go +++ b/go/vt/sqlparser/ast_visit.go @@ -1416,8 +1416,10 @@ func VisitRefOfDelete(in *Delete, f Visit) error { if err := VisitRefOfWhere(in.Where, f); err != nil { return err } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err + for _, el := range in.OrderBy { + if err := VisitRefOfOrder(el, f); err != nil { + return err + } } if err := VisitRefOfLimit(in.Limit, f); err != nil { return err @@ -1927,8 +1929,10 @@ func VisitRefOfGroupConcatExpr(in *GroupConcatExpr, f Visit) error { return err } } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err + for _, el := range in.OrderBy { + if err := VisitRefOfOrder(el, f); err != nil { + return err + } } if err := VisitRefOfLimit(in.Limit, f); err != nil { return err @@ -3578,8 +3582,10 @@ func VisitRefOfSelect(in *Select, f Visit) error { if err := VisitNamedWindows(in.Windows, f); err != nil { return err } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err + for _, el := range in.OrderBy { + if err := VisitRefOfOrder(el, f); err != nil { + return err + } } if err := VisitRefOfLimit(in.Limit, f); err != nil { return err @@ -4111,8 +4117,10 @@ func VisitRefOfUnion(in *Union, f Visit) error { if err := VisitTableStatement(in.Right, f); err != nil { return err } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err + for _, el := range in.OrderBy { + if err := VisitRefOfOrder(el, f); err != nil { + return err + } } if err := VisitRefOfLimit(in.Limit, f); err != nil { return err @@ -4155,8 +4163,10 @@ func VisitRefOfUpdate(in *Update, f Visit) error { if err := VisitRefOfWhere(in.Where, f); err != nil { return err } - if err := VisitOrderBy(in.OrderBy, f); err != nil { - return err + for _, el := range in.OrderBy { + if err := VisitRefOfOrder(el, f); err != nil { + return err + } } if err := VisitRefOfLimit(in.Limit, f); err != nil { return err @@ -4329,8 +4339,10 @@ func VisitRefOfValuesStatement(in *ValuesStatement, f Visit) error { if err := VisitRefOfParsedComments(in.Comments, f); err != nil { return err } - if err := VisitOrderBy(in.Order, f); err != nil { - return err + for _, el := range in.OrderBy { + if err := VisitRefOfOrder(el, f); err != nil { + return err + } } if err := VisitRefOfLimit(in.Limit, f); err != nil { return err @@ -4509,8 +4521,10 @@ func VisitRefOfWindowSpecification(in *WindowSpecification, f Visit) error { return err } } - if err := VisitOrderBy(in.OrderClause, f); err != nil { - return err + for _, el := range in.OrderBy { + if err := VisitRefOfOrder(el, f); err != nil { + return err + } } if err := VisitRefOfFrameClause(in.FrameClause, f); err != nil { return err diff --git a/go/vt/sqlparser/cached_size.go b/go/vt/sqlparser/cached_size.go index 4f17041bdbe..4658e5be864 100644 --- a/go/vt/sqlparser/cached_size.go +++ b/go/vt/sqlparser/cached_size.go @@ -4780,8 +4780,8 @@ func (cached *ValuesStatement) CachedSize(alloc bool) int64 { size += cached.Comments.CachedSize(true) // field Order vitess.io/vitess/go/vt/sqlparser.OrderBy { - size += hack.RuntimeAllocSize(int64(cap(cached.Order)) * int64(8)) - for _, elem := range cached.Order { + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(8)) + for _, elem := range cached.OrderBy { size += elem.CachedSize(true) } } @@ -4979,10 +4979,10 @@ func (cached *WindowSpecification) CachedSize(alloc bool) int64 { } } } - // field OrderClause vitess.io/vitess/go/vt/sqlparser.OrderBy + // field OrderBy vitess.io/vitess/go/vt/sqlparser.OrderBy { - size += hack.RuntimeAllocSize(int64(cap(cached.OrderClause)) * int64(8)) - for _, elem := range cached.OrderClause { + size += hack.RuntimeAllocSize(int64(cap(cached.OrderBy)) * int64(8)) + for _, elem := range cached.OrderBy { size += elem.CachedSize(true) } } diff --git a/go/vt/sqlparser/sql.go b/go/vt/sqlparser/sql.go index ef2036265e6..f6504b1ac6f 100644 --- a/go/vt/sqlparser/sql.go +++ b/go/vt/sqlparser/sql.go @@ -18511,7 +18511,7 @@ yydefault: var yyLOCAL *WindowSpecification //line sql.y:5768 { - yyLOCAL = &WindowSpecification{Name: yyDollar[1].identifierCI, PartitionClause: yyDollar[2].exprsUnion(), OrderClause: yyDollar[3].orderByUnion(), FrameClause: yyDollar[4].frameClauseUnion()} + yyLOCAL = &WindowSpecification{Name: yyDollar[1].identifierCI, PartitionClause: yyDollar[2].exprsUnion(), OrderBy: yyDollar[3].orderByUnion(), FrameClause: yyDollar[4].frameClauseUnion()} } yyVAL.union = yyLOCAL case 1117: diff --git a/go/vt/sqlparser/sql.y b/go/vt/sqlparser/sql.y index 5ecdeabe108..872a6cebb0a 100644 --- a/go/vt/sqlparser/sql.y +++ b/go/vt/sqlparser/sql.y @@ -5766,7 +5766,7 @@ sql_id_opt: window_spec: sql_id_opt window_partition_clause_opt order_by_opt frame_clause_opt { - $$ = &WindowSpecification{ Name: $1, PartitionClause: $2, OrderClause: $3, FrameClause: $4} + $$ = &WindowSpecification{ Name: $1, PartitionClause: $2, OrderBy: $3, FrameClause: $4} } over_clause: diff --git a/go/vt/sqlparser/tracked_buffer.go b/go/vt/sqlparser/tracked_buffer.go index 9ef0cfd1f93..b4c4a21cf74 100644 --- a/go/vt/sqlparser/tracked_buffer.go +++ b/go/vt/sqlparser/tracked_buffer.go @@ -255,6 +255,15 @@ func (buf *TrackedBuffer) formatExprs(exprs []Expr) { } } +func formatSlice[T SQLNode](buf *TrackedBuffer, name string, valueExprs []T) { + prefix := " " + name + " " + for _, n := range valueExprs { + _, _ = buf.literal(prefix) + buf.formatter(n) + prefix = ", " + } +} + func (buf *TrackedBuffer) formatNodes(input any) { switch nodes := input.(type) { case []Expr: