Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

genai: add helper function NewUserContent #188

Merged
merged 1 commit into from
Jul 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions genai/caching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func testCaching(t *testing.T, client *Client) {
argcc := &CachedContent{
Model: model,
Expiration: ExpireTimeOrTTL{TTL: ttl},
Contents: []*Content{{Role: "user", Parts: parts}},
Contents: []*Content{NewUserContent(parts...)},
}
cc := must(client.CreateCachedContent(ctx, argcc))
compare(cc, wantExpireTime)
Expand Down Expand Up @@ -158,7 +158,7 @@ func testCaching(t *testing.T, client *Client) {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &CachedContent{
Model: model,
Contents: []*Content{{Role: "user", Parts: []Part{Text(txt)}}},
Contents: []*Content{NewUserContent(Text(txt))},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions genai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *GenerativeModel) StartChat() *ChatSession {
// SendMessage sends a request to the model as part of a chat session.
func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
// Call the underlying client with the entire history plus the argument Content.
cs.History = append(cs.History, newUserContent(parts))
cs.History = append(cs.History, NewUserContent(parts...))
req, err := cs.m.newGenerateContentRequest(cs.History...)
if err != nil {
return nil, err
Expand All @@ -48,7 +48,7 @@ func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*Generat

// SendMessageStream is like SendMessage, but with a streaming request.
func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
cs.History = append(cs.History, newUserContent(parts))
cs.History = append(cs.History, NewUserContent(parts...))
req, err := cs.m.newGenerateContentRequest(cs.History...)
if err != nil {
return &GenerateContentResponseIterator{err: err}
Expand Down
10 changes: 3 additions & 7 deletions genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func fullModelName(name string) string {

// GenerateContent produces a single request and response.
func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
content := newUserContent(parts)
content := NewUserContent(parts...)
req, err := m.newGenerateContentRequest(content)
if err != nil {
return nil, err
Expand All @@ -194,7 +194,7 @@ func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*
// GenerateContentStream returns an iterator that enumerates responses.
func (m *GenerativeModel) GenerateContentStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
iter := &GenerateContentResponseIterator{}
req, err := m.newGenerateContentRequest(newUserContent(parts))
req, err := m.newGenerateContentRequest(NewUserContent(parts...))
if err != nil {
iter.err = err
} else {
Expand Down Expand Up @@ -241,10 +241,6 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.G
})
}

func newUserContent(parts []Part) *Content {
return &Content{Role: roleUser, Parts: parts}
}

// GenerateContentResponseIterator is an iterator over GnerateContentResponse.
type GenerateContentResponseIterator struct {
sc pb.GenerativeService_StreamGenerateContentClient
Expand Down Expand Up @@ -313,7 +309,7 @@ func (iter *GenerateContentResponseIterator) MergedResponse() *GenerateContentRe

// CountTokens counts the number of tokens in the content.
func (m *GenerativeModel) CountTokens(ctx context.Context, parts ...Part) (*CountTokensResponse, error) {
req, err := m.newCountTokensRequest(newUserContent(parts))
req, err := m.newCountTokensRequest(NewUserContent(parts...))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ func TestRecoverPanic(t *testing.T) {
Response: map[string]any{"x": 1 + 2i}, // complex values are invalid
}
var m GenerativeModel
_, err := m.newGenerateContentRequest(newUserContent([]Part{fr}))
_, err := m.newGenerateContentRequest(NewUserContent(fr))
if err == nil {
t.Fatal("got nil, want error")
}
Expand Down
10 changes: 10 additions & 0 deletions genai/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,13 @@ func (c *Candidate) FunctionCalls() []FunctionCall {
}
return fcs
}

// NewUserContent returns a *Content with a "user" role set and one or more
// parts.
func NewUserContent(parts ...Part) *Content {
content := &Content{Role: roleUser, Parts: []Part{}}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}
2 changes: 1 addition & 1 deletion genai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (m *EmbeddingModel) EmbedContentWithTitle(ctx context.Context, title string
func newEmbedContentRequest(model string, tt TaskType, title string, parts []Part) *pb.EmbedContentRequest {
req := &pb.EmbedContentRequest{
Model: model,
Content: newUserContent(parts).toProto(),
Content: NewUserContent(parts...).toProto(),
}
// A non-empty title overrides the task type.
if title != "" {
Expand Down
54 changes: 16 additions & 38 deletions genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ func ExampleGenerativeModel_GenerateContent_config() {
model.SetTopP(0.5)
model.SetTopK(20)
model.SetMaxOutputTokens(100)
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars."))
model.ResponseMIMEType = "application/json"
resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?"))
if err != nil {
Expand All @@ -212,9 +210,7 @@ func ExampleGenerativeModel_GenerateContent_systemInstruction() {
defer client.Close()

model := client.GenerativeModel("gemini-1.5-flash")
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
resp, err := model.GenerateContent(ctx, genai.Text("Good morning! How are you?"))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -303,7 +299,6 @@ func ExampleGenerativeModel_GenerateContentStream() {
}
defer client.Close()

// START [text_gen_text_only_prompt_streaming]
model := client.GenerativeModel("gemini-1.5-flash")
iter := model.GenerateContentStream(ctx, genai.Text("Write a story about a magic backpack."))
for {
Expand All @@ -316,7 +311,7 @@ func ExampleGenerativeModel_GenerateContentStream() {
}
printResponse(resp)
}
// END [text_gen_text_only_prompt_streaming]

}

func ExampleGenerativeModel_GenerateContentStream_imagePrompt() {
Expand All @@ -327,7 +322,6 @@ func ExampleGenerativeModel_GenerateContentStream_imagePrompt() {
}
defer client.Close()

// START [text_gen_multimodal_one_image_prompt_streaming]
model := client.GenerativeModel("gemini-1.5-flash")

imgData, err := os.ReadFile(filepath.Join(testDataDir, "organ.jpg"))
Expand All @@ -347,7 +341,7 @@ func ExampleGenerativeModel_GenerateContentStream_imagePrompt() {
}
printResponse(resp)
}
// END [text_gen_multimodal_one_image_prompt_streaming]

}

func ExampleGenerativeModel_GenerateContentStream_videoPrompt() {
Expand All @@ -358,7 +352,6 @@ func ExampleGenerativeModel_GenerateContentStream_videoPrompt() {
}
defer client.Close()

// START [text_gen_multimodal_video_prompt_streaming]
model := client.GenerativeModel("gemini-1.5-flash")

file, err := uploadFile(ctx, client, filepath.Join(testDataDir, "earth.mp4"), "")
Expand All @@ -380,7 +373,7 @@ func ExampleGenerativeModel_GenerateContentStream_videoPrompt() {
}
printResponse(resp)
}
// END [text_gen_multimodal_video_prompt_streaming]

}

func ExampleGenerativeModel_CountTokens_contextWindow() {
Expand Down Expand Up @@ -447,7 +440,7 @@ func ExampleGenerativeModel_CountTokens_cachedContent() {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
Contents: []*genai.Content{{Role: "user", Parts: []genai.Part{genai.Text(txt)}}},
Contents: []*genai.Content{genai.NewUserContent(genai.Text(txt))},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -657,9 +650,7 @@ func ExampleGenerativeModel_CountTokens_systemInstruction() {
// ( total_tokens: 10 )

// Same prompt, this time with system instruction
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
respWithInstruction, err := model.CountTokens(ctx, genai.Text(prompt))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -1201,8 +1192,8 @@ func ExampleCachedContent_create() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1238,7 +1229,7 @@ func ExampleCachedContent_createFromChat() {

modelName := "gemini-1.5-flash-001"
model := client.GenerativeModel(modelName)
model.SystemInstruction = userContent(genai.Text("You are an expert analyzing transcripts."))
model.SystemInstruction = genai.NewUserContent(genai.Text("You are an expert analyzing transcripts."))

cs := model.StartChat()
resp, err := cs.SendMessage(ctx, genai.Text("Hi, could you summarize this transcript?"), fd)
Expand Down Expand Up @@ -1292,8 +1283,8 @@ func ExampleClient_GetCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1337,8 +1328,8 @@ func ExampleClient_ListCachedContents() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1378,8 +1369,8 @@ func ExampleClient_UpdateCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1455,19 +1446,6 @@ func ExampleClient_setProxy() {
printResponse(resp)
}

// userContent helps create a *genai.Content with a "user" role and one or
// more parts with less verbosity.
func userContent(parts ...genai.Part) *genai.Content {
content := &genai.Content{
Role: "user",
Parts: []genai.Part{},
}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}

func printResponse(resp *genai.GenerateContentResponse) {
for _, cand := range resp.Candidates {
if cand.Content != nil {
Expand Down
45 changes: 13 additions & 32 deletions genai/internal/samples/docs-snippets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ func ExampleGenerativeModel_GenerateContent_config() {
model.SetTopP(0.5)
model.SetTopK(20)
model.SetMaxOutputTokens(100)
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars."))
model.ResponseMIMEType = "application/json"
resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?"))
if err != nil {
Expand All @@ -218,9 +216,7 @@ func ExampleGenerativeModel_GenerateContent_systemInstruction() {

// [START system_instruction]
model := client.GenerativeModel("gemini-1.5-flash")
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
resp, err := model.GenerateContent(ctx, genai.Text("Good morning! How are you?"))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -458,7 +454,7 @@ func ExampleGenerativeModel_CountTokens_cachedContent() {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
Contents: []*genai.Content{{Role: "user", Parts: []genai.Part{genai.Text(txt)}}},
Contents: []*genai.Content{genai.NewUserContent(genai.Text(txt))},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -674,9 +670,7 @@ func ExampleGenerativeModel_CountTokens_systemInstruction() {
// ( total_tokens: 10 )

// Same prompt, this time with system instruction
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
respWithInstruction, err := model.CountTokens(ctx, genai.Text(prompt))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -1235,8 +1229,8 @@ func ExampleCachedContent_create() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1273,7 +1267,7 @@ func ExampleCachedContent_createFromChat() {

modelName := "gemini-1.5-flash-001"
model := client.GenerativeModel(modelName)
model.SystemInstruction = userContent(genai.Text("You are an expert analyzing transcripts."))
model.SystemInstruction = genai.NewUserContent(genai.Text("You are an expert analyzing transcripts."))

cs := model.StartChat()
resp, err := cs.SendMessage(ctx, genai.Text("Hi, could you summarize this transcript?"), fd)
Expand Down Expand Up @@ -1330,8 +1324,8 @@ func ExampleClient_GetCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1377,8 +1371,8 @@ func ExampleClient_ListCachedContents() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1419,8 +1413,8 @@ func ExampleClient_UpdateCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1496,19 +1490,6 @@ func ExampleClient_setProxy() {
printResponse(resp)
}

// userContent helps create a *genai.Content with a "user" role and one or
// more parts with less verbosity.
func userContent(parts ...genai.Part) *genai.Content {
content := &genai.Content{
Role: "user",
Parts: []genai.Part{},
}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}

func printResponse(resp *genai.GenerateContentResponse) {
for _, cand := range resp.Candidates {
if cand.Content != nil {
Expand Down
Loading