Skip to content

Commit

Permalink
feat: add support for Visitable on CopyOnRewrite
Browse files Browse the repository at this point in the history
Signed-off-by: Andres Taylor <andres@planetscale.com>
  • Loading branch information
systay committed Feb 27, 2025
1 parent 3546eda commit 5033886
Show file tree
Hide file tree
Showing 8 changed files with 219 additions and 6 deletions.
60 changes: 60 additions & 0 deletions go/tools/asthelpergen/copy_on_rewrite_gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ func (c *cowGen) addFunc(code *jen.Statement) {
}

func (c *cowGen) genFile(generatorSPI) (string, *jen.File) {
c.genVisitableMethod()
return "ast_copy_on_rewrite.go", c.file
}

Expand Down Expand Up @@ -192,6 +193,9 @@ func (c *cowGen) interfaceMethod(t types.Type, iface *types.Interface, spi gener
})

cases = append(cases,
jen.Case(jen.Id(visitableName)).Block(
jen.Return(jen.Id("c."+cowName+visitableName).Call(jen.Id("n, parent"))),
),
jen.Default().Block(
jen.Comment("this should never happen"),
jen.Return(jen.Nil(), jen.False()),
Expand Down Expand Up @@ -368,6 +372,62 @@ func (c *cowGen) visitStruct(t types.Type, strct *types.Struct, spi generatorSPI
c.addFunc(funcDeclaration.Block(stmts...))
}

func (c *cowGen) genVisitableMethod() {
c.file.Func().
Params(jen.Id("c").Op("*").Id("cow")).
Id("copyOnRewriteVisitable").
Params(
jen.Id("n").Id("Visitable"),
jen.Id("parent").Id(c.baseType),
).
Params(
jen.Id("out").Id(c.baseType),
jen.Id("changed").Bool(),
).
Block(
// if c.cursor.stop { return n, false }
jen.If(jen.Id("c").Dot("cursor").Dot("stop")).Block(
jen.Return(jen.Id("n"), jen.False()),
),

// out = n
jen.Id("out").Op("=").Id("n"),

// if c.pre == nil || c.pre(n, parent) { ... }
jen.If(
jen.Id("c").Dot("pre").Op("==").Nil().Op("||").
Id("c").Dot("pre").Call(jen.Id("n"), jen.Id("parent")),
).Block(
// _inner, changedInner := c.copyOnRewriteAST(n.VisitThis(), n)
jen.List(jen.Id("_inner"), jen.Id("changedInner")).
Op(":=").
Id("c").Dot("copyOnRewrite"+c.baseType).
Call(
jen.Id("n").Dot("VisitThis").Call(),
jen.Id("n"),
),

// if changedInner { res := n.Clone(_inner); out = res; changed = true }
jen.If(jen.Id("changedInner")).Block(
jen.Id("res").Op(":=").Id("n").Dot("Clone").Call(jen.Id("_inner")),
jen.Id("out").Op("=").Id("res"),
jen.Id("changed").Op("=").True(),
),
),

// if c.post != nil { out, changed = c.postVisit(out, parent, changed) }
jen.If(jen.Id("c").Dot("post").Op("!=").Nil()).Block(
jen.List(jen.Id("out"), jen.Id("changed")).
Op("=").
Id("c").Dot("postVisit").
Call(jen.Id("out"), jen.Id("parent"), jen.Id("changed")),
),

// return
jen.Return(),
)
}

func ifPostNotNilVisit(out string) *jen.Statement {
return ifNotNil("c.post", jen.List(jen.Id(out), jen.Id("changed")).Op("=").Id("c").Dot("postVisit").Params(jen.Id(out), jen.Id("parent"), jen.Id("changed")))
}
22 changes: 22 additions & 0 deletions go/tools/asthelpergen/integration/ast_copy_on_rewrite.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

14 changes: 9 additions & 5 deletions go/tools/asthelpergen/integration/integration_rewriter_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,8 @@ func TestVisitableRewrite(t *testing.T) {
Pre{visitable},
Pre{leaf},
Post{leaf},
Post{visitable},
Post{refContainer},
Pre{visitable},
Pre{refContainer},
})
}

Expand Down Expand Up @@ -418,15 +418,19 @@ func (tv *rewriteTestVisitor) post(cursor *Cursor) bool {
return true
}
func (tv *rewriteTestVisitor) assertEquals(t *testing.T, expected []step) {
assertStepsEqual(t, tv.walk, expected)
}

