Skip to content

Commit

Permalink
examples: add vision support and example
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc committed Nov 21, 2023
1 parent 4b80634 commit 23d87d1
Showing 9 changed files with 214 additions and 10 deletions.
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@ import (
)

func main() {
llm, err := openai.NewChat(openai.WithModel("gpt-3.5-turbo-0613"))
llm, err := openai.NewChat(openai.WithModel("gpt-3.5-turbo-1106"))
if err != nil {
log.Fatal(err)
}
@@ -29,11 +29,17 @@ func main() {
}
}

func getCurrentWeather(location string, unit string) (string, error) {
// Get the current weather in the specified location.
func getCurrentWeather(req struct {
// The city and state, e.g. San Francisco, CA.
location string
// The temperature unit, either "celcius" or "fahrenheit".
unit string
}) (string, error) {
weatherInfo := map[string]interface{}{
"location": location,
"location": req.location,
"temperature": "72",
"unit": unit,
"unit": req.unit,
"forecast": []string{"sunny", "windy"},
}
b, err := json.Marshal(weatherInfo)
@@ -47,6 +53,6 @@ var functions = []llms.FunctionDefinition{
{
Name: "getCurrentWeather",
Description: "Get the current weather in a given location",
Parameters: json.RawMessage(`{"type": "object", "properties": {"location": {"type": "string", "description": "The city and state, e.g. San Francisco, CA"}, "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}}, "required": ["location"]}`),
Parameters: json.RawMessage(`{"type":"object","properties":{"req":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA."},"unit":{"type":"string","description":"The temperature unit, either \"celcius\" or \"fahrenheit\"."}}}},"required":["req","unit"]}`),
},
}
2 changes: 1 addition & 1 deletion examples/openai-gpt4-turbo-example/openai_gpt4_turbo.go
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@ import (
)

func main() {
llm, err := openai.NewChat(openai.WithModel("gpt-4-1106-preview"))
llm, err := openai.NewChat(openai.WithModel("gpt-3.5-turbo"))
if err != nil {
log.Fatal(err)
}
11 changes: 11 additions & 0 deletions examples/openai-gpt4-vision-example/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
module github.com/tmc/langchaingo/examples/openai-gpt4-vsion-example

go 1.19

require github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823

require (
github.com/dlclark/regexp2 v1.8.1 // indirect
github.com/google/uuid v1.3.1 // indirect
github.com/pkoukk/tiktoken-go v0.1.2 // indirect
)
12 changes: 12 additions & 0 deletions examples/openai-gpt4-vision-example/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0=
github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8=
github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4=
github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHtyHY=
github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823 h1:vIYDdskaXk+iwfrMIaN9dM66o3wUedqvWWww1o7a1m4=
github.com/tmc/langchaingo v0.0.0-20231028223410-5f4451567823/go.mod h1:wwzKIaam0XFmiWfTlvSvdKwq7CkxE9Tz5rIkz1KKDws=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
90 changes: 90 additions & 0 deletions examples/openai-gpt4-vision-example/openai_gpt4_vision_example.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
package main

import (
"context"
"encoding/base64"
"flag"
"fmt"
"io"
"log"
"net/http"
"os"

"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/schema"
)

var (
flagImagePath = flag.String("image", "-", "path to image to send to model")
)

func main() {
flag.Parse()
llm, err := openai.NewChat(openai.WithModel("gpt-4-vision-preview"))
if err != nil {
log.Fatal(err)
}
ctx := context.Background()

base64Image, err := loadImageBase64(*flagImagePath)
if err != nil {
log.Fatal(err)
}
completion, err := llm.Call(ctx, []schema.ChatMessage{
schema.CompoundChatMessage{
Type: schema.ChatMessageTypeHuman,
Parts: []schema.ChatMessageContentPart{
schema.ChatMessageContentPartText{
Type: "text",
Text: "What is in this image?",
},
schema.ChatMessageContentPartImage{
Type: "image_url",
ImageURL: schema.ChatMessageContentPartImageURL{
URL: base64Image,
},
},
},
},
})
if err != nil {
log.Fatal(err)
}

fmt.Println(completion)

}

func loadImageBase64(path string) (string, error) {
f, err := pathToReader(path)
if err != nil {
return "", fmt.Errorf("failed to open image: %w", err)
}
defer f.Close()
data, err := io.ReadAll(f)
if err != nil {
return "", fmt.Errorf("failed to read image: %w", err)
}

// Determine the content type of the image file
mimeType := http.DetectContentType(data)

var base64Encoding string
// Prepend the appropriate URI scheme header depending
// on the MIME type
switch mimeType {
case "image/jpeg":
base64Encoding += "data:image/jpeg;base64,"
case "image/png":
base64Encoding += "data:image/png;base64,"
}
base64Encoding += base64.StdEncoding.EncodeToString(data)
return base64Encoding, nil
}

func pathToReader(path string) (io.ReadCloser, error) {
if path == "-" {
return os.Stdin, nil
}
return os.Open(path)
}
13 changes: 11 additions & 2 deletions llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
@@ -7,8 +7,10 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"log"
"net/http"
"os"
"strings"
)

@@ -47,7 +49,7 @@ type ChatMessage struct {
// The role of the author of this message. One of system, user, or assistant.
Role string `json:"role"`
// The content of the message.
Content string `json:"content"`
Content any `json:"content"`
// The name of the author of this message. May contain a-z, A-Z, 0-9, and underscores,
// with a maximum length of 64 characters.
Name string `json:"name,omitempty"`
@@ -141,6 +143,9 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes
return nil, err
}

// print raw payload:
fmt.Println(string(payloadBytes))

// Build request
body := bytes.NewReader(payloadBytes)
if c.baseURL == "" {
@@ -160,6 +165,9 @@ func (c *Client) createChat(ctx context.Context, payload *ChatRequest) (*ChatRes
}
defer r.Body.Close()

buf, err := io.ReadAll(r.Body)
io.Copy(os.Stdout, bytes.NewReader(buf))

if r.StatusCode != http.StatusOK {
msg := fmt.Sprintf("API returned unexpected status code: %d", r.StatusCode)

@@ -220,7 +228,8 @@ func parseStreamingChatResponse(ctx context.Context, r *http.Response, payload *
continue
}
chunk := []byte(streamResponse.Choices[0].Delta.Content)
response.Choices[0].Message.Content += streamResponse.Choices[0].Delta.Content
// response.Choices[0].Message.Content += streamResponse.Choices[0].Delta.Content
response.Choices[0].Message.Content = fmt.Sprintf("%s%s", response.Choices[0].Message.Content, streamResponse.Choices[0].Delta.Content)
response.Choices[0].FinishReason = streamResponse.Choices[0].FinishReason
if streamResponse.Choices[0].Delta.FunctionCall != nil {
if response.Choices[0].Message.FunctionCall == nil {
2 changes: 1 addition & 1 deletion llms/openai/internal/openaiclient/openaiclient.go
Original file line number Diff line number Diff line change
@@ -86,7 +86,7 @@ func (c *Client) CreateCompletion(ctx context.Context, r *CompletionRequest) (*C
return nil, ErrEmptyResponse
}
return &Completion{
Text: resp.Choices[0].Message.Content,
Text: fmt.Sprint(resp.Choices[0].Message.Content),
}, nil
}

7 changes: 6 additions & 1 deletion llms/openai/openaillm_chat.go
Original file line number Diff line number Diff line change
@@ -2,6 +2,7 @@ package openai

import (
"context"
"fmt"
"reflect"

"github.com/tmc/langchaingo/callbacks"
@@ -93,7 +94,7 @@ func (o *Chat) Generate(ctx context.Context, messageSets [][]schema.ChatMessage,
generationInfo["PromptTokens"] = result.Usage.PromptTokens
generationInfo["TotalTokens"] = result.Usage.TotalTokens
msg := &schema.AIChatMessage{
Content: result.Choices[0].Message.Content,
Content: fmt.Sprint(result.Choices[0].Message.Content),
}
if result.Choices[0].FinishReason == "function_call" {
msg.FunctionCall = &schema.FunctionCall{
@@ -175,6 +176,10 @@ func messagesToClientMessages(messages []schema.ChatMessage) []*openaiclient.Cha
if n, ok := m.(schema.Named); ok {
msg.Name = n.GetName()
}
if cl, ok := m.(schema.ContentList); ok {
fmt.Println("has content list")
msg.Content = cl.GetContentList()
}
msgs[i] = msg
}

71 changes: 71 additions & 0 deletions schema/chat_messages.go
Original file line number Diff line number Diff line change
@@ -39,13 +39,29 @@ type Named interface {
GetName() string
}

// ContentList is an interface for objects that have a list of content.
type ContentList interface {
GetContentList() []ChatMessageContentPart
}

// Statically assert that the types implement the interface.
var (
_ ChatMessage = AIChatMessage{}
_ ChatMessage = HumanChatMessage{}
_ ChatMessage = SystemChatMessage{}
_ ChatMessage = GenericChatMessage{}
_ ChatMessage = FunctionChatMessage{}
_ ChatMessage = CompoundChatMessage{}
)

// ContentType is the type of content in a message.
type ContentType string

const (
// ContentTypeText is text.
ContentTypeText ContentType = "text"
// ContentTypeImage is an image.
ContentTypeImage ContentType = "image_url"
)

// AIChatMessage is a message sent by an AI.
@@ -103,6 +119,61 @@ func (m FunctionChatMessage) GetType() ChatMessageType { return ChatMessageTypeF
func (m FunctionChatMessage) GetContent() string { return m.Content }
func (m FunctionChatMessage) GetName() string { return m.Name }

// CompoundChatMessage is a chat message with multiple parts.
type CompoundChatMessage struct {
Type ChatMessageType
Parts []ChatMessageContentPart
}

func (m CompoundChatMessage) GetType() ChatMessageType { return m.Type }
func (m CompoundChatMessage) GetContent() string {
b, _ := json.Marshal(m.Parts)
return string(b)
}

func (m CompoundChatMessage) GetContentList() []ChatMessageContentPart {
return m.Parts
}

// ChatMessageContentPart is a part of a chat message.
type ChatMessageContentPart interface {
isChatMessageContentPart()
}

// ChatMessageContentPartText is a text part of a chat message.
type ChatMessageContentPartText struct {
Type string `json:"type"`
Text string `json:"text"`
}

// ChatMessageContentPartImageURL is an image part of a chat message.
type ChatMessageContentPartImageURL struct {
URL string `json:"url"`
Details ImageDetail `json:"detail,omitempty"`
}

// ImageDetail is the detail of an image.
type ImageDetail string

const (
// ImageDetailAuto is the default image detail.
ImageDetailAuto ImageDetail = "auto"
// ImageDetailLow is the low image detail.
ImageDetailLow ImageDetail = "low"
// ImageDetailHigh is the high image detail.
ImageDetailHigh ImageDetail = "high"
)

func (ChatMessageContentPartText) isChatMessageContentPart() {}

// ChatMessageContentPartImage is an image part of a chat message.
type ChatMessageContentPartImage struct {
Type string `json:"type"`
ImageURL ChatMessageContentPartImageURL `json:"image_url,omitempty"`
}

func (ChatMessageContentPartImage) isChatMessageContentPart() {}

// ChatGeneration is the output of a single chat generation.
type ChatGeneration struct {
Generation

0 comments on commit 23d87d1

Please sign in to comment.