Skip to content

Commit

Permalink
genai: make custom HTTP client work again (#152)
Browse files Browse the repository at this point in the history
Workaround until #151 is fixed
  • Loading branch information
eliben authored Jun 28, 2024
1 parent 8e25c39 commit fcf3f0b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 5 deletions.
19 changes: 18 additions & 1 deletion genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,10 +75,14 @@ Import the option package as "google.golang.org/api/option".`)
if err != nil {
return nil, fmt.Errorf("creating file client: %w", err)
}
cc, err := gl.NewCacheClient(ctx, opts...)

// Workaround for https://github.com/google/generative-ai-go/issues/151
optsForCache := removeHTTPClientOption(opts)
cc, err := gl.NewCacheClient(ctx, optsForCache...)
if err != nil {
return nil, fmt.Errorf("creating cache client: %w", err)
}

ds, err := gld.NewService(ctx, opts...)
if err != nil {
return nil, fmt.Errorf("creating discovery client: %w", err)
Expand Down Expand Up @@ -110,6 +114,19 @@ func hasAuthOption(opts []option.ClientOption) bool {
return false
}

// removeHTTPClientOption removes option.withHTTPClient from the given list
// of options, if it exists; it returns the new (filtered) list.
func removeHTTPClientOption(opts []option.ClientOption) []option.ClientOption {
var newOpts []option.ClientOption
for _, opt := range opts {
ts := reflect.ValueOf(opt).Type().String()
if ts != "option.withHTTPClient" {
newOpts = append(newOpts, opt)
}
}
return newOpts
}

// Close closes the client.
func (c *Client) Close() error {
return errors.Join(c.gc.Close(), c.mc.Close(), c.fc.Close())
Expand Down
33 changes: 29 additions & 4 deletions genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -766,9 +766,33 @@ func TestRecoverPanic(t *testing.T) {
}
}

type customRT struct {
APIKey string
}

func (t *customRT) RoundTrip(req *http.Request) (*http.Response, error) {
transport := http.DefaultTransport.(*http.Transport).Clone()
newReq := req.Clone(req.Context())
vals := newReq.URL.Query()
vals.Set("key", t.APIKey)
newReq.URL.RawQuery = vals.Encode()

resp, err := transport.RoundTrip(newReq)
if err != nil {
return nil, err
}

return resp, nil
}

func TestCustomHTTPClient(t *testing.T) {
t.Skip("custom HTTP client not working right now")
c := http.DefaultClient
apiKey := os.Getenv("GEMINI_API_KEY")
if testing.Short() || apiKey == "" {
t.Skip("skipping live test in -short mode, or when API key isn't provided")
}
c := &http.Client{
Transport: &customRT{APIKey: apiKey},
}

ctx := context.Background()
client, err := NewClient(ctx, option.WithHTTPClient(c))
Expand All @@ -780,9 +804,10 @@ func TestCustomHTTPClient(t *testing.T) {
model := client.GenerativeModel(defaultModel)
resp, err := model.GenerateContent(ctx, Text("What are some of the largest cities in the US?"))
if err != nil {
log.Fatal(err)
t.Fatal(err)
}
fmt.Println(resp)
got := responseString(resp)
checkMatch(t, got, `new york`)
}

func uploadFile(t *testing.T, ctx context.Context, client *Client, filename string) *File {
Expand Down

0 comments on commit fcf3f0b

Please sign in to comment.