Skip to content

Commit

Permalink
Merge branch 'main' into tool-count
Browse files Browse the repository at this point in the history
  • Loading branch information
eliben committed Jul 19, 2024
2 parents b9197be + 70f6989 commit e345ce6
Show file tree
Hide file tree
Showing 8 changed files with 45 additions and 77 deletions.
4 changes: 2 additions & 2 deletions genai/caching_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func testCaching(t *testing.T, client *Client) {
argcc := &CachedContent{
Model: model,
Expiration: ExpireTimeOrTTL{TTL: ttl},
Contents: []*Content{{Role: "user", Parts: parts}},
Contents: []*Content{NewUserContent(parts...)},
}
cc := must(client.CreateCachedContent(ctx, argcc))
compare(cc, wantExpireTime)
Expand Down Expand Up @@ -158,7 +158,7 @@ func testCaching(t *testing.T, client *Client) {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &CachedContent{
Model: model,
Contents: []*Content{{Role: "user", Parts: []Part{Text(txt)}}},
Contents: []*Content{NewUserContent(Text(txt))},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down
4 changes: 2 additions & 2 deletions genai/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (m *GenerativeModel) StartChat() *ChatSession {
// SendMessage sends a request to the model as part of a chat session.
func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
// Call the underlying client with the entire history plus the argument Content.
cs.History = append(cs.History, newUserContent(parts))
cs.History = append(cs.History, NewUserContent(parts...))
req, err := cs.m.newGenerateContentRequest(cs.History...)
if err != nil {
return nil, err
Expand All @@ -48,7 +48,7 @@ func (cs *ChatSession) SendMessage(ctx context.Context, parts ...Part) (*Generat

// SendMessageStream is like SendMessage, but with a streaming request.
func (cs *ChatSession) SendMessageStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
cs.History = append(cs.History, newUserContent(parts))
cs.History = append(cs.History, NewUserContent(parts...))
req, err := cs.m.newGenerateContentRequest(cs.History...)
if err != nil {
return &GenerateContentResponseIterator{err: err}
Expand Down
10 changes: 3 additions & 7 deletions genai/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,7 @@ func fullModelName(name string) string {

// GenerateContent produces a single request and response.
func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*GenerateContentResponse, error) {
content := newUserContent(parts)
content := NewUserContent(parts...)
req, err := m.newGenerateContentRequest(content)
if err != nil {
return nil, err
Expand All @@ -194,7 +194,7 @@ func (m *GenerativeModel) GenerateContent(ctx context.Context, parts ...Part) (*
// GenerateContentStream returns an iterator that enumerates responses.
func (m *GenerativeModel) GenerateContentStream(ctx context.Context, parts ...Part) *GenerateContentResponseIterator {
iter := &GenerateContentResponseIterator{}
req, err := m.newGenerateContentRequest(newUserContent(parts))
req, err := m.newGenerateContentRequest(NewUserContent(parts...))
if err != nil {
iter.err = err
} else {
Expand Down Expand Up @@ -241,10 +241,6 @@ func (m *GenerativeModel) newGenerateContentRequest(contents ...*Content) (*pb.G
})
}

func newUserContent(parts []Part) *Content {
return &Content{Role: roleUser, Parts: parts}
}

// GenerateContentResponseIterator is an iterator over GnerateContentResponse.
type GenerateContentResponseIterator struct {
sc pb.GenerativeService_StreamGenerateContentClient
Expand Down Expand Up @@ -313,7 +309,7 @@ func (iter *GenerateContentResponseIterator) MergedResponse() *GenerateContentRe

// CountTokens counts the number of tokens in the content.
func (m *GenerativeModel) CountTokens(ctx context.Context, parts ...Part) (*CountTokensResponse, error) {
req, err := m.newCountTokensRequest(newUserContent(parts))
req, err := m.newCountTokensRequest(NewUserContent(parts...))
if err != nil {
return nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion genai/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -814,7 +814,7 @@ func TestRecoverPanic(t *testing.T) {
Response: map[string]any{"x": 1 + 2i}, // complex values are invalid
}
var m GenerativeModel
_, err := m.newGenerateContentRequest(newUserContent([]Part{fr}))
_, err := m.newGenerateContentRequest(NewUserContent(fr))
if err == nil {
t.Fatal("got nil, want error")
}
Expand Down
10 changes: 10 additions & 0 deletions genai/content.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,13 @@ func (c *Candidate) FunctionCalls() []FunctionCall {
}
return fcs
}

// NewUserContent returns a *Content with a "user" role set and one or more
// parts.
func NewUserContent(parts ...Part) *Content {
content := &Content{Role: roleUser, Parts: []Part{}}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}
2 changes: 1 addition & 1 deletion genai/embed.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (m *EmbeddingModel) EmbedContentWithTitle(ctx context.Context, title string
func newEmbedContentRequest(model string, tt TaskType, title string, parts []Part) *pb.EmbedContentRequest {
req := &pb.EmbedContentRequest{
Model: model,
Content: newUserContent(parts).toProto(),
Content: NewUserContent(parts...).toProto(),
}
// A non-empty title overrides the task type.
if title != "" {
Expand Down
45 changes: 13 additions & 32 deletions genai/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,9 +191,7 @@ func ExampleGenerativeModel_GenerateContent_config() {
model.SetTopP(0.5)
model.SetTopK(20)
model.SetMaxOutputTokens(100)
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars."))
model.ResponseMIMEType = "application/json"
resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?"))
if err != nil {
Expand All @@ -212,9 +210,7 @@ func ExampleGenerativeModel_GenerateContent_systemInstruction() {
defer client.Close()

model := client.GenerativeModel("gemini-1.5-flash")
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
resp, err := model.GenerateContent(ctx, genai.Text("Good morning! How are you?"))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -482,7 +478,7 @@ func ExampleGenerativeModel_CountTokens_cachedContent() {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
Contents: []*genai.Content{{Role: "user", Parts: []genai.Part{genai.Text(txt)}}},
Contents: []*genai.Content{genai.NewUserContent(genai.Text(txt))},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -692,9 +688,7 @@ func ExampleGenerativeModel_CountTokens_systemInstruction() {
// ( total_tokens: 10 )

// Same prompt, this time with system instruction
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
respWithInstruction, err := model.CountTokens(ctx, genai.Text(prompt))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -1236,8 +1230,8 @@ func ExampleCachedContent_create() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1273,7 +1267,7 @@ func ExampleCachedContent_createFromChat() {

modelName := "gemini-1.5-flash-001"
model := client.GenerativeModel(modelName)
model.SystemInstruction = userContent(genai.Text("You are an expert analyzing transcripts."))
model.SystemInstruction = genai.NewUserContent(genai.Text("You are an expert analyzing transcripts."))

cs := model.StartChat()
resp, err := cs.SendMessage(ctx, genai.Text("Hi, could you summarize this transcript?"), fd)
Expand Down Expand Up @@ -1327,8 +1321,8 @@ func ExampleClient_GetCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1372,8 +1366,8 @@ func ExampleClient_ListCachedContents() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1413,8 +1407,8 @@ func ExampleClient_UpdateCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1490,19 +1484,6 @@ func ExampleClient_setProxy() {
printResponse(resp)
}

// userContent helps create a *genai.Content with a "user" role and one or
// more parts with less verbosity.
func userContent(parts ...genai.Part) *genai.Content {
content := &genai.Content{
Role: "user",
Parts: []genai.Part{},
}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}

func printResponse(resp *genai.GenerateContentResponse) {
for _, cand := range resp.Candidates {
if cand.Content != nil {
Expand Down
45 changes: 13 additions & 32 deletions genai/internal/samples/docs-snippets_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -196,9 +196,7 @@ func ExampleGenerativeModel_GenerateContent_config() {
model.SetTopP(0.5)
model.SetTopK(20)
model.SetMaxOutputTokens(100)
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are Yoda from Star Wars.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are Yoda from Star Wars."))
model.ResponseMIMEType = "application/json"
resp, err := model.GenerateContent(ctx, genai.Text("What is the average size of a swallow?"))
if err != nil {
Expand All @@ -218,9 +216,7 @@ func ExampleGenerativeModel_GenerateContent_systemInstruction() {

// [START system_instruction]
model := client.GenerativeModel("gemini-1.5-flash")
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
resp, err := model.GenerateContent(ctx, genai.Text("Good morning! How are you?"))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -498,7 +494,7 @@ func ExampleGenerativeModel_CountTokens_cachedContent() {
txt := strings.Repeat("George Washington was the first president of the United States. ", 3000)
argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
Contents: []*genai.Content{{Role: "user", Parts: []genai.Part{genai.Text(txt)}}},
Contents: []*genai.Content{genai.NewUserContent(genai.Text(txt))},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -714,9 +710,7 @@ func ExampleGenerativeModel_CountTokens_systemInstruction() {
// ( total_tokens: 10 )

// Same prompt, this time with system instruction
model.SystemInstruction = &genai.Content{
Parts: []genai.Part{genai.Text("You are a cat. Your name is Neko.")},
}
model.SystemInstruction = genai.NewUserContent(genai.Text("You are a cat. Your name is Neko."))
respWithInstruction, err := model.CountTokens(ctx, genai.Text(prompt))
if err != nil {
log.Fatal(err)
Expand Down Expand Up @@ -1275,8 +1269,8 @@ func ExampleCachedContent_create() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1313,7 +1307,7 @@ func ExampleCachedContent_createFromChat() {

modelName := "gemini-1.5-flash-001"
model := client.GenerativeModel(modelName)
model.SystemInstruction = userContent(genai.Text("You are an expert analyzing transcripts."))
model.SystemInstruction = genai.NewUserContent(genai.Text("You are an expert analyzing transcripts."))

cs := model.StartChat()
resp, err := cs.SendMessage(ctx, genai.Text("Hi, could you summarize this transcript?"), fd)
Expand Down Expand Up @@ -1370,8 +1364,8 @@ func ExampleClient_GetCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1417,8 +1411,8 @@ func ExampleClient_ListCachedContents() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1459,8 +1453,8 @@ func ExampleClient_UpdateCachedContent() {

argcc := &genai.CachedContent{
Model: "gemini-1.5-flash-001",
SystemInstruction: userContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{userContent(fd)},
SystemInstruction: genai.NewUserContent(genai.Text("You are an expert analyzing transcripts.")),
Contents: []*genai.Content{genai.NewUserContent(fd)},
}
cc, err := client.CreateCachedContent(ctx, argcc)
if err != nil {
Expand Down Expand Up @@ -1536,19 +1530,6 @@ func ExampleClient_setProxy() {
printResponse(resp)
}

// userContent helps create a *genai.Content with a "user" role and one or
// more parts with less verbosity.
func userContent(parts ...genai.Part) *genai.Content {
content := &genai.Content{
Role: "user",
Parts: []genai.Part{},
}
for _, part := range parts {
content.Parts = append(content.Parts, part)
}
return content
}

func printResponse(resp *genai.GenerateContentResponse) {
for _, cand := range resp.Candidates {
if cand.Content != nil {
Expand Down

0 comments on commit e345ce6

Please sign in to comment.