Skip to content

Commit

Permalink
googleai: add codegen tool to generate vertex.go from googleai.go (#553)
Browse files Browse the repository at this point in the history
Fixes #410
  • Loading branch information
eliben authored Jan 25, 2024
1 parent 25b04d4 commit e5e2fa5
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 2 deletions.
1 change: 1 addition & 0 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ require (
gitlab.com/golang-commonmark/markdown v0.0.0-20211110145824-bf3e522c626a
go.starlark.net v0.0.0-20230302034142-4b1e35fe2254
golang.org/x/exp v0.0.0-20230713183714-613f0c0eb8a1
golang.org/x/tools v0.14.0
google.golang.org/api v0.152.0
google.golang.org/grpc v1.60.0
google.golang.org/protobuf v1.31.0
Expand Down
4 changes: 4 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,8 @@ golang.org/x/mod v0.3.0/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA=
golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4=
golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.13.0 h1:I/DsJXRlw/8l/0c24sM9yb0T4z9liZTduXvdAWYiysY=
golang.org/x/mod v0.13.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
Expand Down Expand Up @@ -920,6 +922,8 @@ golang.org/x/tools v0.0.0-20210106214847-113979e3529a/go.mod h1:emZCQorbCU4vsT4f
golang.org/x/tools v0.1.3/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk=
golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc=
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.14.0 h1:jvNa2pY0M4r62jkRQ6RwEZZyPcymeL9XZMLBbV7U2nc=
golang.org/x/tools v0.14.0/go.mod h1:uYBEerGOWcJyEORxN+Ek8+TT266gXkNlHdJBwexUsBg=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
Expand Down
115 changes: 115 additions & 0 deletions llms/googleai/internal/cmd/generate-vertex.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
// Code generator for vertex.go from googleai.go
// nolint
package main

import (
"fmt"
"go/ast"
"go/format"
"go/parser"
"go/token"
"log"
"os"
"strings"

"golang.org/x/tools/go/ast/astutil"
)

func main() {
fset := token.NewFileSet()
file, err := parser.ParseFile(fset, "src.go", os.Stdin, parser.ParseComments)
if err != nil {
log.Fatal(err)
}

file.Name.Name = "vertex"

astutil.Apply(file, nil, func(c *astutil.Cursor) bool {
n := c.Node()
switch x := n.(type) {
case *ast.ImportSpec:
rewriteImport(x)

case *ast.FuncDecl:
if x.Recv != nil && len(x.Recv.List) == 1 {
rewriteReceiverName(x)
}
addCastToTopK(x)
removeTokenCount(x)
}

return true
})

fmt.Println(strings.TrimLeft(preamble, "\r\n"))
format.Node(os.Stdout, fset, file)
}

const preamble = `
// DO NOT EDIT THIS FILE -- it is automatically generated from googleai.go
// See the README file in this directory for additional details
`

func rewriteImport(x *ast.ImportSpec) {
if strings.Index(x.Path.Value, "generative-ai-go/genai") > 0 {
x.Path.Value = `"cloud.google.com/go/vertexai/genai"`
}
}

func rewriteReceiverName(fun *ast.FuncDecl) {
recv := fun.Recv.List[0]
ty := recv.Type.(*ast.StarExpr)
tyName := ty.X.(*ast.Ident)
tyName.Name = "Vertex"
}

func addCastToTopK(fun *ast.FuncDecl) {
ast.Inspect(fun, func(n ast.Node) bool {
switch x := n.(type) {
case *ast.CallExpr:
if getIdentName(x.Fun) == "int32" && len(x.Args) == 1 {
arg0 := x.Args[0]
if sel, ok := arg0.(*ast.SelectorExpr); ok {
if getIdentName(sel.X) == "opts" {
if getIdentName(sel.Sel) == "TopK" {
funcId := x.Fun.(*ast.Ident)
funcId.Name = "float32"
}
}
}
}
}
return true
})
}

func removeTokenCount(fun *ast.FuncDecl) {
ast.Inspect(fun, func(n ast.Node) bool {
if block, ok := n.(*ast.BlockStmt); ok {
idx := -1
for i, stmt := range block.List {
if assign, ok := stmt.(*ast.AssignStmt); ok {
lhs0 := assign.Lhs[0]
if lhs, ok := lhs0.(*ast.SelectorExpr); ok && getIdentName(lhs.Sel) == "TokenCount" && getIdentName(lhs.X) == "candidate" {
idx = i
break
}
}
}

if idx > 0 {
block.List = append(block.List[:idx], block.List[idx+1:]...)
}
}
return true
})
}

// getIdentName returns the identifier name from ast.Ident expressions; for
// other expressions, returns an empty string.
func getIdentName(x ast.Expr) string {
if id, ok := x.(*ast.Ident); ok {
return id.Name
}
return ""
}
2 changes: 2 additions & 0 deletions llms/googleai/new.go
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// DO NOT EDIT

// package googleai implements a langchaingo provider for Google AI LLMs.
// See https://ai.google.dev/ for more details.
package googleai
Expand Down
12 changes: 12 additions & 0 deletions llms/googleai/vertex/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
This package implements a langchaingo provider for Google Vertex AI LLMs,
using the SDK at https://pkg.go.dev/cloud.google.com/go/vertexai

Since Vertex SDK is so similar to the Google AI SDK, we generate the main part
of this package from `llms/googleai/googleai.go` to create
`llms/googleai/vertex/vertex.go`.

To re-generate, run this from the root of the repository:

go run ./llms/googleai/internal/cmd/generate-vertex.go < llms/googleai/googleai.go > llms/googleai/vertex/vertex.go

See the script in `llms/googleai/internal/cmd/generate-vertex.go` for details.
5 changes: 3 additions & 2 deletions llms/googleai/vertex/vertex.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package vertex
// DO NOT EDIT THIS FILE -- it is automatically generated from googleai.go
// See the README file in this directory for additional details

// DO NOT EDIT: this code is auto-generated from llms/googleai/googleai.go
package vertex

import (
"context"
Expand Down

0 comments on commit e5e2fa5

Please sign in to comment.