Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sync fork #1113

Closed
wants to merge 28 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
6b4e6a4
outputparser: improve BooleanOutputParser
amitaifrey Jul 31, 2024
034a0a4
Merge pull request #1 from lema-ai/improve-boolean-parser
amitaifrey Aug 4, 2024
1ed2922
outputparser: improve BooleanOutputParser
amitaifrey Jul 31, 2024
3432c64
Merge branch 'main' of github.com:lema-ai/langchaingo into local/impr…
amitaifrey Aug 4, 2024
36a9dba
Merge branch 'main' of github.com:lema-ai/langchaingo into local/impr…
amitaifrey Aug 4, 2024
dafebdf
Merge branch 'local/improve-boolean-parser' of github.com:lema-ai/lan…
amitaifrey Aug 4, 2024
fe7b3ff
Merge pull request #3 from lema-ai/local/improve-boolean-parser
amitaifrey Aug 4, 2024
abd68a9
outputparser: improve DefinedOutputParser
amitaifrey Aug 4, 2024
679443e
Merge pull request #4 from lema-ai/improve-defined-parser
amitaifrey Aug 4, 2024
ac55095
outputparser: allow json to start with ```
amitaifrey Aug 21, 2024
4c1b0a5
Merge pull request #5 from lema-ai/relax-defined-parser
amitaifrey Aug 27, 2024
10d2006
outputparser: allow json to start with any prefix
amitaifrey Nov 21, 2024
94458aa
Merge branch 'main' of github.com:tmc/langchaingo
amitaifrey Nov 25, 2024
99cd08a
Handle LLM verbosity
amitaifrey Nov 26, 2024
f0af39f
Just use brackets
amitaifrey Nov 26, 2024
f74fbd8
Merge pull request #6 from lema-ai/relax-defined-parser-more
amitaifrey Dec 4, 2024
363a03e
✨feat: Add citations to openai response
oryanmoshe Dec 10, 2024
0960a3b
Handle max completion tokens and max tokens
amitaifrey Jan 7, 2025
956c7bc
Merge pull request #8 from lema-ai/max-completion-tokens
amitaifrey Jan 7, 2025
0eac5bc
outputparser: improve BooleanOutputParser
amitaifrey Jul 31, 2024
faea392
outputparser: improve BooleanOutputParser
amitaifrey Jan 27, 2025
033ac65
outputparser: allow json to start with ```
amitaifrey Aug 21, 2024
a952c84
outputparser: allow json to start with any prefix
amitaifrey Nov 21, 2024
bdf674a
Handle LLM verbosity
amitaifrey Nov 26, 2024
c1cefd0
Just use brackets
amitaifrey Nov 26, 2024
7e21f43
✨feat: Add citations to openai response
oryanmoshe Dec 10, 2024
7701e3a
Handle max completion tokens and max tokens
amitaifrey Jan 7, 2025
d0fb78a
Merge branch 'main' of github.com:lema-ai/langchaingo
amitaifrey Jan 27, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions chains/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ type chainCallOption struct {
MaxTokens int
maxTokensSet bool

// MaxCompletionTokens is the maximum number of tokens to generate - used in the current openai API
MaxCompletionTokens int
maxCompletionTokensSet bool

// Temperature is the temperature for sampling to use in an LLM call, between 0 and 1.
Temperature float64
temperatureSet bool
Expand Down Expand Up @@ -83,6 +87,14 @@ func WithMaxTokens(maxTokens int) ChainCallOption {
}
}

// WithMaxCompletionTokens is an option for LLM.Call.
func WithMaxCompletionTokens(maxCompletionTokens int) ChainCallOption {
return func(o *chainCallOption) {
o.MaxCompletionTokens = maxCompletionTokens
o.maxCompletionTokensSet = true
}
}

