-
-
Notifications
You must be signed in to change notification settings - Fork 745
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
googleai: add codegen tool to generate vertex.go from googleai.go (#553)
Fixes #410
- Loading branch information
Showing
6 changed files
with
137 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
// Code generator for vertex.go from googleai.go | ||
// nolint | ||
package main | ||
|
||
import ( | ||
"fmt" | ||
"go/ast" | ||
"go/format" | ||
"go/parser" | ||
"go/token" | ||
"log" | ||
"os" | ||
"strings" | ||
|
||
"golang.org/x/tools/go/ast/astutil" | ||
) | ||
|
||
func main() { | ||
fset := token.NewFileSet() | ||
file, err := parser.ParseFile(fset, "src.go", os.Stdin, parser.ParseComments) | ||
if err != nil { | ||
log.Fatal(err) | ||
} | ||
|
||
file.Name.Name = "vertex" | ||
|
||
astutil.Apply(file, nil, func(c *astutil.Cursor) bool { | ||
n := c.Node() | ||
switch x := n.(type) { | ||
case *ast.ImportSpec: | ||
rewriteImport(x) | ||
|
||
case *ast.FuncDecl: | ||
if x.Recv != nil && len(x.Recv.List) == 1 { | ||
rewriteReceiverName(x) | ||
} | ||
addCastToTopK(x) | ||
removeTokenCount(x) | ||
} | ||
|
||
return true | ||
}) | ||
|
||
fmt.Println(strings.TrimLeft(preamble, "\r\n")) | ||
format.Node(os.Stdout, fset, file) | ||
} | ||
|
||
const preamble = ` | ||
// DO NOT EDIT THIS FILE -- it is automatically generated from googleai.go | ||
// See the README file in this directory for additional details | ||
` | ||
|
||
func rewriteImport(x *ast.ImportSpec) { | ||
if strings.Index(x.Path.Value, "generative-ai-go/genai") > 0 { | ||
x.Path.Value = `"cloud.google.com/go/vertexai/genai"` | ||
} | ||
} | ||
|
||
func rewriteReceiverName(fun *ast.FuncDecl) { | ||
recv := fun.Recv.List[0] | ||
ty := recv.Type.(*ast.StarExpr) | ||
tyName := ty.X.(*ast.Ident) | ||
tyName.Name = "Vertex" | ||
} | ||
|
||
func addCastToTopK(fun *ast.FuncDecl) { | ||
ast.Inspect(fun, func(n ast.Node) bool { | ||
switch x := n.(type) { | ||
case *ast.CallExpr: | ||
if getIdentName(x.Fun) == "int32" && len(x.Args) == 1 { | ||
arg0 := x.Args[0] | ||
if sel, ok := arg0.(*ast.SelectorExpr); ok { | ||
if getIdentName(sel.X) == "opts" { | ||
if getIdentName(sel.Sel) == "TopK" { | ||
funcId := x.Fun.(*ast.Ident) | ||
funcId.Name = "float32" | ||
} | ||
} | ||
} | ||
} | ||
} | ||
return true | ||
}) | ||
} | ||
|
||
func removeTokenCount(fun *ast.FuncDecl) { | ||
ast.Inspect(fun, func(n ast.Node) bool { | ||
if block, ok := n.(*ast.BlockStmt); ok { | ||
idx := -1 | ||
for i, stmt := range block.List { | ||
if assign, ok := stmt.(*ast.AssignStmt); ok { | ||
lhs0 := assign.Lhs[0] | ||
if lhs, ok := lhs0.(*ast.SelectorExpr); ok && getIdentName(lhs.Sel) == "TokenCount" && getIdentName(lhs.X) == "candidate" { | ||
idx = i | ||
break | ||
} | ||
} | ||
} | ||
|
||
if idx > 0 { | ||
block.List = append(block.List[:idx], block.List[idx+1:]...) | ||
} | ||
} | ||
return true | ||
}) | ||
} | ||
|
||
// getIdentName returns the identifier name from ast.Ident expressions; for | ||
// other expressions, returns an empty string. | ||
func getIdentName(x ast.Expr) string { | ||
if id, ok := x.(*ast.Ident); ok { | ||
return id.Name | ||
} | ||
return "" | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
This package implements a langchaingo provider for Google Vertex AI LLMs, | ||
using the SDK at https://pkg.go.dev/cloud.google.com/go/vertexai | ||
|
||
Since Vertex SDK is so similar to the Google AI SDK, we generate the main part | ||
of this package from `llms/googleai/googleai.go` to create | ||
`llms/googleai/vertex/vertex.go`. | ||
|
||
To re-generate, run this from the root of the repository: | ||
|
||
go run ./llms/googleai/internal/cmd/generate-vertex.go < llms/googleai/googleai.go > llms/googleai/vertex/vertex.go | ||
|
||
See the script in `llms/googleai/internal/cmd/generate-vertex.go` for details. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters