diff --git a/llms/googleai/googleai_llm.go b/llms/googleai/googleai_llm.go index 5faaea717..048c37a31 100644 --- a/llms/googleai/googleai_llm.go +++ b/llms/googleai/googleai_llm.go @@ -74,17 +74,19 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC } opts := llms.CallOptions{ - Model: g.opts.defaultModel, - MaxTokens: g.opts.defaultMaxTokens, - Temperature: g.opts.defaultTemperature, - TopP: g.opts.defaultTopP, - TopK: g.opts.defaultTopK, + Model: g.opts.defaultModel, + CandidateCount: g.opts.defaultCandidateCount, + MaxTokens: g.opts.defaultMaxTokens, + Temperature: g.opts.defaultTemperature, + TopP: g.opts.defaultTopP, + TopK: g.opts.defaultTopK, } for _, opt := range options { opt(&opts) } model := g.client.GenerativeModel(opts.Model) + model.SetCandidateCount(int32(opts.CandidateCount)) model.SetMaxOutputTokens(int32(opts.MaxTokens)) model.SetTemperature(float32(opts.Temperature)) model.SetTopP(float32(opts.TopP)) diff --git a/llms/googleai/googleai_llm_test.go b/llms/googleai/googleai_llm_test.go index 5c79cb267..4e3a48b25 100644 --- a/llms/googleai/googleai_llm_test.go +++ b/llms/googleai/googleai_llm_test.go @@ -148,6 +148,31 @@ func TestEmbeddings(t *testing.T) { assert.NotEmpty(t, res[1]) } +func TestCandidateCountSetting(t *testing.T) { + t.Parallel() + llm := newClient(t) + + parts := []llms.ContentPart{ + llms.TextContent{Text: "Name five countries in Africa"}, + } + content := []llms.MessageContent{ + { + Role: schema.ChatMessageTypeHuman, + Parts: parts, + }, + } + + { + rsp, err := llm.GenerateContent(context.Background(), content, + llms.WithCandidateCount(1), llms.WithTemperature(1)) + require.NoError(t, err) + + assert.Len(t, rsp.Choices, 1) + } + + // TODO: test multiple candidates when the backend supports it +} + func TestMaxTokensSetting(t *testing.T) { t.Parallel() llm := newClient(t) diff --git a/llms/googleai/googleai_option.go b/llms/googleai/googleai_option.go index 300b6a811..66704ea7a 100644 --- a/llms/googleai/googleai_option.go +++ b/llms/googleai/googleai_option.go @@ -5,6 +5,7 @@ type options struct { apiKey string defaultModel string defaultEmbeddingModel string + defaultCandidateCount int defaultMaxTokens int defaultTemperature float64 defaultTopK int @@ -16,6 +17,7 @@ func defaultOptions() options { apiKey: "", defaultModel: "gemini-pro", defaultEmbeddingModel: "embedding-001", + defaultCandidateCount: 1, defaultMaxTokens: 256, defaultTemperature: 0.5, defaultTopK: 3, diff --git a/llms/options.go b/llms/options.go index 9a626f163..9d8173f59 100644 --- a/llms/options.go +++ b/llms/options.go @@ -5,10 +5,13 @@ import "context" // CallOption is a function that configures a CallOptions. type CallOption func(*CallOptions) -// CallOptions is a set of options for calling models. +// CallOptions is a set of options for calling models. Not all models support +// all options. type CallOptions struct { // Model is the model to use. Model string `json:"model"` + // CandidateCount is the number of response candidates to generate. + CandidateCount int `json:"candidate_count"` // MaxTokens is the maximum number of tokens to generate. MaxTokens int `json:"max_tokens"` // Temperature is the temperature for sampling, between 0 and 1. @@ -80,6 +83,13 @@ func WithMaxTokens(maxTokens int) CallOption { } } +// WithCandidateCount specifies the number of response candidates to generate. +func WithCandidateCount(c int) CallOption { + return func(o *CallOptions) { + o.CandidateCount = c + } +} + // WithTemperature specifies the model temperature, a hyperparameter that // regulates the randomness, or creativity, of the AI's responses. func WithTemperature(temperature float64) CallOption {