diff --git a/go/tools/asthelpergen/integration/ast_path_test.go b/go/tools/asthelpergen/integration/ast_path_test.go index 36ef35951aa..157322c374d 100644 --- a/go/tools/asthelpergen/integration/ast_path_test.go +++ b/go/tools/asthelpergen/integration/ast_path_test.go @@ -43,7 +43,7 @@ func TestWalkAllPartsOfAST(t *testing.T) { } var leafPaths []ASTPath - RewriteWithPaths(ast, func(c *Cursor) bool { + Rewrite(ast, func(c *Cursor) bool { node := c.Node() if !reflect.TypeOf(node).Comparable() { return true diff --git a/go/tools/asthelpergen/integration/ast_rewrite.go b/go/tools/asthelpergen/integration/ast_rewrite.go index 89f6966558e..1a22c7d0162 100644 --- a/go/tools/asthelpergen/integration/ast_rewrite.go +++ b/go/tools/asthelpergen/integration/ast_rewrite.go @@ -372,7 +372,7 @@ func (a *application) rewriteRefOfRefSliceContainer(parent AST, node *RefSliceCo return false } } - if a.collectPaths { + if a.collectPaths && len(node.ASTImplementationElements) > 0 { a.cur.current.Pop() } if a.post != nil { diff --git a/go/tools/asthelpergen/integration/test_helpers.go b/go/tools/asthelpergen/integration/test_helpers.go index f79feff12dc..b9e60b4940b 100644 --- a/go/tools/asthelpergen/integration/test_helpers.go +++ b/go/tools/asthelpergen/integration/test_helpers.go @@ -83,25 +83,9 @@ func (c *Cursor) ReplaceAndRevisit(newNode AST) { type replacerFunc func(newNode, parent AST) -// Rewrite is the api. func Rewrite(node AST, pre, post ApplyFunc) AST { outer := &struct{ AST }{node} - a := &application{ - pre: pre, - post: post, - } - - a.rewriteAST(outer, node, func(newNode, parent AST) { - outer.AST = newNode - }) - - return outer.AST -} - -func RewriteWithPaths(node AST, pre, post ApplyFunc) AST { - outer := &struct{ AST }{node} - a := &application{ pre: pre, post: post, diff --git a/go/tools/asthelpergen/rewrite_gen.go b/go/tools/asthelpergen/rewrite_gen.go index 52b476a7c9d..754a8c1e89d 100644 --- a/go/tools/asthelpergen/rewrite_gen.go +++ b/go/tools/asthelpergen/rewrite_gen.go @@ -100,12 +100,17 @@ func (r *rewriteGen) interfaceMethod(t types.Type, iface *types.Interface, spi g } func (r *rewriteGen) visitStructFields(t types.Type, strct *types.Struct, spi generatorSPI, fail bool) (stmts []jen.Code) { - fields := r.rewriteAllStructFields(t, strct, spi, fail) + fields, prevSliceField := r.rewriteAllStructFields(t, strct, spi, fail) stmts = append(stmts, r.executePre()) stmts = append(stmts, fields...) if len(fields) > 0 { - stmts = append(stmts, jen.If(jen.Id("a.collectPaths")).Block(jen.Id("a.cur.current.Pop").Params())) + ifCondition := jen.Id("a.collectPaths") + if prevSliceField != "" { + ifCondition = ifCondition.Op("&&").Len(jen.Id("node." + prevSliceField)).Op(">").Lit(0) + } + + stmts = append(stmts, jen.If(ifCondition).Block(jen.Id("a.cur.current.Pop").Params())) } stmts = append(stmts, executePost(len(fields) > 0)) stmts = append(stmts, returnTrue()) @@ -284,7 +289,7 @@ func (r *rewriteGen) rewriteFunc(t types.Type, stmts []jen.Code, source string) r.file.Add(code) } -func (r *rewriteGen) rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, fail bool) []jen.Code { +func (r *rewriteGen) rewriteAllStructFields(t types.Type, strct *types.Struct, spi generatorSPI, fail bool) ([]jen.Code, string) { /* if errF := rewriteAST(node, node.ASTType, func(newNode, parent AST) { err = vterrors.New(vtrpcpb.Code_INTERNAL, "[BUG] tried to replace '%s' on '%s'") @@ -326,7 +331,7 @@ func (r *rewriteGen) rewriteAllStructFields(t types.Type, strct *types.Struct, s fieldNumber++ } } - return output + return output, prevSliceField } func failReplacer(t types.Type, f string) *jen.Statement { diff --git a/go/vt/sqlparser/ast_rewrite.go b/go/vt/sqlparser/ast_rewrite.go index 33e5b0cda42..e098960fe7e 100644 --- a/go/vt/sqlparser/ast_rewrite.go +++ b/go/vt/sqlparser/ast_rewrite.go @@ -1998,7 +1998,7 @@ func (a *application) rewriteRefOfCallProc(parent SQLNode, node *CallProc, repla return false } } - if a.collectPaths { + if a.collectPaths && len(node.Params) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -2221,7 +2221,7 @@ func (a *application) rewriteRefOfCharExpr(parent SQLNode, node *CharExpr, repla return false } } - if a.collectPaths { + if a.collectPaths && len(node.Exprs) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -3713,7 +3713,7 @@ func (a *application) rewriteRefOfExecuteStmt(parent SQLNode, node *ExecuteStmt, return false } } - if a.collectPaths { + if a.collectPaths && len(node.Arguments) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -3890,7 +3890,7 @@ func (a *application) rewriteRefOfExprs(parent SQLNode, node *Exprs, replacer re return false } } - if a.collectPaths { + if a.collectPaths && len(node.Exprs) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -4354,7 +4354,7 @@ func (a *application) rewriteRefOfFuncExpr(parent SQLNode, node *FuncExpr, repla return false } } - if a.collectPaths { + if a.collectPaths && len(node.Exprs) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -4995,7 +4995,7 @@ func (a *application) rewriteRefOfGroupBy(parent SQLNode, node *GroupBy, replace return false } } - if a.collectPaths { + if a.collectPaths && len(node.Exprs) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -5205,7 +5205,7 @@ func (a *application) rewriteRefOfIndexHint(parent SQLNode, node *IndexHint, rep return false } } - if a.collectPaths { + if a.collectPaths && len(node.Indexes) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -5571,7 +5571,7 @@ func (a *application) rewriteRefOfIntervalFuncExpr(parent SQLNode, node *Interva return false } } - if a.collectPaths { + if a.collectPaths && len(node.Exprs) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -5748,7 +5748,7 @@ func (a *application) rewriteRefOfJSONArrayExpr(parent SQLNode, node *JSONArrayE return false } } - if a.collectPaths { + if a.collectPaths && len(node.Params) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -5865,7 +5865,7 @@ func (a *application) rewriteRefOfJSONContainsExpr(parent SQLNode, node *JSONCon return false } } - if a.collectPaths { + if a.collectPaths && len(node.PathList) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -5933,7 +5933,7 @@ func (a *application) rewriteRefOfJSONContainsPathExpr(parent SQLNode, node *JSO return false } } - if a.collectPaths { + if a.collectPaths && len(node.PathList) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -5992,7 +5992,7 @@ func (a *application) rewriteRefOfJSONExtractExpr(parent SQLNode, node *JSONExtr return false } } - if a.collectPaths { + if a.collectPaths && len(node.PathList) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -6147,7 +6147,7 @@ func (a *application) rewriteRefOfJSONObjectExpr(parent SQLNode, node *JSONObjec return false } } - if a.collectPaths { + if a.collectPaths && len(node.Params) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -6384,7 +6384,7 @@ func (a *application) rewriteRefOfJSONRemoveExpr(parent SQLNode, node *JSONRemov return false } } - if a.collectPaths { + if a.collectPaths && len(node.PathList) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -6568,7 +6568,7 @@ func (a *application) rewriteRefOfJSONSearchExpr(parent SQLNode, node *JSONSearc return false } } - if a.collectPaths { + if a.collectPaths && len(node.PathList) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -6725,7 +6725,7 @@ func (a *application) rewriteRefOfJSONTableExpr(parent SQLNode, node *JSONTableE return false } } - if a.collectPaths { + if a.collectPaths && len(node.Columns) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -6900,7 +6900,7 @@ func (a *application) rewriteRefOfJSONValueMergeExpr(parent SQLNode, node *JSONV return false } } - if a.collectPaths { + if a.collectPaths && len(node.JSONDocList) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -6959,7 +6959,7 @@ func (a *application) rewriteRefOfJSONValueModifierExpr(parent SQLNode, node *JS return false } } - if a.collectPaths { + if a.collectPaths && len(node.Params) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -7372,7 +7372,7 @@ func (a *application) rewriteRefOfLineStringExpr(parent SQLNode, node *LineStrin return false } } - if a.collectPaths { + if a.collectPaths && len(node.PointParams) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -7953,7 +7953,7 @@ func (a *application) rewriteRefOfMultiLinestringExpr(parent SQLNode, node *Mult return false } } - if a.collectPaths { + if a.collectPaths && len(node.LinestringParams) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -8001,7 +8001,7 @@ func (a *application) rewriteRefOfMultiPointExpr(parent SQLNode, node *MultiPoin return false } } - if a.collectPaths { + if a.collectPaths && len(node.PointParams) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -8049,7 +8049,7 @@ func (a *application) rewriteRefOfMultiPolygonExpr(parent SQLNode, node *MultiPo return false } } - if a.collectPaths { + if a.collectPaths && len(node.PolygonParams) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -9102,7 +9102,7 @@ func (a *application) rewriteRefOfPartitionOption(parent SQLNode, node *Partitio return false } } - if a.collectPaths { + if a.collectPaths && len(node.Definitions) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -9179,7 +9179,7 @@ func (a *application) rewriteRefOfPartitionSpec(parent SQLNode, node *PartitionS return false } } - if a.collectPaths { + if a.collectPaths && len(node.Definitions) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -9453,7 +9453,7 @@ func (a *application) rewriteRefOfPolygonExpr(parent SQLNode, node *PolygonExpr, return false } } - if a.collectPaths { + if a.collectPaths && len(node.LinestringParams) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -10605,7 +10605,7 @@ func (a *application) rewriteRefOfSelectExprs(parent SQLNode, node *SelectExprs, return false } } - if a.collectPaths { + if a.collectPaths && len(node.Exprs) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -13348,7 +13348,7 @@ func (a *application) rewriteRefOfVindexSpec(parent SQLNode, node *VindexSpec, r return false } } - if a.collectPaths { + if a.collectPaths && len(node.Params) > 0 { a.cur.current.Pop() } if a.post != nil { @@ -13708,7 +13708,7 @@ func (a *application) rewriteRefOfWith(parent SQLNode, node *With, replacer repl return false } } - if a.collectPaths { + if a.collectPaths && len(node.CTEs) > 0 { a.cur.current.Pop() } if a.post != nil { diff --git a/go/vt/sqlparser/parse_test.go b/go/vt/sqlparser/parse_test.go index 6c4a6e634ec..f35cdcd7c2e 100644 --- a/go/vt/sqlparser/parse_test.go +++ b/go/vt/sqlparser/parse_test.go @@ -3917,6 +3917,10 @@ func TestValid(t *testing.T) { _ = Walk(func(node SQLNode) (bool, error) { return true, nil }, tree) + + _ = RewriteWithPath(tree, func(cursor *Cursor) bool { + return true + }, nil) }) } }