Skip to content

Commit

Permalink
bug fix: openai_functions_agent not compatible with the new llms.Mode…
Browse files Browse the repository at this point in the history
…l interface (#536)

bug fix: openai_functions_agent not compatible with the new llm.Model interface
  • Loading branch information
devinyf authored Jan 20, 2024
1 parent 39ca49e commit 4c0c4f9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 23 deletions.
2 changes: 1 addition & 1 deletion agents/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func TestExecutorWithOpenAIFunctionAgent(t *testing.T) {
t.Skip("SERPAPI_API_KEY not set")
}

llm, err := openai.NewChat()
llm, err := openai.New()
require.NoError(t, err)

searchTool, err := serpapi.New()
Expand Down
35 changes: 24 additions & 11 deletions agents/openai_functions_agent.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ const agentScratchpad = "agent_scratchpad"
type OpenAIFunctionsAgent struct {
// LLM is the llm used to call with the values. The llm should have an
// input called "agent_scratchpad" for the agent to put its thoughts in.
LLM llms.ChatLLM
LLM llms.Model
Prompt prompts.FormatPrompter
// Chain chains.Chain
// Tools is a list of the tools the agent can use.
Expand All @@ -33,7 +33,7 @@ type OpenAIFunctionsAgent struct {
var _ Agent = (*OpenAIFunctionsAgent)(nil)

// NewOpenAIFunctionsAgent creates a new OpenAIFunctionsAgent.
func NewOpenAIFunctionsAgent(llm llms.ChatLLM, tools []tools.Tool, opts ...CreationOption) *OpenAIFunctionsAgent {
func NewOpenAIFunctionsAgent(llm llms.Model, tools []tools.Tool, opts ...CreationOption) *OpenAIFunctionsAgent {
options := openAIFunctionsDefaultOptions()
for _, opt := range opts {
opt(&options)
Expand Down Expand Up @@ -92,7 +92,19 @@ func (o *OpenAIFunctionsAgent) Plan(
return nil, nil, err
}

result, err := o.LLM.Generate(ctx, [][]schema.ChatMessage{prompt.Messages()},
mcList := make([]llms.MessageContent, len(prompt.Messages()))
for i, msg := range prompt.Messages() {
role := msg.GetType()
text := msg.GetContent()

mc := llms.MessageContent{
Role: role,
Parts: []llms.ContentPart{llms.TextContent{Text: text}},
}
mcList[i] = mc
}

result, err := o.LLM.GenerateContent(ctx, mcList,
llms.WithFunctions(o.functions()), llms.WithStreamingFunc(stream))
if err != nil {
return nil, nil, err
Expand Down Expand Up @@ -148,22 +160,23 @@ func (o *OpenAIFunctionsAgent) constructScratchPad(steps []schema.AgentStep) []s
return messages
}

func (o *OpenAIFunctionsAgent) ParseOutput(generations []*llms.Generation) (
func (o *OpenAIFunctionsAgent) ParseOutput(contentResp *llms.ContentResponse) (
[]schema.AgentAction, *schema.AgentFinish, error,
) {
msg := generations[0].Message
choice := contentResp.Choices[0]

// finish
if generations[0].Message.FunctionCall == nil {
if choice.FuncCall == nil {
return nil, &schema.AgentFinish{
ReturnValues: map[string]any{
"output": msg.Content,
"output": choice.Content,
},
Log: msg.Content,
Log: choice.Content,
}, nil
}

// action
functionCall := msg.FunctionCall
functionCall := choice.FuncCall
functionName := functionCall.Name
toolInputStr := functionCall.Arguments
toolInputMap := make(map[string]any, 0)
Expand All @@ -181,8 +194,8 @@ func (o *OpenAIFunctionsAgent) ParseOutput(generations []*llms.Generation) (
}

contentMsg := "\n"
if msg.Content != "" {
contentMsg = fmt.Sprintf("responded: %s\n", msg.Content)
if choice.Content != "" {
contentMsg = fmt.Sprintf("responded: %s\n", choice.Content)
}

return []schema.AgentAction{
Expand Down
11 changes: 0 additions & 11 deletions llms/llms.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,3 @@ func CallLLM(ctx context.Context, llm Model, prompt string, options ...CallOptio
c1 := choices[0]
return c1.Content, nil
}

func GenerateChatPrompt(ctx context.Context, l ChatLLM, promptValues []schema.PromptValue, options ...CallOption) (LLMResult, error) { //nolint:lll
messages := make([][]schema.ChatMessage, 0, len(promptValues))
for _, promptValue := range promptValues {
messages = append(messages, promptValue.Messages())
}
generations, err := l.Generate(ctx, messages, options...)
return LLMResult{
Generations: [][]*Generation{generations},
}, err
}

0 comments on commit 4c0c4f9

Please sign in to comment.