From 60455968864941a09bcf4a7978c55214e722c735 Mon Sep 17 00:00:00 2001 From: Travis Cline Date: Mon, 18 Mar 2024 18:31:37 -0700 Subject: [PATCH] llms: Improve json mode support (#683) * llms: Improve json response format coverage, add example --- examples/json-mode-example/go.mod | 13 +++++ examples/json-mode-example/go.sum | 21 ++++++++ .../json-mode-example/json_mode_example.go | 52 +++++++++++++++++++ llms/anthropic/anthropicllm.go | 8 ++- llms/ollama/ollamallm.go | 7 ++- llms/openai/internal/openaiclient/chat.go | 7 +++ llms/openai/openaillm.go | 9 ++-- llms/openai/openaillm_option.go | 16 ++++++ llms/options.go | 11 ++++ 9 files changed, 137 insertions(+), 7 deletions(-) create mode 100644 examples/json-mode-example/go.mod create mode 100644 examples/json-mode-example/go.sum create mode 100644 examples/json-mode-example/json_mode_example.go diff --git a/examples/json-mode-example/go.mod b/examples/json-mode-example/go.mod new file mode 100644 index 000000000..b2edba6fe --- /dev/null +++ b/examples/json-mode-example/go.mod @@ -0,0 +1,13 @@ +module github.com/tmc/langchaingo/examples/json-mode-example + +go 1.21 + +toolchain go1.21.4 + +require github.com/tmc/langchaingo v0.1.6-alpha.0.0.20240318012619-9dbcc88fd002 + +require ( + github.com/dlclark/regexp2 v1.10.0 // indirect + github.com/google/uuid v1.6.0 // indirect + github.com/pkoukk/tiktoken-go v0.1.6 // indirect +) diff --git a/examples/json-mode-example/go.sum b/examples/json-mode-example/go.sum new file mode 100644 index 000000000..565366567 --- /dev/null +++ b/examples/json-mode-example/go.sum @@ -0,0 +1,21 @@ +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/dlclark/regexp2 v1.8.1 h1:6Lcdwya6GjPUNsBct8Lg/yRPwMhABj269AAzdGSiR+0= +github.com/dlclark/regexp2 v1.8.1/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/dlclark/regexp2 v1.10.0/go.mod h1:DHkYz0B9wPfa6wondMfaivmHpzrQ3v9q8cnmRbL6yW8= +github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= +github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/pkoukk/tiktoken-go v0.1.2 h1:u7PCSBiWJ3nJYoTGShyM9iHXz4dNyYkurwwp+GHtyHY= +github.com/pkoukk/tiktoken-go v0.1.2/go.mod h1:boMWvk9pQCOTx11pgu0DrIdrAKgQzzJKUP6vLXaz7Rw= +github.com/pkoukk/tiktoken-go v0.1.6/go.mod h1:9NiV+i9mJKGj1rYOT+njbv+ZwA/zJxYdewGl6qVatpg= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= +github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/tmc/langchaingo v0.1.5 h1:PNPFu54wn5uVPRt9GS/quRwdFZW4omSab9/dcFAsGmU= +github.com/tmc/langchaingo v0.1.5/go.mod h1:RLtnUED/hH2v765vdjS9Z6gonErZAXURuJHph0BttqM= +github.com/tmc/langchaingo v0.1.6-alpha.0.0.20240318012619-9dbcc88fd002 h1:qM/fnCN2BvGZmDS3gyxeV3m4p6veX/8KCttIMtIYrps= +github.com/tmc/langchaingo v0.1.6-alpha.0.0.20240318012619-9dbcc88fd002/go.mod h1:m+VxH55LmyknIgla6GyUu0U/syv03r4wtIfrJYmWXMY= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/examples/json-mode-example/json_mode_example.go b/examples/json-mode-example/json_mode_example.go new file mode 100644 index 000000000..5f8be15e2 --- /dev/null +++ b/examples/json-mode-example/json_mode_example.go @@ -0,0 +1,52 @@ +package main + +import ( + "context" + "flag" + "fmt" + "log" + "os" + + "github.com/tmc/langchaingo/llms" + "github.com/tmc/langchaingo/llms/anthropic" + "github.com/tmc/langchaingo/llms/googleai" + "github.com/tmc/langchaingo/llms/ollama" + "github.com/tmc/langchaingo/llms/openai" +) + +var flagBackend = flag.String("backend", "openai", "backend to use") + +func main() { + flag.Parse() + ctx := context.Background() + llm, err := initBackend(ctx) + if err != nil { + log.Fatal(err) + } + completion, err := llms.GenerateFromSinglePrompt(ctx, + llm, + "Who was first man to walk on the moon? Respond in json format, include `first_man` in response keys.", + llms.WithTemperature(0.0), + llms.WithJSONMode(), + ) + if err != nil { + log.Fatal(err) + } + + fmt.Println(completion) +} + +func initBackend(ctx context.Context) (llms.Model, error) { + switch *flagBackend { + case "openai": + return openai.New() + case "ollama": + return ollama.New(ollama.WithModel("mistral")) + case "anthropic": + return anthropic.New(anthropic.WithModel("claude-2.1")) + case "googleai": + return googleai.New(ctx, googleai.WithAPIKey(os.Getenv("GOOGLE_AI_API_KEY"))) + default: + return nil, fmt.Errorf("unknown backend: %s", *flagBackend) + } +} diff --git a/llms/anthropic/anthropicllm.go b/llms/anthropic/anthropicllm.go index 7d69f8063..0d3b738ab 100644 --- a/llms/anthropic/anthropicllm.go +++ b/llms/anthropic/anthropicllm.go @@ -3,6 +3,7 @@ package anthropic import ( "context" "errors" + "fmt" "net/http" "os" @@ -71,9 +72,14 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten // Assume we get a single text message msg0 := messages[0] part := msg0.Parts[0] + partText, ok := part.(llms.TextContent) + if !ok { + return nil, fmt.Errorf("unexpected message type: %T", part) + } + prompt := fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", partText.Text) result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{ Model: opts.Model, - Prompt: part.(llms.TextContent).Text, + Prompt: prompt, MaxTokens: opts.MaxTokens, StopWords: opts.StopWords, Temperature: opts.Temperature, diff --git a/llms/ollama/ollamallm.go b/llms/ollama/ollamallm.go index 3ee413d45..d4be8c446 100644 --- a/llms/ollama/ollamallm.go +++ b/llms/ollama/ollamallm.go @@ -97,11 +97,16 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten chatMsgs = append(chatMsgs, msg) } + format := o.options.format + if opts.JSONMode { + format = "json" + } + // Get our ollamaOptions from llms.CallOptions ollamaOptions := makeOllamaOptionsFromOptions(o.options.ollamaOptions, opts) req := &ollamaclient.ChatRequest{ Model: model, - Format: o.options.format, + Format: format, Messages: chatMsgs, Options: ollamaOptions, Stream: func(b bool) *bool { return &b }(opts.StreamingFunc != nil), diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 6ab38f905..527ecf4a4 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -33,6 +33,8 @@ type ChatRequest struct { FrequencyPenalty float64 `json:"frequency_penalty,omitempty"` PresencePenalty float64 `json:"presence_penalty,omitempty"` + ResponseFormat ResponseFormat `json:"response_format,omitempty"` + // Function definitions to include in the request. Functions []FunctionDefinition `json:"functions,omitempty"` // FunctionCallBehavior is the behavior to use when calling functions. @@ -46,6 +48,11 @@ type ChatRequest struct { StreamingFunc func(ctx context.Context, chunk []byte) error `json:"-"` } +// ResponseFormat is the format of the response. +type ResponseFormat struct { + Type string `json:"type"` +} + // ChatMessage is a message in a chat request. type ChatMessage struct { //nolint:musttag // The role of the author of this message. One of system, user, or assistant. diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index e16d225c6..d4bcaf5b3 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -44,9 +44,7 @@ func (o *LLM) Call(ctx context.Context, prompt string, options ...llms.CallOptio } // GenerateContent implements the Model interface. -// -//nolint:goerr113 -func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop +func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageContent, options ...llms.CallOption) (*llms.ContentResponse, error) { //nolint: lll, cyclop, goerr113, funlen if o.CallbacksHandler != nil { o.CallbacksHandler.HandleLLMGenerateContentStart(ctx, messages) } @@ -76,7 +74,6 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten chatMsgs = append(chatMsgs, msg) } - req := &openaiclient.ChatRequest{ Model: opts.Model, StopWords: opts.StopWords, @@ -89,7 +86,9 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten PresencePenalty: opts.PresencePenalty, FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), } - + if opts.JSONMode { + req.ResponseFormat = ResponseFormatJSON + } for _, fn := range opts.Functions { req.Functions = append(req.Functions, openaiclient.FunctionDefinition{ Name: fn.Name, diff --git a/llms/openai/openaillm_option.go b/llms/openai/openaillm_option.go index fb08cdf67..d005e80fe 100644 --- a/llms/openai/openaillm_option.go +++ b/llms/openai/openaillm_option.go @@ -33,6 +33,8 @@ type options struct { apiType APIType httpClient openaiclient.Doer + responseFormat ResponseFormat + // required when APIType is APITypeAzure or APITypeAzureAD apiVersion string embeddingModel string @@ -40,8 +42,15 @@ type options struct { callbackHandler callbacks.Handler } +// Option is a functional option for the OpenAI client. type Option func(*options) +// ResponseFormat is the response format for the OpenAI client. +type ResponseFormat = openaiclient.ResponseFormat + +// ResponseFormatJSON is the JSON response format. +var ResponseFormatJSON = ResponseFormat{Type: "json_object"} //nolint:gochecknoglobals + // WithToken passes the OpenAI API token to the client. If not set, the token // is read from the OPENAI_API_KEY environment variable. func WithToken(token string) Option { @@ -112,3 +121,10 @@ func WithCallback(callbackHandler callbacks.Handler) Option { opts.callbackHandler = callbackHandler } } + +// WithResponseFormat allows setting a custom response format. +func WithResponseFormat(responseFormat ResponseFormat) Option { + return func(opts *options) { + opts.responseFormat = responseFormat + } +} diff --git a/llms/options.go b/llms/options.go index 9d8173f59..da02668c3 100644 --- a/llms/options.go +++ b/llms/options.go @@ -40,6 +40,9 @@ type CallOptions struct { // PresencePenalty is the presence penalty for sampling. PresencePenalty float64 `json:"presence_penalty"` + // JSONMode is a flag to enable JSON mode. + JSONMode bool `json:"json"` + // Function defitions to include in the request. Functions []FunctionDefinition `json:"functions"` // FunctionCallBehavior is the behavior to use when calling functions. @@ -195,3 +198,11 @@ func WithFunctions(functions []FunctionDefinition) CallOption { o.Functions = functions } } + +// WithJSONMode will add an option to set the response format to JSON. +// This is useful for models that return structured data. +func WithJSONMode() CallOption { + return func(o *CallOptions) { + o.JSONMode = true + } +}