diff --git a/chains/options.go b/chains/options.go index bb77ae7d0..6d5e08d7e 100644 --- a/chains/options.go +++ b/chains/options.go @@ -27,6 +27,10 @@ type chainCallOption struct { MaxTokens int maxTokensSet bool + // MaxCompletionTokens is the maximum number of tokens to generate - used in the current openai API + MaxCompletionTokens int + maxCompletionTokensSet bool + // Temperature is the temperature for sampling to use in an LLM call, between 0 and 1. Temperature float64 temperatureSet bool @@ -83,6 +87,14 @@ func WithMaxTokens(maxTokens int) ChainCallOption { } } +// WithMaxCompletionTokens is an option for LLM.Call. +func WithMaxCompletionTokens(maxCompletionTokens int) ChainCallOption { + return func(o *chainCallOption) { + o.MaxCompletionTokens = maxCompletionTokens + o.maxCompletionTokensSet = true + } +} + // WithTemperature is an option for LLM.Call. func WithTemperature(temperature float64) ChainCallOption { return func(o *chainCallOption) { @@ -181,6 +193,9 @@ func getLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint: if opts.maxTokensSet { chainCallOption = append(chainCallOption, llms.WithMaxTokens(opts.MaxTokens)) } + if opts.maxCompletionTokensSet { + chainCallOption = append(chainCallOption, llms.WithMaxCompletionTokens(opts.MaxCompletionTokens)) + } if opts.temperatureSet { chainCallOption = append(chainCallOption, llms.WithTemperature(opts.Temperature)) } @@ -209,3 +224,7 @@ func getLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint: return chainCallOption } + +func GetLLMCallOptions(options ...ChainCallOption) []llms.CallOption { + return getLLMCallOptions(options...) +} diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 46938926e..ae93e271a 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -34,7 +34,7 @@ type ChatRequest struct { Temperature float64 `json:"temperature"` TopP float64 `json:"top_p,omitempty"` // Deprecated: Use MaxCompletionTokens - MaxTokens int `json:"-,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` N int `json:"n,omitempty"` StopWords []string `json:"stop,omitempty"` @@ -297,6 +297,7 @@ type ChatCompletionResponse struct { Object string `json:"object,omitempty"` Usage ChatUsage `json:"usage,omitempty"` SystemFingerprint string `json:"system_fingerprint"` + Citations []string `json:"citations,omitempty"` } type Usage struct { diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 78f8334d2..7c87a0e5d 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -14,6 +14,7 @@ type ChatMessage = openaiclient.ChatMessage type LLM struct { CallbacksHandler callbacks.Handler client *openaiclient.Client + opts *options } const ( @@ -35,6 +36,7 @@ func New(opts ...Option) (*LLM, error) { return &LLM{ client: c, CallbacksHandler: opt.callbackHandler, + opts: opt, }, err } @@ -96,17 +98,16 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten chatMsgs = append(chatMsgs, msg) } req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: chatMsgs, - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - - MaxCompletionTokens: opts.MaxTokens, - + Model: opts.Model, + StopWords: opts.StopWords, + Messages: chatMsgs, + StreamingFunc: opts.StreamingFunc, + Temperature: opts.Temperature, + N: opts.N, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, + MaxTokens: opts.MaxTokens, + MaxCompletionTokens: opts.MaxCompletionTokens, ToolChoice: opts.ToolChoice, FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), Seed: opts.Seed, @@ -116,6 +117,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten req.ResponseFormat = ResponseFormatJSON } + // since req.Functions is deprecated, we need to use the new Tools API. for _, fn := range opts.Functions { req.Tools = append(req.Tools, openaiclient.Tool{ @@ -160,6 +162,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten "PromptTokens": result.Usage.PromptTokens, "TotalTokens": result.Usage.TotalTokens, "ReasoningTokens": result.Usage.CompletionTokensDetails.ReasoningTokens, + "Citations": result.Citations, }, } diff --git a/llms/options.go b/llms/options.go index b6b595290..661c8dc16 100644 --- a/llms/options.go +++ b/llms/options.go @@ -14,6 +14,8 @@ type CallOptions struct { CandidateCount int `json:"candidate_count"` // MaxTokens is the maximum number of tokens to generate. MaxTokens int `json:"max_tokens"` + // MaxCompletionTokens is the maximum number of tokens to generate - used in the current openai API + MaxCompletionTokens int `json:"max_completion_tokens"` // Temperature is the temperature for sampling, between 0 and 1. Temperature float64 `json:"temperature"` // StopWords is a list of words to stop on. @@ -126,6 +128,13 @@ func WithMaxTokens(maxTokens int) CallOption { } } +// WithMaxCompletionTokens specifies the max number of tokens to generate - used in the current openai API +func WithMaxCompletionTokens(maxCompletionTokens int) CallOption { + return func(o *CallOptions) { + o.MaxCompletionTokens = maxCompletionTokens + } +} + // WithCandidateCount specifies the number of response candidates to generate. func WithCandidateCount(c int) CallOption { return func(o *CallOptions) { diff --git a/outputparser/boolean_parser_test.go b/outputparser/boolean_parser_test.go index ee3f5b195..f8214402a 100644 --- a/outputparser/boolean_parser_test.go +++ b/outputparser/boolean_parser_test.go @@ -71,6 +71,7 @@ func TestBooleanParser(t *testing.T) { } for _, tc := range testCases { + tc := tc parser := outputparser.NewBooleanParser() t.Run(tc.input, func(t *testing.T) { diff --git a/outputparser/defined.go b/outputparser/defined.go index 138055566..655a52407 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" @@ -61,14 +62,14 @@ func (p Defined[T]) GetFormatInstructions() string { func (p Defined[T]) Parse(text string) (T, error) { var target T - // Removes '```json' and '```' from the start and end of the text. - const opening = "```json" - const closing = "```" - if text[:len(opening)] != opening || text[len(text)-len(closing):] != closing { - return target, fmt.Errorf("input text should start with %s and end with %s", opening, closing) + startIndex := strings.Index(text, "{") + endIndex := strings.LastIndex(text, "}") + if startIndex == -1 || endIndex == -1 { + return target, fmt.Errorf("could not find start or end of JSON object") } - parseableJSON := text[len(opening) : len(text)-len(closing)] - if err := json.Unmarshal([]byte(parseableJSON), &target); err != nil { + text = text[startIndex : endIndex+1] + + if err := json.Unmarshal([]byte(text), &target); err != nil { return target, fmt.Errorf("could not parse generated JSON: %w", err) } return target, nil diff --git a/outputparser/defined_test.go b/outputparser/defined_test.go index e77ca9cd7..3adef4aea 100644 --- a/outputparser/defined_test.go +++ b/outputparser/defined_test.go @@ -97,37 +97,142 @@ interface Foods { } } -func TestDefinedParse(t *testing.T) { - t.Parallel() - var book struct { - Chapters []struct { - Title string `json:"title" describe:"chapter title"` - } `json:"chapters" describe:"chapters"` - } - parser, newErr := NewDefined(book) - if newErr != nil { - t.Error(newErr) - } +type book struct { + Chapters []struct { + Title string `json:"title" describe:"chapter title"` + } `json:"chapters" describe:"chapters"` +} +func getParseTests() map[string]struct { + input string + expected *book + wantErr bool +} { titles := []string{ "A Hello There", "The Meaty Middle", "The Grand Finale", } - output, parseErr := parser.Parse(fmt.Sprintf("```json\n%s\n```", fmt.Sprintf( - `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], - ))) - if parseErr != nil { - t.Error(parseErr) - } - if count := len(output.Chapters); count != 3 { - t.Errorf("got %d chapters; want 3", count) + return map[string]struct { + input string + expected *book + wantErr bool + }{ + "empty": { + input: "", + wantErr: true, + expected: nil, + }, + "invalid": { + input: "invalid", + wantErr: true, + expected: nil, + }, + "valid": { + input: fmt.Sprintf("```json\n%s\n```", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, + "valid-without-json-tag": { + input: fmt.Sprintf("```\n%s\n```", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, + "valid-without-tags": { + input: fmt.Sprintf("\n%s\n", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, + "llm-explanation-and-tags": { + input: fmt.Sprintf("Sure! Here's the JSON:\n\n```json\n%s\n```\n\nLet me know if you need anything else.", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, + "llm-explanation-and-valid": { + input: fmt.Sprintf("Sure! Here's the JSON:\n\n%s\n\nLet me know if you need anything else.", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, } - for i, chapter := range output.Chapters { - title := titles[i] - if chapter.Title != titles[i] { - t.Errorf("got '%s'; want '%s'", chapter.Title, title) - } +} + +func TestDefinedParse(t *testing.T) { + t.Parallel() + for name, test := range getParseTests() { + t.Run(name, func(t *testing.T) { + t.Parallel() + parser, newErr := NewDefined(book{}) + if newErr != nil { + t.Error(newErr) + } + output, parseErr := parser.Parse(test.input) + switch { + case parseErr != nil && !test.wantErr: + t.Errorf("%s: unexpected error: %v", name, parseErr) + case parseErr == nil && test.wantErr: + t.Errorf("%s: expected error", name) + case parseErr == nil && test.expected != nil: + if count := len(output.Chapters); count != len(test.expected.Chapters) { + t.Errorf("%s: got %d chapters; want %d", name, count, len(test.expected.Chapters)) + } + for i, chapter := range output.Chapters { + title := test.expected.Chapters[i].Title + if chapter.Title != title { + t.Errorf("%s: got '%s'; want '%s'", name, chapter.Title, title) + } + } + } + }) } }