// WithTemperature is an option for LLM.Call.
func WithTemperature(temperature float64) ChainCallOption {
return func(o *chainCallOption) {
Expand Down Expand Up @@ -181,6 +193,9 @@ func getLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint:
if opts.maxTokensSet {
chainCallOption = append(chainCallOption, llms.WithMaxTokens(opts.MaxTokens))
}
if opts.maxCompletionTokensSet {
chainCallOption = append(chainCallOption, llms.WithMaxCompletionTokens(opts.MaxCompletionTokens))
}
if opts.temperatureSet {
chainCallOption = append(chainCallOption, llms.WithTemperature(opts.Temperature))
}
Expand Down Expand Up @@ -209,3 +224,7 @@ func getLLMCallOptions(options ...ChainCallOption) []llms.CallOption { //nolint:

return chainCallOption
}

func GetLLMCallOptions(options ...ChainCallOption) []llms.CallOption {
return getLLMCallOptions(options...)
}
3 changes: 2 additions & 1 deletion llms/openai/internal/openaiclient/chat.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ type ChatRequest struct {
Temperature float64 `json:"temperature"`
TopP float64 `json:"top_p,omitempty"`
// Deprecated: Use MaxCompletionTokens
MaxTokens int `json:"-,omitempty"`
MaxTokens int `json:"max_tokens,omitempty"`
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
N int `json:"n,omitempty"`
StopWords []string `json:"stop,omitempty"`
Expand Down Expand Up @@ -297,6 +297,7 @@ type ChatCompletionResponse struct {
Object string `json:"object,omitempty"`
Usage ChatUsage `json:"usage,omitempty"`
SystemFingerprint string `json:"system_fingerprint"`
Citations []string `json:"citations,omitempty"`
}

type Usage struct {
Expand Down
25 changes: 14 additions & 11 deletions llms/openai/openaillm.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ type ChatMessage = openaiclient.ChatMessage
type LLM struct {
CallbacksHandler callbacks.Handler
client *openaiclient.Client
opts *options
}

const (
Expand All @@ -35,6 +36,7 @@ func New(opts ...Option) (*LLM, error) {
return &LLM{
client: c,
CallbacksHandler: opt.callbackHandler,
opts: opt,
}, err
}

Expand Down Expand Up @@ -96,17 +98,16 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
chatMsgs = append(chatMsgs, msg)
}
req := &openaiclient.ChatRequest{
Model: opts.Model,
StopWords: opts.StopWords,
Messages: chatMsgs,
StreamingFunc: opts.StreamingFunc,
Temperature: opts.Temperature,
N: opts.N,
FrequencyPenalty: opts.FrequencyPenalty,
PresencePenalty: opts.PresencePenalty,

MaxCompletionTokens: opts.MaxTokens,

Model: opts.Model,
StopWords: opts.StopWords,
Messages: chatMsgs,
StreamingFunc: opts.StreamingFunc,
Temperature: opts.Temperature,
N: opts.N,
FrequencyPenalty: opts.FrequencyPenalty,
PresencePenalty: opts.PresencePenalty,
MaxTokens: opts.MaxTokens,
MaxCompletionTokens: opts.MaxCompletionTokens,
ToolChoice: opts.ToolChoice,
FunctionCallBehavior: openaiclient.FunctionCallBehavior(opts.FunctionCallBehavior),
Seed: opts.Seed,
Expand All @@ -116,6 +117,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
req.ResponseFormat = ResponseFormatJSON
}


// since req.Functions is deprecated, we need to use the new Tools API.
for _, fn := range opts.Functions {
req.Tools = append(req.Tools, openaiclient.Tool{
Expand Down Expand Up @@ -160,6 +162,7 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
"PromptTokens": result.Usage.PromptTokens,
"TotalTokens": result.Usage.TotalTokens,
"ReasoningTokens": result.Usage.CompletionTokensDetails.ReasoningTokens,
"Citations": result.Citations,
},
}

Expand Down
9 changes: 9 additions & 0 deletions llms/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
CandidateCount int `json:"candidate_count"`
// MaxTokens is the maximum number of tokens to generate.
MaxTokens int `json:"max_tokens"`
// MaxCompletionTokens is the maximum number of tokens to generate - used in the current openai API
MaxCompletionTokens int `json:"max_completion_tokens"`
// Temperature is the temperature for sampling, between 0 and 1.
Temperature float64 `json:"temperature"`
// StopWords is a list of words to stop on.
Expand Down Expand Up @@ -126,6 +128,13 @@
}
}

// WithMaxCompletionTokens specifies the max number of tokens to generate - used in the current openai API

Check failure on line 131 in llms/options.go

View workflow job for this annotation

GitHub Actions / Lint

Comment should end in a period (godot)
func WithMaxCompletionTokens(maxCompletionTokens int) CallOption {
return func(o *CallOptions) {
o.MaxCompletionTokens = maxCompletionTokens
}
}

// WithCandidateCount specifies the number of response candidates to generate.
func WithCandidateCount(c int) CallOption {
return func(o *CallOptions) {
Expand Down
1 change: 1 addition & 0 deletions outputparser/boolean_parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@
}

for _, tc := range testCases {
tc := tc

Check failure on line 74 in outputparser/boolean_parser_test.go

View workflow job for this annotation

GitHub Actions / Lint

The copy of the 'for' variable "tc" can be deleted (Go 1.22+) (copyloopvar)
parser := outputparser.NewBooleanParser()

t.Run(tc.input, func(t *testing.T) {
Expand Down
15 changes: 8 additions & 7 deletions outputparser/defined.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"errors"
"fmt"
"reflect"
"strings"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
Expand Down Expand Up @@ -61,14 +62,14 @@ func (p Defined[T]) GetFormatInstructions() string {
func (p Defined[T]) Parse(text string) (T, error) {
var target T

// Removes '```json' and '```' from the start and end of the text.
const opening = "```json"
const closing = "```"
if text[:len(opening)] != opening || text[len(text)-len(closing):] != closing {
return target, fmt.Errorf("input text should start with %s and end with %s", opening, closing)
startIndex := strings.Index(text, "{")
endIndex := strings.LastIndex(text, "}")
if startIndex == -1 || endIndex == -1 {
return target, fmt.Errorf("could not find start or end of JSON object")
}
parseableJSON := text[len(opening) : len(text)-len(closing)]
if err := json.Unmarshal([]byte(parseableJSON), &target); err != nil {
text = text[startIndex : endIndex+1]

if err := json.Unmarshal([]byte(text), &target); err != nil {
return target, fmt.Errorf("could not parse generated JSON: %w", err)
}
return target, nil
Expand Down
153 changes: 129 additions & 24 deletions outputparser/defined_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,37 +97,142 @@ interface Foods {
}
}

func TestDefinedParse(t *testing.T) {
t.Parallel()
var book struct {
Chapters []struct {
Title string `json:"title" describe:"chapter title"`
} `json:"chapters" describe:"chapters"`
}
parser, newErr := NewDefined(book)
if newErr != nil {
t.Error(newErr)
}
type book struct {
Chapters []struct {
Title string `json:"title" describe:"chapter title"`
} `json:"chapters" describe:"chapters"`
}

func getParseTests() map[string]struct {
input string
expected *book
wantErr bool
} {
titles := []string{
"A Hello There",
"The Meaty Middle",
"The Grand Finale",
}

output, parseErr := parser.Parse(fmt.Sprintf("```json\n%s\n```", fmt.Sprintf(
`{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2],
)))
if parseErr != nil {
t.Error(parseErr)
}
if count := len(output.Chapters); count != 3 {
t.Errorf("got %d chapters; want 3", count)
return map[string]struct {
input string
expected *book
wantErr bool
}{
"empty": {
input: "",
wantErr: true,
expected: nil,
},
"invalid": {
input: "invalid",
wantErr: true,
expected: nil,
},
"valid": {
input: fmt.Sprintf("```json\n%s\n```", fmt.Sprintf(
`{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2],
)),
wantErr: false,
expected: &book{
Chapters: []struct {
Title string `json:"title" describe:"chapter title"`
}{
{Title: titles[0]},
{Title: titles[1]},
{Title: titles[2]},
},
},
},
"valid-without-json-tag": {
input: fmt.Sprintf("```\n%s\n```", fmt.Sprintf(
`{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2],
)),
wantErr: false,
expected: &book{
Chapters: []struct {
Title string `json:"title" describe:"chapter title"`
}{
{Title: titles[0]},
{Title: titles[1]},
{Title: titles[2]},
},
},
},
"valid-without-tags": {
input: fmt.Sprintf("\n%s\n", fmt.Sprintf(
`{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2],
)),
wantErr: false,
expected: &book{
Chapters: []struct {
Title string `json:"title" describe:"chapter title"`
}{
{Title: titles[0]},
{Title: titles[1]},
{Title: titles[2]},
},
},
},
"llm-explanation-and-tags": {
input: fmt.Sprintf("Sure! Here's the JSON:\n\n```json\n%s\n```\n\nLet me know if you need anything else.", fmt.Sprintf(
`{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2],
)),
wantErr: false,
expected: &book{
Chapters: []struct {
Title string `json:"title" describe:"chapter title"`
}{
{Title: titles[0]},
{Title: titles[1]},
{Title: titles[2]},
},
},
},
"llm-explanation-and-valid": {
input: fmt.Sprintf("Sure! Here's the JSON:\n\n%s\n\nLet me know if you need anything else.", fmt.Sprintf(
`{"chapters": [{"title": "%s"}, {"title": "%s"}, {"title": "%s"}]}`, titles[0], titles[1], titles[2],
)),
wantErr: false,
expected: &book{
Chapters: []struct {
Title string `json:"title" describe:"chapter title"`
}{
{Title: titles[0]},
{Title: titles[1]},
{Title: titles[2]},
},
},
},
}
for i, chapter := range output.Chapters {
title := titles[i]
if chapter.Title != titles[i] {
t.Errorf("got '%s'; want '%s'", chapter.Title, title)
}
}

func TestDefinedParse(t *testing.T) {
t.Parallel()
for name, test := range getParseTests() {
t.Run(name, func(t *testing.T) {
t.Parallel()
parser, newErr := NewDefined(book{})
if newErr != nil {
t.Error(newErr)
}
output, parseErr := parser.Parse(test.input)
switch {
case parseErr != nil && !test.wantErr:
t.Errorf("%s: unexpected error: %v", name, parseErr)
case parseErr == nil && test.wantErr:
t.Errorf("%s: expected error", name)
case parseErr == nil && test.expected != nil:
if count := len(output.Chapters); count != len(test.expected.Chapters) {
t.Errorf("%s: got %d chapters; want %d", name, count, len(test.expected.Chapters))
}
for i, chapter := range output.Chapters {
title := test.expected.Chapters[i].Title
if chapter.Title != title {
t.Errorf("%s: got '%s'; want '%s'", name, chapter.Title, title)
}
}
}
})
}
}
Loading