diff --git a/outputparser/boolean_parser.go b/outputparser/boolean_parser.go index a91d0091d..46970771c 100644 --- a/outputparser/boolean_parser.go +++ b/outputparser/boolean_parser.go @@ -11,15 +11,15 @@ import ( // BooleanParser is an output parser used to parse the output of an LLM as a boolean. type BooleanParser struct { - TrueStr string - FalseStr string + TrueStrings []string + FalseStrings []string } // NewBooleanParser returns a new BooleanParser. func NewBooleanParser() BooleanParser { return BooleanParser{ - TrueStr: "YES", - FalseStr: "NO", + TrueStrings: []string{"YES", "TRUE"}, + FalseStrings: []string{"NO", "FALSE"}, } } @@ -33,20 +33,24 @@ func (p BooleanParser) GetFormatInstructions() string { func (p BooleanParser) parse(text string) (bool, error) { text = normalize(text) - booleanStrings := []string{p.TrueStr, p.FalseStr} - if !slices.Contains(booleanStrings, text) { - return false, ParseError{ - Text: text, - Reason: fmt.Sprintf("Expected output to be either '%s' or '%s', received %s", p.TrueStr, p.FalseStr, text), - } + if slices.Contains(p.TrueStrings, text) { + return true, nil } - return text == p.TrueStr, nil + if slices.Contains(p.FalseStrings, text) { + return false, nil + } + + return false, ParseError{ + Text: text, + Reason: fmt.Sprintf("Expected output to one of %v, received %s", append(p.TrueStrings, p.FalseStrings...), text), + } } func normalize(text string) string { text = strings.TrimSpace(text) + text = strings.Trim(text, "'\"`") text = strings.ToUpper(text) return text diff --git a/outputparser/boolean_parser_test.go b/outputparser/boolean_parser_test.go index 9ab92c664..ee3f5b195 100644 --- a/outputparser/boolean_parser_test.go +++ b/outputparser/boolean_parser_test.go @@ -24,6 +24,7 @@ func TestBooleanParser(t *testing.T) { }, { input: "YESNO", + err: outputparser.ParseError{}, expected: false, }, { @@ -31,18 +32,62 @@ func TestBooleanParser(t *testing.T) { err: outputparser.ParseError{}, expected: false, }, + { + input: "true", + expected: true, + }, + { + input: "false", + expected: false, + }, + { + input: "True", + expected: true, + }, + { + input: "False", + expected: false, + }, + { + input: "TRUE", + expected: true, + }, + { + input: "FALSE", + expected: false, + }, + { + input: "'TRUE'", + expected: true, + }, + { + input: "`TRUE`", + expected: true, + }, + { + input: "'TRUE`", + expected: true, + }, } for _, tc := range testCases { parser := outputparser.NewBooleanParser() - actual, err := parser.Parse(tc.input) - if tc.err != nil && err == nil { - t.Errorf("Expected error %v, got nil", tc.err) - } + t.Run(tc.input, func(t *testing.T) { + t.Parallel() + + result, err := parser.Parse(tc.input) + if err != nil && tc.err == nil { + t.Errorf("Unexpected error: %v", err) + } + + if err == nil && tc.err != nil { + t.Errorf("Expected error %v, got nil", tc.err) + } - if actual != tc.expected { - t.Errorf("Expected %v, got %v", tc.expected, actual) - } + if result != tc.expected { + t.Errorf("Expected %v, but got %v", tc.expected, result) + } + }) } }