Skip to content

Commit

Permalink
googleai: propagate more options and add test (#532)
Browse files Browse the repository at this point in the history
For #410
  • Loading branch information
eliben authored Jan 19, 2024
1 parent 71ece2d commit 2e8220b
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 10 deletions.
17 changes: 10 additions & 7 deletions llms/googleai/googleai_llm.go
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@ func (g *GoogleAI) GenerateContent(ctx context.Context, messages []llms.MessageC
model.SetTemperature(float32(opts.Temperature))
model.SetTopP(float32(opts.TopP))
model.SetTopK(int32(opts.TopK))
model.StopSequences = opts.StopWords

var response *llms.ContentResponse
var err error
Expand Down Expand Up @@ -120,14 +121,16 @@ func convertCandidates(candidates []*genai.Candidate) (*llms.ContentResponse, er
for _, candidate := range candidates {
buf := strings.Builder{}

for _, part := range candidate.Content.Parts {
if v, ok := part.(genai.Text); ok {
_, err := buf.WriteString(string(v))
if err != nil {
return nil, err
if candidate.Content != nil {
for _, part := range candidate.Content.Parts {
if v, ok := part.(genai.Text); ok {
_, err := buf.WriteString(string(v))
if err != nil {
return nil, err
}
} else {
return nil, ErrUnknownPartInResponse
}
} else {
return nil, ErrUnknownPartInResponse
}
}

Expand Down
47 changes: 44 additions & 3 deletions llms/googleai/googleai_llm_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ func TestMultiContentText(t *testing.T) {

assert.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "dog|canid|canine", strings.ToLower(c1.Content))
assert.Regexp(t, "(?i)dog|canid|canine", c1.Content)
}

func TestMultiContentTextStream(t *testing.T) {
Expand Down Expand Up @@ -109,7 +109,7 @@ func TestMultiContentTextChatSequence(t *testing.T) {

assert.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "spain.*larger", strings.ToLower(c1.Content))
assert.Regexp(t, "(?i)spain.*larger", c1.Content)
}

func TestMultiContentImage(t *testing.T) {
Expand All @@ -132,7 +132,7 @@ func TestMultiContentImage(t *testing.T) {

assert.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "parrot", strings.ToLower(c1.Content))
assert.Regexp(t, "(?i)parrot", c1.Content)
}

func TestEmbeddings(t *testing.T) {
Expand All @@ -147,3 +147,44 @@ func TestEmbeddings(t *testing.T) {
assert.NotEmpty(t, res[0])
assert.NotEmpty(t, res[1])
}

func TestMaxTokensSetting(t *testing.T) {
t.Parallel()
llm := newClient(t)

parts := []llms.ContentPart{
llms.TextContent{Text: "I'm a pomeranian"},
llms.TextContent{Text: "Describe my taxonomy, health and care"},
}
content := []llms.MessageContent{
{
Role: schema.ChatMessageTypeHuman,
Parts: parts,
},
}

// First, try this with a very low MaxTokens setting for such a query; expect
// a stop reason that max of tokens was reached.
{
rsp, err := llm.GenerateContent(context.Background(), content,
llms.WithMaxTokens(16))
require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "(?i)MaxTokens", c1.StopReason)
}

// Now, try it again with a much larger MaxTokens setting and expect to
// finish successfully and generate a response.
{
rsp, err := llm.GenerateContent(context.Background(), content,
llms.WithMaxTokens(2048))
require.NoError(t, err)

assert.NotEmpty(t, rsp.Choices)
c1 := rsp.Choices[0]
assert.Regexp(t, "(?i)stop", c1.StopReason)
assert.Regexp(t, "(?i)dog|canid|canine", c1.Content)
}
}

0 comments on commit 2e8220b

Please sign in to comment.