func assertStepsEqual(t *testing.T, walk, expected []step) {
t.Helper()
var lines []string
error := false
expectedSize := len(expected)
for i, step := range tv.walk {
for i, step := range walk {
t.Run(fmt.Sprintf("step %d", i), func(t *testing.T) {
t.Helper()
if expectedSize <= i {
t.Fatalf("❌️ - Expected less elements %v", tv.walk[i:])
t.Fatalf("❌️ - Expected less elements %v", walk[i:])
} else {
e := expected[i]
if reflect.DeepEqual(e, step) {
Expand All @@ -449,7 +453,7 @@ func (tv *rewriteTestVisitor) assertEquals(t *testing.T, expected []step) {
}
})
}
walkSize := len(tv.walk)
walkSize := len(walk)
if expectedSize > walkSize {
t.Errorf("❌️ - Expected more elements %v", expected[walkSize:])
}
Expand Down
52 changes: 52 additions & 0 deletions go/tools/asthelpergen/integration/integration_visit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import (
"reflect"
"testing"

"github.com/stretchr/testify/assert"

"github.com/stretchr/testify/require"
)

Expand Down Expand Up @@ -157,6 +159,10 @@ type testVisitable struct {
inner AST
}

func (t *testVisitable) Clone(inner AST) AST {
return &testVisitable{inner: inner}
}

func (t *testVisitable) String() string {
return t.inner.String()
}
Expand Down Expand Up @@ -184,6 +190,52 @@ func TestVisitableVisit(t *testing.T) {
})
}

func TestCopyOnRewriteVisitable(t *testing.T) {
leaf := &Leaf{v: 1}
visitable := &testVisitable{inner: leaf}
refContainer := &RefContainer{ASTType: visitable}
var walk []step
pre := func(node, parent AST) bool {
walk = append(walk, Pre{node})
return true
}
post := func(cursor *cursor) {
walk = append(walk, Post{cursor.node})
}
CopyOnRewrite(refContainer, pre, post, nil)

assertStepsEqual(t, walk, []step{
Pre{refContainer},
Pre{visitable},
Pre{leaf},
Post{leaf},
Post{visitable},
Post{refContainer},
})
}

func TestCopyOnRewriteReplaceVisitable(t *testing.T) {
leaf := &Leaf{v: 1}
visitable := &testVisitable{inner: leaf}
refContainer := &RefContainer{ASTType: visitable}

result := CopyOnRewrite(refContainer, nil, func(cursor *cursor) {
_, ok := cursor.node.(*Leaf)
if !ok {
return
}
cursor.replaced = &Leaf{v: 2}
}, nil)

assert.NotSame(t, refContainer, result)
resRefCon := result.(*RefContainer)
assert.NotSame(t, visitable, resRefCon.ASTType)
newLeaf := resRefCon.ASTType.(*testVisitable).inner.(*Leaf)
assert.Equal(t, 2, newLeaf.v)
assert.Equal(t, 1, leaf.v)

}

func (tv *testVisitor) assertVisitOrder(t *testing.T, expected []AST) {
t.Helper()
var lines []string
Expand Down
21 changes: 20 additions & 1 deletion go/tools/asthelpergen/integration/test_helpers.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,11 +124,19 @@ type (
cursor cursor
}
cursor struct {
stop bool
stop bool
node, replaced AST
}
)

func (c *cow) postVisit(a, b AST, d bool) (AST, bool) {
c.cursor.node = a
c.cursor.replaced = nil
c.post(&c.cursor)
if c.cursor.replaced != nil {
return c.cursor.replaced, true
}

return a, d
}

Expand Down Expand Up @@ -175,3 +183,14 @@ func (path ASTPath) DebugString() string {

return sb.String()
}

func CopyOnRewrite(
node AST,
pre func(node, parent AST) bool,
post func(cursor *cursor),
cloned func(before, after AST),
) AST {
cow := cow{pre: pre, post: post, cursor: cursor{}, cloned: cloned}
out, _ := cow.copyOnRewriteAST(node, nil)
return out
}
1 change: 1 addition & 0 deletions go/tools/asthelpergen/integration/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ type (
Visitable interface {
AST
VisitThis() AST
Clone(inner AST) AST
}
)

Expand Down
Loading

0 comments on commit 5033886

Please sign in to comment.