diff --git a/llms/googleai/googleai_llm.go b/llms/googleai/googleai_llm.go index 0dc0859d4..5faaea717 100644 --- a/llms/googleai/googleai_llm.go +++ b/llms/googleai/googleai_llm.go @@ -89,6 +89,7 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC model.SetTemperature(float32(opts.Temperature)) model.SetTopP(float32(opts.TopP)) model.SetTopK(int32(opts.TopK)) + model.StopSequences = opts.StopWords var response *llms.ContentResponse var err error @@ -120,14 +121,16 @@ func convertCandidates(candidates []*genai.Candidate) (*llms.ContentResponse, er for _, candidate := range candidates { buf := strings.Builder{} - for _, part := range candidate.Content.Parts { - if v, ok := part.(genai.Text); ok { - _, err := buf.WriteString(string(v)) - if err != nil { - return nil, err + if candidate.Content != nil { + for _, part := range candidate.Content.Parts { + if v, ok := part.(genai.Text); ok { + _, err := buf.WriteString(string(v)) + if err != nil { + return nil, err + } + } else { + return nil, ErrUnknownPartInResponse } - } else { - return nil, ErrUnknownPartInResponse } } diff --git a/llms/googleai/googleai_llm_test.go b/llms/googleai/googleai_llm_test.go index e1fc53b55..5c79cb267 100644 --- a/llms/googleai/googleai_llm_test.go +++ b/llms/googleai/googleai_llm_test.go @@ -45,7 +45,7 @@ func TestMultiContentText(t *testing.T) { assert.NotEmpty(t, rsp.Choices) c1 := rsp.Choices[0] - assert.Regexp(t, "dog|canid|canine", strings.ToLower(c1.Content)) + assert.Regexp(t, "(?i)dog|canid|canine", c1.Content) } func TestMultiContentTextStream(t *testing.T) { @@ -109,7 +109,7 @@ func TestMultiContentTextChatSequence(t *testing.T) { assert.NotEmpty(t, rsp.Choices) c1 := rsp.Choices[0] - assert.Regexp(t, "spain.*larger", strings.ToLower(c1.Content)) + assert.Regexp(t, "(?i)spain.*larger", c1.Content) } func TestMultiContentImage(t *testing.T) { @@ -132,7 +132,7 @@ func TestMultiContentImage(t *testing.T) { assert.NotEmpty(t, rsp.Choices) c1 := rsp.Choices[0] - assert.Regexp(t, "parrot", strings.ToLower(c1.Content)) + assert.Regexp(t, "(?i)parrot", c1.Content) } func TestEmbeddings(t *testing.T) { @@ -147,3 +147,44 @@ func TestEmbeddings(t *testing.T) { assert.NotEmpty(t, res[0]) assert.NotEmpty(t, res[1]) } + +func TestMaxTokensSetting(t *testing.T) { + t.Parallel() + llm := newClient(t) + + parts := []llms.ContentPart{ + llms.TextContent{Text: "I'm a pomeranian"}, + llms.TextContent{Text: "Describe my taxonomy, health and care"}, + } + content := []llms.MessageContent{ + { + Role: schema.ChatMessageTypeHuman, + Parts: parts, + }, + } + + // First, try this with a very low MaxTokens setting for such a query; expect + // a stop reason that max of tokens was reached. + { + rsp, err := llm.GenerateContent(context.Background(), content, + llms.WithMaxTokens(16)) + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + c1 := rsp.Choices[0] + assert.Regexp(t, "(?i)MaxTokens", c1.StopReason) + } + + // Now, try it again with a much larger MaxTokens setting and expect to + // finish successfully and generate a response. + { + rsp, err := llm.GenerateContent(context.Background(), content, + llms.WithMaxTokens(2048)) + require.NoError(t, err) + + assert.NotEmpty(t, rsp.Choices) + c1 := rsp.Choices[0] + assert.Regexp(t, "(?i)stop", c1.StopReason) + assert.Regexp(t, "(?i)dog|canid|canine", c1.Content) + } +}