Skip to content

Commit

Permalink
llms: configure multiple response candidates for models that support …
Browse files Browse the repository at this point in the history
…them (#533)
  • Loading branch information
eliben authored Jan 19, 2024
1 parent 2e8220b commit 6287034
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 6 deletions.
12 changes: 7 additions & 5 deletions llms/googleai/googleai_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,17 +74,19 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC
}

opts := llms.CallOptions{
Model: g.opts.defaultModel,
MaxTokens: g.opts.defaultMaxTokens,
Temperature: g.opts.defaultTemperature,
TopP: g.opts.defaultTopP,
TopK: g.opts.defaultTopK,
Model: g.opts.defaultModel,
CandidateCount: g.opts.defaultCandidateCount,
MaxTokens: g.opts.defaultMaxTokens,
Temperature: g.opts.defaultTemperature,
TopP: g.opts.defaultTopP,
TopK: g.opts.defaultTopK,
}
for _, opt := range options {
opt(&opts)
}

model := g.client.GenerativeModel(opts.Model)
model.SetCandidateCount(int32(opts.CandidateCount))
model.SetMaxOutputTokens(int32(opts.MaxTokens))
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
Expand Down
25 changes: 25 additions & 0 deletions llms/googleai/googleai_llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,31 @@ func TestEmbeddings(t *testing.T) {
assert.NotEmpty(t, res[1])
}

func TestCandidateCountSetting(t *testing.T) {
t.Parallel()
llm := newClient(t)

parts := []llms.ContentPart{
llms.TextContent{Text: "Name five countries in Africa"},
}
content := []llms.MessageContent{
{
Role: schema.ChatMessageTypeHuman,
Parts: parts,
},
}

{
rsp, err := llm.GenerateContent(context.Background(), content,
llms.WithCandidateCount(1), llms.WithTemperature(1))
require.NoError(t, err)

assert.Len(t, rsp.Choices, 1)
}

// TODO: test multiple candidates when the backend supports it
}

func TestMaxTokensSetting(t *testing.T) {
t.Parallel()
llm := newClient(t)
Expand Down
2 changes: 2 additions & 0 deletions llms/googleai/googleai_option.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ type options struct {
apiKey string
defaultModel string
defaultEmbeddingModel string
defaultCandidateCount int
defaultMaxTokens int
defaultTemperature float64
defaultTopK int
Expand All @@ -16,6 +17,7 @@ func defaultOptions() options {
apiKey: "",
defaultModel: "gemini-pro",
defaultEmbeddingModel: "embedding-001",
defaultCandidateCount: 1,
defaultMaxTokens: 256,
defaultTemperature: 0.5,
defaultTopK: 3,
Expand Down
12 changes: 11 additions & 1 deletion llms/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@ import "context"
// CallOption is a function that configures a CallOptions.
type CallOption func(*CallOptions)

// CallOptions is a set of options for calling models.
// CallOptions is a set of options for calling models. Not all models support
// all options.
type CallOptions struct {
// Model is the model to use.
Model string `json:"model"`
// CandidateCount is the number of response candidates to generate.
CandidateCount int `json:"candidate_count"`
// MaxTokens is the maximum number of tokens to generate.
MaxTokens int `json:"max_tokens"`
// Temperature is the temperature for sampling, between 0 and 1.
Expand Down Expand Up @@ -80,6 +83,13 @@ func WithMaxTokens(maxTokens int) CallOption {
}
}

// WithCandidateCount specifies the number of response candidates to generate.
func WithCandidateCount(c int) CallOption {
return func(o *CallOptions) {
o.CandidateCount = c
}
}

// WithTemperature specifies the model temperature, a hyperparameter that
// regulates the randomness, or creativity, of the AI's responses.
func WithTemperature(temperature float64) CallOption {
Expand Down

0 comments on commit 6287034

Please sign in to comment.