diff --git a/llms/googleai/googleai_llm.go b/llms/googleai/googleai_llm.go index 4c8735ce3..bc51b5827 100644 --- a/llms/googleai/googleai_llm.go +++ b/llms/googleai/googleai_llm.go @@ -64,12 +64,12 @@ func NewGoogleAI(ctx context.Context, opts ...Option) (*GoogleAI, error) { return gi, nil } -// Call Implement the call interface for LLM. +// Call implements the [llms.Model] interface. func (g *GoogleAI) Call(ctx context.Context, prompt string, options ...llms.CallOption) (string, error) { return llms.CallLLM(ctx, g, prompt, options...) } -// GenerateContent calls the LLM with the provided parts. +// GenerateContent implements the [llms.Model] interface. func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { if g.CallbacksHandler != nil { g.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) @@ -77,8 +77,10 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC opts := llms.CallOptions{ Model: g.opts.defaultModel, - MaxTokens: int(g.opts.defaultMaxTokens), - Temperature: float64(g.opts.defaultTemperature), + MaxTokens: g.opts.defaultMaxTokens, + Temperature: g.opts.defaultTemperature, + TopP: g.opts.defaultTopP, + TopK: g.opts.defaultTopK, } for _, opt := range options { opt(&opts) @@ -87,6 +89,8 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC model := g.client.GenerativeModel(opts.Model) model.SetMaxOutputTokens(int32(opts.MaxTokens)) model.SetTemperature(float32(opts.Temperature)) + model.SetTopP(float32(opts.TopP)) + model.SetTopK(int32(opts.TopK)) var response *llms.ContentResponse var err error diff --git a/llms/googleai/googleai_option.go b/llms/googleai/googleai_option.go index 38e591db4..bef2ef2be 100644 --- a/llms/googleai/googleai_option.go +++ b/llms/googleai/googleai_option.go @@ -6,8 +6,10 @@ type options struct { apiKey string defaultModel string defaultEmbeddingModel string - defaultMaxTokens int32 - defaultTemperature float32 + defaultMaxTokens int + defaultTemperature float64 + defaultTopK int + defaultTopP float64 } func defaultOptions() options { @@ -17,6 +19,8 @@ func defaultOptions() options { defaultEmbeddingModel: "embedding-001", defaultMaxTokens: 256, defaultTemperature: 0.5, + defaultTopK: 3, + defaultTopP: 0.95, } }