Skip to content

Commit

Permalink
Merge pull request #479 from chyroc/feat-fstring
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc authored Jan 20, 2024
2 parents f2c4e86 + a0ce29d commit 02ba20b
Show file tree
Hide file tree
Showing 6 changed files with 225 additions and 1 deletion.
2 changes: 2 additions & 0 deletions prompts/internal/fstring/doc.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
// Package fstring contains template format with f-string.
package fstring
20 changes: 20 additions & 0 deletions prompts/internal/fstring/fstring.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package fstring

import "errors"

var (
ErrEmptyExpression = errors.New("empty expression not allowed")
ErrArgsNotDefined = errors.New("args not defined")
ErrLeftBracketNotClosed = errors.New("single '{' is not allowed")
ErrRightBracketNotClosed = errors.New("single '}' is not allowed")
)

// Format interpolates the given template with the given values by using
// f-string.
func Format(template string, values map[string]any) (string, error) {
p := newParser(template, values)
if err := p.parse(); err != nil {
return "", err
}
return string(p.result), nil
}
49 changes: 49 additions & 0 deletions prompts/internal/fstring/fstring_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
package fstring

import (
"strings"
"testing"
)

func TestFormat(t *testing.T) {
t.Parallel()

type args struct {
format string
values map[string]any
}
tests := []struct {
name string
args args
want string
wantErr string
}{
{"1", args{"{", map[string]any{}}, "", "single '{' is not allowed"},
{"2", args{"{{", map[string]any{}}, "{", ""},
{"3", args{"}", map[string]any{}}, "", "single '}' is not allowed"},
{"4", args{"}}", map[string]any{}}, "}", ""},
{"4", args{"{}", map[string]any{}}, "", "empty expression not allowed"},
{"4", args{"{val}", map[string]any{}}, "", "args not defined"},
{"4", args{"a={val}", map[string]any{"val": 1}}, "a=1", ""},
{"4", args{"a= {val}", map[string]any{"val": 1}}, "a= 1", ""},
{"4", args{"a= { val }", map[string]any{"val": 1}}, "a= 1", ""},
}
for _, tt := range tests {
tt := tt
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
got, err := Format(tt.args.format, tt.args.values)
if (err != nil) != (tt.wantErr != "") {
t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr)
return
}
if err != nil && !strings.Contains(err.Error(), tt.wantErr) {
t.Errorf("Format() error = %v, wantErr %v", err, tt.wantErr)
return
}
if got != tt.want {
t.Errorf("Format() got = %v, want %v", got, tt.want)
}
})
}
}
149 changes: 149 additions & 0 deletions prompts/internal/fstring/parser.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package fstring

import (
"fmt"
"strconv"
"strings"
)

type parser struct {
data []rune
result []rune
idx int
values map[string]any
}

func newParser(s string, values map[string]any) *parser {
if len(values) == 0 {
values = map[string]any{}
}
return &parser{
data: []rune(s),
result: nil,
idx: 0,
values: values,
}
}

func (r *parser) parse() error {
for r.hasMore() {
existLeftCurlyBracket, tmp, err := r.scanToLeftCurlyBracket()
if err != nil {
return err
}
r.result = append(r.result, tmp...)
if !existLeftCurlyBracket {
continue
}

tmp = r.scanToRightCurlyBracket()
valName := strings.TrimSpace(string(tmp))
if valName == "" {
return ErrEmptyExpression
}
val, ok := r.values[valName]
if !ok {
return fmt.Errorf("%w: %s", ErrArgsNotDefined, valName)
}
r.result = append(r.result, []rune(toString(val))...)
}
return nil
}

func (r *parser) scanToLeftCurlyBracket() (bool, []rune, error) {
res := []rune{}
for r.hasMore() {
s := r.get()
r.idx++
switch s {
case '}':
if r.hasMore() && r.get() == '}' {
res = append(res, '}') // nolint:ineffassign,staticcheck
r.idx++
continue
}
return false, nil, ErrRightBracketNotClosed
case '{':
if !r.hasMore() {
return false, nil, ErrLeftBracketNotClosed
}
if r.get() == '{' {
// {{ -> {
r.idx++
res = append(res, '{')
continue
}
return true, res, nil
default:
res = append(res, s)
}
}
return false, res, nil
}

func (r *parser) scanToRightCurlyBracket() []rune {
var res []rune
for r.hasMore() {
s := r.get()
if s != '}' {
// xxx
res = append(res, s)
r.idx++
continue
}
r.idx++
break
}
return res
}

func (r *parser) hasMore() bool {
return r.idx < len(r.data)
}

func (r *parser) get() rune {
return r.data[r.idx]
}

// nolint: cyclop
func toString(val any) string {
if val == nil {
return "nil" // f'None' -> "None"
}
switch val := val.(type) {
case string:
return val
case []rune:
return string(val)
case []byte:
return string(val)
case int:
return strconv.FormatInt(int64(val), 10)
case int8:
return strconv.FormatInt(int64(val), 10)
case int16:
return strconv.FormatInt(int64(val), 10)
case int32:
return strconv.FormatInt(int64(val), 10)
case int64:
return strconv.FormatInt(val, 10)
case uint:
return strconv.FormatUint(uint64(val), 10)
case uint8:
return strconv.FormatUint(uint64(val), 10)
case uint16:
return strconv.FormatUint(uint64(val), 10)
case uint32:
return strconv.FormatUint(uint64(val), 10)
case uint64:
return strconv.FormatUint(val, 10)
case float32:
return strconv.FormatFloat(float64(val), 'f', -1, 32)
case float64:
return strconv.FormatFloat(val, 'f', -1, 64)
case bool:
return strconv.FormatBool(val)
default:
return fmt.Sprintf("%s", val)
}
}
4 changes: 4 additions & 0 deletions prompts/templates.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/Masterminds/sprig/v3"
"github.com/nikolalohinski/gonja"
"github.com/tmc/langchaingo/prompts/internal/fstring"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
Expand All @@ -24,6 +25,8 @@ const (
TemplateFormatGoTemplate TemplateFormat = "go-template"
// TemplateFormatJinja2 is the format for jinja2.
TemplateFormatJinja2 TemplateFormat = "jinja2"
// TemplateFormatFString is the format for f-string.
TemplateFormatFString TemplateFormat = "f-string"
)

// interpolator is the function that interpolates the given template with the given values.
Expand All @@ -33,6 +36,7 @@ type interpolator func(template string, values map[string]any) (string, error)
var defaultFormatterMapping = map[TemplateFormat]interpolator{ //nolint:gochecknoglobals
TemplateFormatGoTemplate: interpolateGoTemplate,
TemplateFormatJinja2: interpolateJinja2,
TemplateFormatFString: fstring.Format,
}

// interpolateGoTemplate interpolates the given template with the given values by using
Expand Down
2 changes: 1 addition & 1 deletion prompts/templates_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ func TestCheckValidTemplate(t *testing.T) {
err := CheckValidTemplate("Hello, {test}", "unknown", []string{"test"})
require.Error(t, err)
require.ErrorIs(t, err, ErrInvalidTemplateFormat)
require.EqualError(t, err, "invalid template format, got: unknown, should be one of [go-template jinja2]")
require.EqualError(t, err, "invalid template format, got: unknown, should be one of [f-string go-template jinja2]")
})

t.Run("TemplateErrored", func(t *testing.T) {
Expand Down

0 comments on commit 02ba20b

Please sign in to comment.