diff --git a/genai/client.go b/genai/client.go index fe9653f..7d0304d 100644 --- a/genai/client.go +++ b/genai/client.go @@ -75,10 +75,14 @@ Import the option package as "google.golang.org/api/option".`) if err != nil { return nil, fmt.Errorf("creating file client: %w", err) } - cc, err := gl.NewCacheClient(ctx, opts...) + + // Workaround for https://github.com/google/generative-ai-go/issues/151 + optsForCache := removeHTTPClientOption(opts) + cc, err := gl.NewCacheClient(ctx, optsForCache...) if err != nil { return nil, fmt.Errorf("creating cache client: %w", err) } + ds, err := gld.NewService(ctx, opts...) if err != nil { return nil, fmt.Errorf("creating discovery client: %w", err) @@ -110,6 +114,19 @@ func hasAuthOption(opts []option.ClientOption) bool { return false } +// removeHTTPClientOption removes option.withHTTPClient from the given list +// of options, if it exists; it returns the new (filtered) list. +func removeHTTPClientOption(opts []option.ClientOption) []option.ClientOption { + var newOpts []option.ClientOption + for _, opt := range opts { + ts := reflect.ValueOf(opt).Type().String() + if ts != "option.withHTTPClient" { + newOpts = append(newOpts, opt) + } + } + return newOpts +} + // Close closes the client. func (c *Client) Close() error { return errors.Join(c.gc.Close(), c.mc.Close(), c.fc.Close()) diff --git a/genai/client_test.go b/genai/client_test.go index a3eb5a6..a6aca80 100644 --- a/genai/client_test.go +++ b/genai/client_test.go @@ -766,9 +766,33 @@ func TestRecoverPanic(t *testing.T) { } } +type customRT struct { + APIKey string +} + +func (t *customRT) RoundTrip(req *http.Request) (*http.Response, error) { + transport := http.DefaultTransport.(*http.Transport).Clone() + newReq := req.Clone(req.Context()) + vals := newReq.URL.Query() + vals.Set("key", t.APIKey) + newReq.URL.RawQuery = vals.Encode() + + resp, err := transport.RoundTrip(newReq) + if err != nil { + return nil, err + } + + return resp, nil +} + func TestCustomHTTPClient(t *testing.T) { - t.Skip("custom HTTP client not working right now") - c := http.DefaultClient + apiKey := os.Getenv("GEMINI_API_KEY") + if testing.Short() || apiKey == "" { + t.Skip("skipping live test in -short mode, or when API key isn't provided") + } + c := &http.Client{ + Transport: &customRT{APIKey: apiKey}, + } ctx := context.Background() client, err := NewClient(ctx, option.WithHTTPClient(c)) @@ -780,9 +804,10 @@ func TestCustomHTTPClient(t *testing.T) { model := client.GenerativeModel(defaultModel) resp, err := model.GenerateContent(ctx, Text("What are some of the largest cities in the US?")) if err != nil { - log.Fatal(err) + t.Fatal(err) } - fmt.Println(resp) + got := responseString(resp) + checkMatch(t, got, `new york`) } func uploadFile(t *testing.T, ctx context.Context, client *Client, filename string) *File {