From 6b4e6a44e4ea88c3e9cae745fd3369f012a3bd41 Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Wed, 31 Jul 2024 17:04:29 +0300 Subject: [PATCH 01/17] outputparser: improve BooleanOutputParser The BooleanOutputParser requests the LLM to respond with a boolean, and gives examples such as `true` or `false`. However, it only parsed respones that include YES or NO. This commits adds more values for parsing and changes the tests to fit them. --- outputparser/boolean_parser.go | 26 +++++++++------- outputparser/boolean_parser_test.go | 48 ++++++++++++++++++++++++----- 2 files changed, 56 insertions(+), 18 deletions(-) diff --git a/outputparser/boolean_parser.go b/outputparser/boolean_parser.go index a91d0091d..d25c60047 100644 --- a/outputparser/boolean_parser.go +++ b/outputparser/boolean_parser.go @@ -11,15 +11,15 @@ import ( // BooleanParser is an output parser used to parse the output of an LLM as a boolean. type BooleanParser struct { - TrueStr string - FalseStr string + TrueStrings []string + FalseStrings []string } // NewBooleanParser returns a new BooleanParser. func NewBooleanParser() BooleanParser { return BooleanParser{ - TrueStr: "YES", - FalseStr: "NO", + TrueStrings: []string{"YES", "TRUE"}, + FalseStrings: []string{"NO", "FALSE"}, } } @@ -33,16 +33,20 @@ func (p BooleanParser) GetFormatInstructions() string { func (p BooleanParser) parse(text string) (bool, error) { text = normalize(text) - booleanStrings := []string{p.TrueStr, p.FalseStr} + booleanStrings := append(p.TrueStrings, p.FalseStrings...) - if !slices.Contains(booleanStrings, text) { - return false, ParseError{ - Text: text, - Reason: fmt.Sprintf("Expected output to be either '%s' or '%s', received %s", p.TrueStr, p.FalseStr, text), - } + if slices.Contains(p.TrueStrings, text) { + return true, nil } - return text == p.TrueStr, nil + if slices.Contains(p.FalseStrings, text) { + return false, nil + } + + return false, ParseError{ + Text: text, + Reason: fmt.Sprintf("Expected output to one of %v, received %s", booleanStrings, text), + } } func normalize(text string) string { diff --git a/outputparser/boolean_parser_test.go b/outputparser/boolean_parser_test.go index 9ab92c664..9a0f2a031 100644 --- a/outputparser/boolean_parser_test.go +++ b/outputparser/boolean_parser_test.go @@ -24,6 +24,7 @@ func TestBooleanParser(t *testing.T) { }, { input: "YESNO", + err: outputparser.ParseError{}, expected: false, }, { @@ -31,18 +32,51 @@ func TestBooleanParser(t *testing.T) { err: outputparser.ParseError{}, expected: false, }, + { + input: "true", + expected: true, + }, + { + input: "false", + expected: false, + }, + { + input: "True", + expected: true, + }, + { + input: "False", + expected: false, + }, + { + input: "TRUE", + expected: true, + }, + { + input: "FALSE", + expected: false, + }, } for _, tc := range testCases { + tc := tc parser := outputparser.NewBooleanParser() - actual, err := parser.Parse(tc.input) - if tc.err != nil && err == nil { - t.Errorf("Expected error %v, got nil", tc.err) - } + t.Run(tc.input, func(t *testing.T) { + t.Parallel() + + result, err := parser.Parse(tc.input) + if err != nil && tc.err == nil { + t.Errorf("Unexpected error: %v", err) + } + + if err == nil && tc.err != nil { + t.Errorf("Expected error %v, got nil", tc.err) + } - if actual != tc.expected { - t.Errorf("Expected %v, got %v", tc.expected, actual) - } + if result != tc.expected { + t.Errorf("Expected %v, but got %v", tc.expected, result) + } + }) } } From 1ed2922f7fd5bff569cf5d58509b54413b078a22 Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Wed, 31 Jul 2024 17:04:29 +0300 Subject: [PATCH 02/17] outputparser: improve BooleanOutputParser The BooleanOutputParser requests the LLM to respond with a boolean, and gives examples such as `true` or `false`. However, it only parsed respones that include YES or NO. This commits adds more values for parsing and changes the tests to fit them. --- llms/openai/openaillm.go | 1 + outputparser/boolean_parser.go | 26 +++++++------ outputparser/boolean_parser_test.go | 59 +++++++++++++++++++++++++---- 3 files changed, 68 insertions(+), 18 deletions(-) diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 699c0d304..c5273f247 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -115,6 +115,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten req.ResponseFormat = ResponseFormatJSON } + // since req.Functions is deprecated, we need to use the new Tools API. for _, fn := range opts.Functions { req.Tools = append(req.Tools, openaiclient.Tool{ diff --git a/outputparser/boolean_parser.go b/outputparser/boolean_parser.go index a91d0091d..46970771c 100644 --- a/outputparser/boolean_parser.go +++ b/outputparser/boolean_parser.go @@ -11,15 +11,15 @@ import ( // BooleanParser is an output parser used to parse the output of an LLM as a boolean. type BooleanParser struct { - TrueStr string - FalseStr string + TrueStrings []string + FalseStrings []string } // NewBooleanParser returns a new BooleanParser. func NewBooleanParser() BooleanParser { return BooleanParser{ - TrueStr: "YES", - FalseStr: "NO", + TrueStrings: []string{"YES", "TRUE"}, + FalseStrings: []string{"NO", "FALSE"}, } } @@ -33,20 +33,24 @@ func (p BooleanParser) GetFormatInstructions() string { func (p BooleanParser) parse(text string) (bool, error) { text = normalize(text) - booleanStrings := []string{p.TrueStr, p.FalseStr} - if !slices.Contains(booleanStrings, text) { - return false, ParseError{ - Text: text, - Reason: fmt.Sprintf("Expected output to be either '%s' or '%s', received %s", p.TrueStr, p.FalseStr, text), - } + if slices.Contains(p.TrueStrings, text) { + return true, nil } - return text == p.TrueStr, nil + if slices.Contains(p.FalseStrings, text) { + return false, nil + } + + return false, ParseError{ + Text: text, + Reason: fmt.Sprintf("Expected output to one of %v, received %s", append(p.TrueStrings, p.FalseStrings...), text), + } } func normalize(text string) string { text = strings.TrimSpace(text) + text = strings.Trim(text, "'\"`") text = strings.ToUpper(text) return text diff --git a/outputparser/boolean_parser_test.go b/outputparser/boolean_parser_test.go index 9ab92c664..ee3f5b195 100644 --- a/outputparser/boolean_parser_test.go +++ b/outputparser/boolean_parser_test.go @@ -24,6 +24,7 @@ func TestBooleanParser(t *testing.T) { }, { input: "YESNO", + err: outputparser.ParseError{}, expected: false, }, { @@ -31,18 +32,62 @@ func TestBooleanParser(t *testing.T) { err: outputparser.ParseError{}, expected: false, }, + { + input: "true", + expected: true, + }, + { + input: "false", + expected: false, + }, + { + input: "True", + expected: true, + }, + { + input: "False", + expected: false, + }, + { + input: "TRUE", + expected: true, + }, + { + input: "FALSE", + expected: false, + }, + { + input: "'TRUE'", + expected: true, + }, + { + input: "`TRUE`", + expected: true, + }, + { + input: "'TRUE`", + expected: true, + }, } for _, tc := range testCases { parser := outputparser.NewBooleanParser() - actual, err := parser.Parse(tc.input) - if tc.err != nil && err == nil { - t.Errorf("Expected error %v, got nil", tc.err) - } + t.Run(tc.input, func(t *testing.T) { + t.Parallel() + + result, err := parser.Parse(tc.input) + if err != nil && tc.err == nil { + t.Errorf("Unexpected error: %v", err) + } + + if err == nil && tc.err != nil { + t.Errorf("Expected error %v, got nil", tc.err) + } - if actual != tc.expected { - t.Errorf("Expected %v, got %v", tc.expected, actual) - } + if result != tc.expected { + t.Errorf("Expected %v, but got %v", tc.expected, result) + } + }) } } From abd68a903604de1bee835d595bbd004d5d76531a Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Sun, 4 Aug 2024 18:59:02 +0300 Subject: [PATCH 03/17] outputparser: improve DefinedOutputParser The DefinedOutputParser prompted the LLM with a Typescript schema, but expected a json in response: https://github.com/tmc/langchaingo/blob/1975058648b5914fdd9dc53434c5b59f219e2b5c/outputparser/defined.go\#L65-69 Now it also requests a json. --- outputparser/defined.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/outputparser/defined.go b/outputparser/defined.go index e56a6a6f5..138055566 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -53,7 +53,7 @@ var _ schema.OutputParser[any] = Defined[any]{} // GetFormatInstructions returns a string describing the format of the output. func (p Defined[T]) GetFormatInstructions() string { - const instructions = "Your output should be in JSON, structured according to this TypeScript:\n```typescript\n%s\n```" + const instructions = "Your output should be in JSON, structured according to this schema:\n```json\n%s\n```" return fmt.Sprintf(instructions, p.schema) } From ac550958d4092daf19965242cf19bc422f1a83b3 Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Wed, 21 Aug 2024 15:11:37 +0300 Subject: [PATCH 04/17] outputparser: allow json to start with ``` Allow the returning schema to start with ```, not just ```json --- outputparser/defined.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/outputparser/defined.go b/outputparser/defined.go index 138055566..021712aef 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -62,12 +62,22 @@ func (p Defined[T]) Parse(text string) (T, error) { var target T // Removes '```json' and '```' from the start and end of the text. - const opening = "```json" + const opening1 = "```json" + const opening2 = "```" + switch { + case len(text) >= len(opening1) && text[:len(opening1)] == opening1: + text = text[len(opening1):] + case len(text) >= len(opening2) && text[:len(opening2)] == opening2: + text = text[len(opening2):] + default: + return target, errors.New("input text should start with '```json' or '```'") + } + const closing = "```" - if text[:len(opening)] != opening || text[len(text)-len(closing):] != closing { - return target, fmt.Errorf("input text should start with %s and end with %s", opening, closing) + if len(text) >= len(closing) && text[len(text)-len(closing):] != closing { + return target, fmt.Errorf("input text should end with %s", closing) } - parseableJSON := text[len(opening) : len(text)-len(closing)] + parseableJSON := text[:len(text)-len(closing)] if err := json.Unmarshal([]byte(parseableJSON), &target); err != nil { return target, fmt.Errorf("could not parse generated JSON: %w", err) } From 10d2006c98f192544522d3e3e61bc6c45acdda56 Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Thu, 21 Nov 2024 16:07:26 +0200 Subject: [PATCH 05/17] outputparser: allow json to start with any prefix Now the output can start with ```json, ``` or nothing; and end with ``` or nothing. --- outputparser/defined.go | 9 +-- outputparser/defined_test.go | 123 ++++++++++++++++++++++++++++------- 2 files changed, 102 insertions(+), 30 deletions(-) diff --git a/outputparser/defined.go b/outputparser/defined.go index 021712aef..4df7b0cda 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -69,16 +69,13 @@ func (p Defined[T]) Parse(text string) (T, error) { text = text[len(opening1):] case len(text) >= len(opening2) && text[:len(opening2)] == opening2: text = text[len(opening2):] - default: - return target, errors.New("input text should start with '```json' or '```'") } const closing = "```" - if len(text) >= len(closing) && text[len(text)-len(closing):] != closing { - return target, fmt.Errorf("input text should end with %s", closing) + if len(text) >= len(closing) && text[len(text)-len(closing):] == closing { + text = text[:len(text)-len(closing)] } - parseableJSON := text[:len(text)-len(closing)] - if err := json.Unmarshal([]byte(parseableJSON), &target); err != nil { + if err := json.Unmarshal([]byte(text), &target); err != nil { return target, fmt.Errorf("could not parse generated JSON: %w", err) } return target, nil diff --git a/outputparser/defined_test.go b/outputparser/defined_test.go index e77ca9cd7..85e3ab470 100644 --- a/outputparser/defined_test.go +++ b/outputparser/defined_test.go @@ -97,37 +97,112 @@ interface Foods { } } -func TestDefinedParse(t *testing.T) { - t.Parallel() - var book struct { - Chapters []struct { - Title string `json:"title" describe:"chapter title"` - } `json:"chapters" describe:"chapters"` - } - parser, newErr := NewDefined(book) - if newErr != nil { - t.Error(newErr) - } +type book struct { + Chapters []struct { + Title string `json:"title" describe:"chapter title"` + } `json:"chapters" describe:"chapters"` +} +func getParseTests() map[string]struct { + input string + expected *book + wantErr bool +} { titles := []string{ "A Hello There", "The Meaty Middle", "The Grand Finale", } - output, parseErr := parser.Parse(fmt.Sprintf("```json\n%s\n```", fmt.Sprintf( - `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], - ))) - if parseErr != nil { - t.Error(parseErr) - } - if count := len(output.Chapters); count != 3 { - t.Errorf("got %d chapters; want 3", count) + return map[string]struct { + input string + expected *book + wantErr bool + }{ + "empty": { + input: "", + wantErr: true, + expected: nil, + }, + "invalid": { + input: "invalid", + wantErr: true, + expected: nil, + }, + "valid": { + input: fmt.Sprintf("```json\n%s\n```", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, + "valid-without-json-tag": { + input: fmt.Sprintf("```\n%s\n```", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, + "valid-without-tags": { + input: fmt.Sprintf("\n%s\n", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, } - for i, chapter := range output.Chapters { - title := titles[i] - if chapter.Title != titles[i] { - t.Errorf("got '%s'; want '%s'", chapter.Title, title) - } +} + +func TestDefinedParse(t *testing.T) { + t.Parallel() + for name, test := range getParseTests() { + t.Run(name, func(t *testing.T) { + t.Parallel() + parser, newErr := NewDefined(book{}) + if newErr != nil { + t.Error(newErr) + } + output, parseErr := parser.Parse(test.input) + switch { + case parseErr != nil && !test.wantErr: + t.Errorf("%s: unexpected error: %v", name, parseErr) + case parseErr == nil && test.wantErr: + t.Errorf("%s: expected error", name) + case parseErr == nil && test.expected != nil: + if count := len(output.Chapters); count != len(test.expected.Chapters) { + t.Errorf("%s: got %d chapters; want %d", name, count, len(test.expected.Chapters)) + } + for i, chapter := range output.Chapters { + title := test.expected.Chapters[i].Title + if chapter.Title != title { + t.Errorf("%s: got '%s'; want '%s'", name, chapter.Title, title) + } + } + } + }) } } From 99cd08aa8a4c6634cd070f6df70a522aca826e2c Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Tue, 26 Nov 2024 10:33:05 +0200 Subject: [PATCH 06/17] Handle LLM verbosity --- outputparser/defined.go | 18 +++++++++++------- outputparser/defined_test.go | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/outputparser/defined.go b/outputparser/defined.go index 4df7b0cda..8cdfadea0 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" @@ -64,16 +65,19 @@ func (p Defined[T]) Parse(text string) (T, error) { // Removes '```json' and '```' from the start and end of the text. const opening1 = "```json" const opening2 = "```" - switch { - case len(text) >= len(opening1) && text[:len(opening1)] == opening1: - text = text[len(opening1):] - case len(text) >= len(opening2) && text[:len(opening2)] == opening2: - text = text[len(opening2):] + + opening1Index := strings.Index(text, opening1) + opening2Index := strings.Index(text, opening2) + if opening1Index != -1 { + text = text[opening1Index+len(opening1):] + } else if opening2Index != -1 { + text = text[opening2Index+len(opening2):] } const closing = "```" - if len(text) >= len(closing) && text[len(text)-len(closing):] == closing { - text = text[:len(text)-len(closing)] + closingIndex := strings.Index(text, closing) + if closingIndex != -1 { + text = text[:closingIndex] } if err := json.Unmarshal([]byte(text), &target); err != nil { return target, fmt.Errorf("could not parse generated JSON: %w", err) diff --git a/outputparser/defined_test.go b/outputparser/defined_test.go index 85e3ab470..3e142300e 100644 --- a/outputparser/defined_test.go +++ b/outputparser/defined_test.go @@ -174,6 +174,21 @@ func getParseTests() map[string]struct { }, }, }, + "llm-explanation-and-valid": { + input: fmt.Sprintf("Sure! Here's the JSON:\n\n```json\n%s\n```\n\nLet me know if you need anything else.", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, } } From f0af39fbbc6c37f50a46bcb2e4c8be1bc1c63f7c Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Tue, 26 Nov 2024 11:14:13 +0200 Subject: [PATCH 07/17] Just use brackets --- outputparser/defined.go | 20 +++++--------------- outputparser/defined_test.go | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/outputparser/defined.go b/outputparser/defined.go index 8cdfadea0..655a52407 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -62,23 +62,13 @@ func (p Defined[T]) GetFormatInstructions() string { func (p Defined[T]) Parse(text string) (T, error) { var target T - // Removes '```json' and '```' from the start and end of the text. - const opening1 = "```json" - const opening2 = "```" - - opening1Index := strings.Index(text, opening1) - opening2Index := strings.Index(text, opening2) - if opening1Index != -1 { - text = text[opening1Index+len(opening1):] - } else if opening2Index != -1 { - text = text[opening2Index+len(opening2):] + startIndex := strings.Index(text, "{") + endIndex := strings.LastIndex(text, "}") + if startIndex == -1 || endIndex == -1 { + return target, fmt.Errorf("could not find start or end of JSON object") } + text = text[startIndex : endIndex+1] - const closing = "```" - closingIndex := strings.Index(text, closing) - if closingIndex != -1 { - text = text[:closingIndex] - } if err := json.Unmarshal([]byte(text), &target); err != nil { return target, fmt.Errorf("could not parse generated JSON: %w", err) } diff --git a/outputparser/defined_test.go b/outputparser/defined_test.go index 3e142300e..3adef4aea 100644 --- a/outputparser/defined_test.go +++ b/outputparser/defined_test.go @@ -174,7 +174,7 @@ func getParseTests() map[string]struct { }, }, }, - "llm-explanation-and-valid": { + "llm-explanation-and-tags": { input: fmt.Sprintf("Sure! Here's the JSON:\n\n```json\n%s\n```\n\nLet me know if you need anything else.", fmt.Sprintf( `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], )), @@ -189,6 +189,21 @@ func getParseTests() map[string]struct { }, }, }, + "llm-explanation-and-valid": { + input: fmt.Sprintf("Sure! Here's the JSON:\n\n%s\n\nLet me know if you need anything else.", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, } } From 363a03ee9576073d91e2691842a4d6077cfb9e94 Mon Sep 17 00:00:00 2001 From: Oryan Moshe <43927816+oryanmoshe@users.noreply.github.com> Date: Tue, 10 Dec 2024 14:08:37 +0200 Subject: [PATCH 08/17] =?UTF-8?q?=E2=9C=A8feat:=20Add=20citations=20to=20o?= =?UTF-8?q?penai=20response?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llms/openai/internal/openaiclient/chat.go | 3 ++- llms/openai/openaillm.go | 19 ++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 2bb572a0a..d6e7cdd22 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -34,7 +34,7 @@ type ChatRequest struct { Temperature float64 `json:"temperature"` TopP float64 `json:"top_p,omitempty"` // Deprecated: Use MaxCompletionTokens - MaxTokens int `json:"-,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` N int `json:"n,omitempty"` StopWords []string `json:"stop,omitempty"` @@ -297,6 +297,7 @@ type ChatCompletionResponse struct { Object string `json:"object,omitempty"` Usage ChatUsage `json:"usage,omitempty"` SystemFingerprint string `json:"system_fingerprint"` + Citations []string `json:"citations,omitempty"` } type Usage struct { diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 78f8334d2..3b60f1257 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -96,15 +96,15 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten chatMsgs = append(chatMsgs, msg) } req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: chatMsgs, - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - + Model: opts.Model, + StopWords: opts.StopWords, + Messages: chatMsgs, + StreamingFunc: opts.StreamingFunc, + Temperature: opts.Temperature, + N: opts.N, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, + MaxTokens: opts.MaxTokens, MaxCompletionTokens: opts.MaxTokens, ToolChoice: opts.ToolChoice, @@ -160,6 +160,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten "PromptTokens": result.Usage.PromptTokens, "TotalTokens": result.Usage.TotalTokens, "ReasoningTokens": result.Usage.CompletionTokensDetails.ReasoningTokens, + "Citations": result.Citations, }, } From 0960a3be8e0f6c5b79f71582af0e989629e3ad88 Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Tue, 7 Jan 2025 12:34:34 +0200 Subject: [PATCH 09/17] Handle max completion tokens and max tokens --- chains/options.go | 19 +++++++++++++++++++ llms/openai/openaillm.go | 23 ++++++++++++----------- llms/options.go | 9 +++++++++ 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/chains/options.go b/chains/options.go index bb77ae7d0..6d5e08d7e 100644 --- a/chains/options.go +++ b/chains/options.go @@ -27,6 +27,10 @@ type chainCallOption struct { MaxTokens int maxTokensSet bool + // MaxCompletionTokens is the maximum number of tokens to generate - used in the current openai API + MaxCompletionTokens int + maxCompletionTokensSet bool + // Temperature is the temperature for sampling to use in an LLM call, between 0 and 1. Temperature float64 temperatureSet bool @@ -83,6 +87,14 @@ func WithMaxTokens(maxTokens int) ChainCallOption { } } +// WithMaxCompletionTokens is an option for LLM.Call. +func WithMaxCompletionTokens(maxCompletionTokens int) ChainCallOption { + return func(o *chainCallOption) { + o.MaxCompletionTokens = maxCompletionTokens + o.maxCompletionTokensSet = true + } +} + // WithTemperature is an option for LLM.Call. func WithTemperature(temperature float64) ChainCallOption { return func(o *chainCallOption) { @@ -181,6 +193,9 @@ func getLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint: if opts.maxTokensSet { chainCallOption = append(chainCallOption, llms.WithMaxTokens(opts.MaxTokens)) } + if opts.maxCompletionTokensSet { + chainCallOption = append(chainCallOption, llms.WithMaxCompletionTokens(opts.MaxCompletionTokens)) + } if opts.temperatureSet { chainCallOption = append(chainCallOption, llms.WithTemperature(opts.Temperature)) } @@ -209,3 +224,7 @@ func getLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint: return chainCallOption } + +func GetLLMCallOptions(options ...ChainCallOption) []llms.CallOption { + return getLLMCallOptions(options...) +} diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 3b60f1257..183822534 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -14,6 +14,7 @@ type ChatMessage = openaiclient.ChatMessage type LLM struct { CallbacksHandler callbacks.Handler client *openaiclient.Client + opts *options } const ( @@ -35,6 +36,7 @@ func New(opts ...Option) (*LLM, error) { return &LLM{ client: c, CallbacksHandler: opt.callbackHandler, + opts: opt, }, err } @@ -96,17 +98,16 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten chatMsgs = append(chatMsgs, msg) } req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: chatMsgs, - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - MaxTokens: opts.MaxTokens, - MaxCompletionTokens: opts.MaxTokens, - + Model: opts.Model, + StopWords: opts.StopWords, + Messages: chatMsgs, + StreamingFunc: opts.StreamingFunc, + Temperature: opts.Temperature, + N: opts.N, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, + MaxTokens: opts.MaxTokens, + MaxCompletionTokens: opts.MaxCompletionTokens, ToolChoice: opts.ToolChoice, FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), Seed: opts.Seed, diff --git a/llms/options.go b/llms/options.go index b6b595290..661c8dc16 100644 --- a/llms/options.go +++ b/llms/options.go @@ -14,6 +14,8 @@ type CallOptions struct { CandidateCount int `json:"candidate_count"` // MaxTokens is the maximum number of tokens to generate. MaxTokens int `json:"max_tokens"` + // MaxCompletionTokens is the maximum number of tokens to generate - used in the current openai API + MaxCompletionTokens int `json:"max_completion_tokens"` // Temperature is the temperature for sampling, between 0 and 1. Temperature float64 `json:"temperature"` // StopWords is a list of words to stop on. @@ -126,6 +128,13 @@ func WithMaxTokens(maxTokens int) CallOption { } } +// WithMaxCompletionTokens specifies the max number of tokens to generate - used in the current openai API +func WithMaxCompletionTokens(maxCompletionTokens int) CallOption { + return func(o *CallOptions) { + o.MaxCompletionTokens = maxCompletionTokens + } +} + // WithCandidateCount specifies the number of response candidates to generate. func WithCandidateCount(c int) CallOption { return func(o *CallOptions) { From 0eac5bcfdda2cbb1edd32037848aea15d0ffb1a1 Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Wed, 31 Jul 2024 17:04:29 +0300 Subject: [PATCH 10/17] outputparser: improve BooleanOutputParser The BooleanOutputParser requests the LLM to respond with a boolean, and gives examples such as `true` or `false`. However, it only parsed respones that include YES or NO. This commits adds more values for parsing and changes the tests to fit them. --- llms/openai/openaillm.go | 1 + 1 file changed, 1 insertion(+) diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 78f8334d2..5a100f6b6 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -116,6 +116,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten req.ResponseFormat = ResponseFormatJSON } + // since req.Functions is deprecated, we need to use the new Tools API. for _, fn := range opts.Functions { req.Tools = append(req.Tools, openaiclient.Tool{ From faea3924dd7828b897829bccfe0f55af07889fd4 Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Mon, 27 Jan 2025 12:11:47 +0200 Subject: [PATCH 11/17] outputparser: improve BooleanOutputParser The BooleanOutputParser requests the LLM to respond with a boolean, and gives examples such as `true` or `false`. However, it only parsed respones that include YES or NO. This commits adds more values for parsing and changes the tests to fit them. --- outputparser/boolean_parser_test.go | 1 + 1 file changed, 1 insertion(+) diff --git a/outputparser/boolean_parser_test.go b/outputparser/boolean_parser_test.go index ee3f5b195..f8214402a 100644 --- a/outputparser/boolean_parser_test.go +++ b/outputparser/boolean_parser_test.go @@ -71,6 +71,7 @@ func TestBooleanParser(t *testing.T) { } for _, tc := range testCases { + tc := tc parser := outputparser.NewBooleanParser() t.Run(tc.input, func(t *testing.T) { From 033ac65e2d8230f93e779a64fba03289e37f152b Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Wed, 21 Aug 2024 15:11:37 +0300 Subject: [PATCH 12/17] outputparser: allow json to start with ``` Allow the returning schema to start with ```, not just ```json --- outputparser/defined.go | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/outputparser/defined.go b/outputparser/defined.go index 138055566..021712aef 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -62,12 +62,22 @@ func (p Defined[T]) Parse(text string) (T, error) { var target T // Removes '```json' and '```' from the start and end of the text. - const opening = "```json" + const opening1 = "```json" + const opening2 = "```" + switch { + case len(text) >= len(opening1) && text[:len(opening1)] == opening1: + text = text[len(opening1):] + case len(text) >= len(opening2) && text[:len(opening2)] == opening2: + text = text[len(opening2):] + default: + return target, errors.New("input text should start with '```json' or '```'") + } + const closing = "```" - if text[:len(opening)] != opening || text[len(text)-len(closing):] != closing { - return target, fmt.Errorf("input text should start with %s and end with %s", opening, closing) + if len(text) >= len(closing) && text[len(text)-len(closing):] != closing { + return target, fmt.Errorf("input text should end with %s", closing) } - parseableJSON := text[len(opening) : len(text)-len(closing)] + parseableJSON := text[:len(text)-len(closing)] if err := json.Unmarshal([]byte(parseableJSON), &target); err != nil { return target, fmt.Errorf("could not parse generated JSON: %w", err) } From a952c8422d8285158e4de4801ec71ab02e874069 Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Thu, 21 Nov 2024 16:07:26 +0200 Subject: [PATCH 13/17] outputparser: allow json to start with any prefix Now the output can start with ```json, ``` or nothing; and end with ``` or nothing. --- outputparser/defined.go | 9 +-- outputparser/defined_test.go | 123 ++++++++++++++++++++++++++++------- 2 files changed, 102 insertions(+), 30 deletions(-) diff --git a/outputparser/defined.go b/outputparser/defined.go index 021712aef..4df7b0cda 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -69,16 +69,13 @@ func (p Defined[T]) Parse(text string) (T, error) { text = text[len(opening1):] case len(text) >= len(opening2) && text[:len(opening2)] == opening2: text = text[len(opening2):] - default: - return target, errors.New("input text should start with '```json' or '```'") } const closing = "```" - if len(text) >= len(closing) && text[len(text)-len(closing):] != closing { - return target, fmt.Errorf("input text should end with %s", closing) + if len(text) >= len(closing) && text[len(text)-len(closing):] == closing { + text = text[:len(text)-len(closing)] } - parseableJSON := text[:len(text)-len(closing)] - if err := json.Unmarshal([]byte(parseableJSON), &target); err != nil { + if err := json.Unmarshal([]byte(text), &target); err != nil { return target, fmt.Errorf("could not parse generated JSON: %w", err) } return target, nil diff --git a/outputparser/defined_test.go b/outputparser/defined_test.go index e77ca9cd7..85e3ab470 100644 --- a/outputparser/defined_test.go +++ b/outputparser/defined_test.go @@ -97,37 +97,112 @@ interface Foods { } } -func TestDefinedParse(t *testing.T) { - t.Parallel() - var book struct { - Chapters []struct { - Title string `json:"title" describe:"chapter title"` - } `json:"chapters" describe:"chapters"` - } - parser, newErr := NewDefined(book) - if newErr != nil { - t.Error(newErr) - } +type book struct { + Chapters []struct { + Title string `json:"title" describe:"chapter title"` + } `json:"chapters" describe:"chapters"` +} +func getParseTests() map[string]struct { + input string + expected *book + wantErr bool +} { titles := []string{ "A Hello There", "The Meaty Middle", "The Grand Finale", } - output, parseErr := parser.Parse(fmt.Sprintf("```json\n%s\n```", fmt.Sprintf( - `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], - ))) - if parseErr != nil { - t.Error(parseErr) - } - if count := len(output.Chapters); count != 3 { - t.Errorf("got %d chapters; want 3", count) + return map[string]struct { + input string + expected *book + wantErr bool + }{ + "empty": { + input: "", + wantErr: true, + expected: nil, + }, + "invalid": { + input: "invalid", + wantErr: true, + expected: nil, + }, + "valid": { + input: fmt.Sprintf("```json\n%s\n```", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, + "valid-without-json-tag": { + input: fmt.Sprintf("```\n%s\n```", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, + "valid-without-tags": { + input: fmt.Sprintf("\n%s\n", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, } - for i, chapter := range output.Chapters { - title := titles[i] - if chapter.Title != titles[i] { - t.Errorf("got '%s'; want '%s'", chapter.Title, title) - } +} + +func TestDefinedParse(t *testing.T) { + t.Parallel() + for name, test := range getParseTests() { + t.Run(name, func(t *testing.T) { + t.Parallel() + parser, newErr := NewDefined(book{}) + if newErr != nil { + t.Error(newErr) + } + output, parseErr := parser.Parse(test.input) + switch { + case parseErr != nil && !test.wantErr: + t.Errorf("%s: unexpected error: %v", name, parseErr) + case parseErr == nil && test.wantErr: + t.Errorf("%s: expected error", name) + case parseErr == nil && test.expected != nil: + if count := len(output.Chapters); count != len(test.expected.Chapters) { + t.Errorf("%s: got %d chapters; want %d", name, count, len(test.expected.Chapters)) + } + for i, chapter := range output.Chapters { + title := test.expected.Chapters[i].Title + if chapter.Title != title { + t.Errorf("%s: got '%s'; want '%s'", name, chapter.Title, title) + } + } + } + }) } } From bdf674ab68fc250609ef2a1088f345edc862ba18 Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Tue, 26 Nov 2024 10:33:05 +0200 Subject: [PATCH 14/17] Handle LLM verbosity --- outputparser/defined.go | 18 +++++++++++------- outputparser/defined_test.go | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 7 deletions(-) diff --git a/outputparser/defined.go b/outputparser/defined.go index 4df7b0cda..8cdfadea0 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -6,6 +6,7 @@ import ( "errors" "fmt" "reflect" + "strings" "github.com/tmc/langchaingo/llms" "github.com/tmc/langchaingo/schema" @@ -64,16 +65,19 @@ func (p Defined[T]) Parse(text string) (T, error) { // Removes '```json' and '```' from the start and end of the text. const opening1 = "```json" const opening2 = "```" - switch { - case len(text) >= len(opening1) && text[:len(opening1)] == opening1: - text = text[len(opening1):] - case len(text) >= len(opening2) && text[:len(opening2)] == opening2: - text = text[len(opening2):] + + opening1Index := strings.Index(text, opening1) + opening2Index := strings.Index(text, opening2) + if opening1Index != -1 { + text = text[opening1Index+len(opening1):] + } else if opening2Index != -1 { + text = text[opening2Index+len(opening2):] } const closing = "```" - if len(text) >= len(closing) && text[len(text)-len(closing):] == closing { - text = text[:len(text)-len(closing)] + closingIndex := strings.Index(text, closing) + if closingIndex != -1 { + text = text[:closingIndex] } if err := json.Unmarshal([]byte(text), &target); err != nil { return target, fmt.Errorf("could not parse generated JSON: %w", err) diff --git a/outputparser/defined_test.go b/outputparser/defined_test.go index 85e3ab470..3e142300e 100644 --- a/outputparser/defined_test.go +++ b/outputparser/defined_test.go @@ -174,6 +174,21 @@ func getParseTests() map[string]struct { }, }, }, + "llm-explanation-and-valid": { + input: fmt.Sprintf("Sure! Here's the JSON:\n\n```json\n%s\n```\n\nLet me know if you need anything else.", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, } } From c1cefd0aa77f49c1bf616bf6e52dfa143055cc9b Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Tue, 26 Nov 2024 11:14:13 +0200 Subject: [PATCH 15/17] Just use brackets --- outputparser/defined.go | 20 +++++--------------- outputparser/defined_test.go | 17 ++++++++++++++++- 2 files changed, 21 insertions(+), 16 deletions(-) diff --git a/outputparser/defined.go b/outputparser/defined.go index 8cdfadea0..655a52407 100644 --- a/outputparser/defined.go +++ b/outputparser/defined.go @@ -62,23 +62,13 @@ func (p Defined[T]) GetFormatInstructions() string { func (p Defined[T]) Parse(text string) (T, error) { var target T - // Removes '```json' and '```' from the start and end of the text. - const opening1 = "```json" - const opening2 = "```" - - opening1Index := strings.Index(text, opening1) - opening2Index := strings.Index(text, opening2) - if opening1Index != -1 { - text = text[opening1Index+len(opening1):] - } else if opening2Index != -1 { - text = text[opening2Index+len(opening2):] + startIndex := strings.Index(text, "{") + endIndex := strings.LastIndex(text, "}") + if startIndex == -1 || endIndex == -1 { + return target, fmt.Errorf("could not find start or end of JSON object") } + text = text[startIndex : endIndex+1] - const closing = "```" - closingIndex := strings.Index(text, closing) - if closingIndex != -1 { - text = text[:closingIndex] - } if err := json.Unmarshal([]byte(text), &target); err != nil { return target, fmt.Errorf("could not parse generated JSON: %w", err) } diff --git a/outputparser/defined_test.go b/outputparser/defined_test.go index 3e142300e..3adef4aea 100644 --- a/outputparser/defined_test.go +++ b/outputparser/defined_test.go @@ -174,7 +174,7 @@ func getParseTests() map[string]struct { }, }, }, - "llm-explanation-and-valid": { + "llm-explanation-and-tags": { input: fmt.Sprintf("Sure! Here's the JSON:\n\n```json\n%s\n```\n\nLet me know if you need anything else.", fmt.Sprintf( `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], )), @@ -189,6 +189,21 @@ func getParseTests() map[string]struct { }, }, }, + "llm-explanation-and-valid": { + input: fmt.Sprintf("Sure! Here's the JSON:\n\n%s\n\nLet me know if you need anything else.", fmt.Sprintf( + `{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2], + )), + wantErr: false, + expected: &book{ + Chapters: []struct { + Title string `json:"title" describe:"chapter title"` + }{ + {Title: titles[0]}, + {Title: titles[1]}, + {Title: titles[2]}, + }, + }, + }, } } From 7e21f43b4b10633be3231f872afc15335303ee07 Mon Sep 17 00:00:00 2001 From: Oryan Moshe <43927816+oryanmoshe@users.noreply.github.com> Date: Tue, 10 Dec 2024 14:08:37 +0200 Subject: [PATCH 16/17] =?UTF-8?q?=E2=9C=A8feat:=20Add=20citations=20to=20o?= =?UTF-8?q?penai=20response?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- llms/openai/internal/openaiclient/chat.go | 3 ++- llms/openai/openaillm.go | 19 ++++++++++--------- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/llms/openai/internal/openaiclient/chat.go b/llms/openai/internal/openaiclient/chat.go index 46938926e..ae93e271a 100644 --- a/llms/openai/internal/openaiclient/chat.go +++ b/llms/openai/internal/openaiclient/chat.go @@ -34,7 +34,7 @@ type ChatRequest struct { Temperature float64 `json:"temperature"` TopP float64 `json:"top_p,omitempty"` // Deprecated: Use MaxCompletionTokens - MaxTokens int `json:"-,omitempty"` + MaxTokens int `json:"max_tokens,omitempty"` MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` N int `json:"n,omitempty"` StopWords []string `json:"stop,omitempty"` @@ -297,6 +297,7 @@ type ChatCompletionResponse struct { Object string `json:"object,omitempty"` Usage ChatUsage `json:"usage,omitempty"` SystemFingerprint string `json:"system_fingerprint"` + Citations []string `json:"citations,omitempty"` } type Usage struct { diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 5a100f6b6..69a4fabd5 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -96,15 +96,15 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten chatMsgs = append(chatMsgs, msg) } req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: chatMsgs, - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - + Model: opts.Model, + StopWords: opts.StopWords, + Messages: chatMsgs, + StreamingFunc: opts.StreamingFunc, + Temperature: opts.Temperature, + N: opts.N, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, + MaxTokens: opts.MaxTokens, MaxCompletionTokens: opts.MaxTokens, ToolChoice: opts.ToolChoice, @@ -161,6 +161,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten "PromptTokens": result.Usage.PromptTokens, "TotalTokens": result.Usage.TotalTokens, "ReasoningTokens": result.Usage.CompletionTokensDetails.ReasoningTokens, + "Citations": result.Citations, }, } From 7701e3abf500112e99670710535dac1ece21c7bc Mon Sep 17 00:00:00 2001 From: amitaifrey Date: Tue, 7 Jan 2025 12:34:34 +0200 Subject: [PATCH 17/17] Handle max completion tokens and max tokens --- chains/options.go | 19 +++++++++++++++++++ llms/openai/openaillm.go | 23 ++++++++++++----------- llms/options.go | 9 +++++++++ 3 files changed, 40 insertions(+), 11 deletions(-) diff --git a/chains/options.go b/chains/options.go index bb77ae7d0..6d5e08d7e 100644 --- a/chains/options.go +++ b/chains/options.go @@ -27,6 +27,10 @@ type chainCallOption struct { MaxTokens int maxTokensSet bool + // MaxCompletionTokens is the maximum number of tokens to generate - used in the current openai API + MaxCompletionTokens int + maxCompletionTokensSet bool + // Temperature is the temperature for sampling to use in an LLM call, between 0 and 1. Temperature float64 temperatureSet bool @@ -83,6 +87,14 @@ func WithMaxTokens(maxTokens int) ChainCallOption { } } +// WithMaxCompletionTokens is an option for LLM.Call. +func WithMaxCompletionTokens(maxCompletionTokens int) ChainCallOption { + return func(o *chainCallOption) { + o.MaxCompletionTokens = maxCompletionTokens + o.maxCompletionTokensSet = true + } +} + // WithTemperature is an option for LLM.Call. func WithTemperature(temperature float64) ChainCallOption { return func(o *chainCallOption) { @@ -181,6 +193,9 @@ func getLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint: if opts.maxTokensSet { chainCallOption = append(chainCallOption, llms.WithMaxTokens(opts.MaxTokens)) } + if opts.maxCompletionTokensSet { + chainCallOption = append(chainCallOption, llms.WithMaxCompletionTokens(opts.MaxCompletionTokens)) + } if opts.temperatureSet { chainCallOption = append(chainCallOption, llms.WithTemperature(opts.Temperature)) } @@ -209,3 +224,7 @@ func getLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint: return chainCallOption } + +func GetLLMCallOptions(options ...ChainCallOption) []llms.CallOption { + return getLLMCallOptions(options...) +} diff --git a/llms/openai/openaillm.go b/llms/openai/openaillm.go index 69a4fabd5..7c87a0e5d 100644 --- a/llms/openai/openaillm.go +++ b/llms/openai/openaillm.go @@ -14,6 +14,7 @@ type ChatMessage = openaiclient.ChatMessage type LLM struct { CallbacksHandler callbacks.Handler client *openaiclient.Client + opts *options } const ( @@ -35,6 +36,7 @@ func New(opts ...Option) (*LLM, error) { return &LLM{ client: c, CallbacksHandler: opt.callbackHandler, + opts: opt, }, err } @@ -96,17 +98,16 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten chatMsgs = append(chatMsgs, msg) } req := &openaiclient.ChatRequest{ - Model: opts.Model, - StopWords: opts.StopWords, - Messages: chatMsgs, - StreamingFunc: opts.StreamingFunc, - Temperature: opts.Temperature, - N: opts.N, - FrequencyPenalty: opts.FrequencyPenalty, - PresencePenalty: opts.PresencePenalty, - MaxTokens: opts.MaxTokens, - MaxCompletionTokens: opts.MaxTokens, - + Model: opts.Model, + StopWords: opts.StopWords, + Messages: chatMsgs, + StreamingFunc: opts.StreamingFunc, + Temperature: opts.Temperature, + N: opts.N, + FrequencyPenalty: opts.FrequencyPenalty, + PresencePenalty: opts.PresencePenalty, + MaxTokens: opts.MaxTokens, + MaxCompletionTokens: opts.MaxCompletionTokens, ToolChoice: opts.ToolChoice, FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior), Seed: opts.Seed, diff --git a/llms/options.go b/llms/options.go index b6b595290..661c8dc16 100644 --- a/llms/options.go +++ b/llms/options.go @@ -14,6 +14,8 @@ type CallOptions struct { CandidateCount int `json:"candidate_count"` // MaxTokens is the maximum number of tokens to generate. MaxTokens int `json:"max_tokens"` + // MaxCompletionTokens is the maximum number of tokens to generate - used in the current openai API + MaxCompletionTokens int `json:"max_completion_tokens"` // Temperature is the temperature for sampling, between 0 and 1. Temperature float64 `json:"temperature"` // StopWords is a list of words to stop on. @@ -126,6 +128,13 @@ func WithMaxTokens(maxTokens int) CallOption { } } +// WithMaxCompletionTokens specifies the max number of tokens to generate - used in the current openai API +func WithMaxCompletionTokens(maxCompletionTokens int) CallOption { + return func(o *CallOptions) { + o.MaxCompletionTokens = maxCompletionTokens + } +} + // WithCandidateCount specifies the number of response candidates to generate. func WithCandidateCount(c int) CallOption { return func(o *CallOptions) {