Skip to content

Commit

Permalink
Merge pull request #1 from nuffin/main
Browse files Browse the repository at this point in the history
merge in toolcall fix
  • Loading branch information
taigrr authored Jan 18, 2025
2 parents d6af34a + f2fd461 commit 0b0343f
Show file tree
Hide file tree
Showing 11 changed files with 280 additions and 170 deletions.
3 changes: 2 additions & 1 deletion agents/agents.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package agents
import (
"context"

"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)
Expand All @@ -11,7 +12,7 @@ import (
type Agent interface {
// Plan Given an input and previous steps decide what to do next. Returns
// either actions or a finish.
Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]string) ([]schema.AgentAction, *schema.AgentFinish, error) //nolint:lll
Plan(ctx context.Context, intermediateSteps []schema.AgentStep, inputs map[string]any, intermediateMessages []llms.ChatMessage) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) //nolint:lll
GetInputKeys() []string
GetOutputKeys() []string
GetTools() []tools.Tool
Expand Down
15 changes: 8 additions & 7 deletions agents/conversational.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,9 @@ func NewConversationalAgent(llm llms.Model, tools []tools.Tool, opts ...Option)
func (a *ConversationalAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
inputs map[string]any,
_ []llms.ChatMessage,
) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
fullInputs[key] = value
Expand All @@ -88,7 +89,7 @@ func (a *ConversationalAgent) Plan(
chains.WithStreamingFunc(stream),
)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

return a.parseOutput(output)
Expand Down Expand Up @@ -130,7 +131,7 @@ func constructScratchPad(steps []schema.AgentStep) string {
return scratchPad
}

func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) {
func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
if strings.Contains(output, _conversationalFinalAnswerAction) {
splits := strings.Split(output, _conversationalFinalAnswerAction)

Expand All @@ -141,18 +142,18 @@ func (a *ConversationalAgent) parseOutput(output string) ([]schema.AgentAction,
Log: output,
}

return nil, finishAction, nil
return nil, finishAction, nil, nil
}

r := regexp.MustCompile(`Action: (.*?)[\n]*Action Input: (.*)`)
matches := r.FindStringSubmatch(output)
if len(matches) == 0 {
return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
return nil, nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
}

return []schema.AgentAction{
{Tool: strings.TrimSpace(matches[1]), ToolInput: strings.TrimSpace(matches[2]), Log: output},
}, nil, nil
}, nil, nil, nil
}

//go:embed prompts/conversational_prefix.txt
Expand Down
126 changes: 76 additions & 50 deletions agents/executor.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/tools"
)
Expand Down Expand Up @@ -48,16 +49,18 @@ func NewExecutor(agent Agent, opts ...Option) *Executor {
}

func (e *Executor) Call(ctx context.Context, inputValues map[string]any, _ ...chains.ChainCallOption) (map[string]any, error) { //nolint:lll
inputs, err := inputsToString(inputValues)
if err != nil {
return nil, err
}
// inputs, err := inputsToString(inputValues)
// if err != nil {
// return nil, err
//}
nameToTool := getNameToTool(e.Agent.GetTools())

steps := make([]schema.AgentStep, 0)
var intermediateMessages []llms.ChatMessage
var err error
for i := 0; i < e.MaxIterations; i++ {
var finish map[string]any
steps, finish, err = e.doIteration(ctx, steps, nameToTool, inputs)
steps, finish, intermediateMessages, err = e.doIteration(ctx, steps, nameToTool, inputValues, intermediateMessages)
if finish != nil || err != nil {
return finish, err
}
Expand All @@ -78,9 +81,13 @@ func (e *Executor) doIteration( // nolint
ctx context.Context,
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
inputs map[string]string,
) ([]schema.AgentStep, map[string]any, error) {
actions, finish, err := e.Agent.Plan(ctx, steps, inputs)
inputs map[string]any,
intermediateMessages []llms.ChatMessage,
) ([]schema.AgentStep, map[string]any, []llms.ChatMessage, error) {
actions, finish, newIntermediateMessages, err := e.Agent.Plan(ctx, steps, inputs, intermediateMessages)
if len(newIntermediateMessages) > 0 {
intermediateMessages = append(intermediateMessages, newIntermediateMessages...)
}
if errors.Is(err, ErrUnableToParseOutput) && e.ErrorHandler != nil {
formattedObservation := err.Error()
if e.ErrorHandler.Formatter != nil {
Expand All @@ -89,60 +96,79 @@ func (e *Executor) doIteration( // nolint
steps = append(steps, schema.AgentStep{
Observation: formattedObservation,
})
return steps, nil, nil
return steps, nil, intermediateMessages, nil
}
if err != nil {
return steps, nil, err
return steps, nil, intermediateMessages, err
}

if len(actions) == 0 && finish == nil {
return steps, nil, ErrAgentNoReturn
return steps, nil, intermediateMessages, ErrAgentNoReturn
}

if finish != nil {
if e.CallbacksHandler != nil {
e.CallbacksHandler.HandleAgentFinish(ctx, *finish)
}
return steps, e.getReturn(finish, steps), nil
return steps, e.getReturn(finish, steps), intermediateMessages, nil
}

for _, action := range actions {
steps, err = e.doAction(ctx, steps, nameToTool, action)
if err != nil {
return steps, nil, err
stepStreams := make([]<-chan schema.AgentStepWithError, len(actions))
for index, action := range actions {
stepStreams[index] = e.doAction(ctx, nameToTool, action)
}
for _, stepStream := range stepStreams {
agentStepWithError := <-stepStream
if agentStepWithError.Error != nil {
return steps, nil, intermediateMessages, agentStepWithError.Error
}
steps = append(steps, agentStepWithError.AgentStep)
}

return steps, nil, nil
return steps, nil, intermediateMessages, nil
}

func (e *Executor) doAction(
ctx context.Context,
steps []schema.AgentStep,
nameToTool map[string]tools.Tool,
action schema.AgentAction,
) ([]schema.AgentStep, error) {
if e.CallbacksHandler != nil {
e.CallbacksHandler.HandleAgentAction(ctx, action)
}
) <-chan schema.AgentStepWithError {
agentStepStream := make(chan schema.AgentStepWithError)
go func() {
defer close(agentStepStream)
if e.CallbacksHandler != nil {
e.CallbacksHandler.HandleAgentAction(ctx, action)
}

tool, ok := nameToTool[strings.ToUpper(action.Tool)]
if !ok {
return append(steps, schema.AgentStep{
Action: action,
Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool),
}), nil
}
tool, ok := nameToTool[strings.ToUpper(action.Tool)]
if !ok {
agentStepStream <- schema.AgentStepWithError{
AgentStep: schema.AgentStep{
Action: action,
Observation: fmt.Sprintf("%s is not a valid tool, try another one", action.Tool),
},
Error: nil,
}
return
}

observation, err := tool.Call(ctx, action.ToolInput)
if err != nil {
return nil, err
}
observation, err := tool.Call(ctx, action.ToolInput)
if err != nil {
agentStepStream <- schema.AgentStepWithError{
AgentStep: schema.AgentStep{}, Error: err,
}
return
}

return append(steps, schema.AgentStep{
Action: action,
Observation: observation,
}), nil
agentStepStream <- schema.AgentStepWithError{
AgentStep: schema.AgentStep{
Action: action,
Observation: observation,
},
Error: nil,
}
}()
return agentStepStream
}

func (e *Executor) getReturn(finish *schema.AgentFinish, steps []schema.AgentStep) map[string]any {
Expand Down Expand Up @@ -172,19 +198,19 @@ func (e *Executor) GetCallbackHandler() callbacks.Handler { //nolint:ireturn
return e.CallbacksHandler
}

func inputsToString(inputValues map[string]any) (map[string]string, error) {
inputs := make(map[string]string, len(inputValues))
for key, value := range inputValues {
valueStr, ok := value.(string)
if !ok {
return nil, fmt.Errorf("%w: %s", ErrExecutorInputNotString, key)
}

inputs[key] = valueStr
}

return inputs, nil
}
// func inputsToString(inputValues map[string]any) (map[string]string, error) {
// inputs := make(map[string]string, len(inputValues))
// for key, value := range inputValues {
// valueStr, ok := value.(string)
// if !ok {
// return nil, fmt.Errorf("%w: %s", ErrExecutorInputNotString, key)
// }
//
// inputs[key] = valueStr
// }
//
// return inputs, nil
//}

func getNameToTool(t []tools.Tool) map[string]tools.Tool {
if len(t) == 0 {
Expand Down
13 changes: 7 additions & 6 deletions agents/executor_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/require"
"github.com/tmc/langchaingo/agents"
"github.com/tmc/langchaingo/chains"
"github.com/tmc/langchaingo/llms"
"github.com/tmc/langchaingo/llms/openai"
"github.com/tmc/langchaingo/prompts"
"github.com/tmc/langchaingo/schema"
Expand All @@ -24,27 +25,27 @@ type testAgent struct {
outputKeys []string

recordedIntermediateSteps []schema.AgentStep
recordedInputs map[string]string
recordedInputs map[string]any
numPlanCalls int
}

func (a *testAgent) Plan(
_ context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
inputs map[string]any, _ []llms.ChatMessage,
) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
a.recordedIntermediateSteps = intermediateSteps
a.recordedInputs = inputs
a.numPlanCalls++

return a.actions, a.finish, a.err
return a.actions, a.finish, nil, a.err
}

func (a testAgent) GetInputKeys() []string {
func (a *testAgent) GetInputKeys() []string {
return a.inputKeys
}

func (a testAgent) GetOutputKeys() []string {
func (a *testAgent) GetOutputKeys() []string {
return a.outputKeys
}

Expand Down
2 changes: 1 addition & 1 deletion agents/markl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ func TestMRKLOutputParser(t *testing.T) {

a := OneShotZeroAgent{}
for _, tc := range testCases {
actions, finish, err := a.parseOutput(tc.input)
actions, finish, _, err := a.parseOutput(tc.input)
require.ErrorIs(t, tc.expectedErr, err)
require.Equal(t, tc.expectedActions, actions)
require.Equal(t, tc.expectedFinish, finish)
Expand Down
15 changes: 8 additions & 7 deletions agents/mrkl.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,9 @@ func NewOneShotAgent(llm llms.Model, tools []tools.Tool, opts ...Option) *OneSho
func (a *OneShotZeroAgent) Plan(
ctx context.Context,
intermediateSteps []schema.AgentStep,
inputs map[string]string,
) ([]schema.AgentAction, *schema.AgentFinish, error) {
inputs map[string]any,
_ []llms.ChatMessage,
) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
fullInputs := make(map[string]any, len(inputs))
for key, value := range inputs {
fullInputs[key] = value
Expand All @@ -90,7 +91,7 @@ func (a *OneShotZeroAgent) Plan(
chains.WithStreamingFunc(stream),
)
if err != nil {
return nil, nil, err
return nil, nil, nil, err
}

return a.parseOutput(output)
Expand Down Expand Up @@ -131,7 +132,7 @@ func constructMrklScratchPad(steps []schema.AgentStep) string {
return scratchPad
}

func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, error) {
func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *schema.AgentFinish, []llms.ChatMessage, error) {
if strings.Contains(output, _finalAnswerAction) {
splits := strings.Split(output, _finalAnswerAction)

Expand All @@ -140,16 +141,16 @@ func (a *OneShotZeroAgent) parseOutput(output string) ([]schema.AgentAction, *sc
a.OutputKey: splits[len(splits)-1],
},
Log: output,
}, nil
}, nil, nil
}

r := regexp.MustCompile(`Action:\s*(.+)\s*Action Input:\s(?s)*(.+)`)
matches := r.FindStringSubmatch(output)
if len(matches) == 0 {
return nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
return nil, nil, nil, fmt.Errorf("%w: %s", ErrUnableToParseOutput, output)
}

return []schema.AgentAction{
{Tool: strings.TrimSpace(matches[1]), ToolInput: strings.TrimSpace(matches[2]), Log: output},
}, nil, nil
}, nil, nil, nil
}
Loading

0 comments on commit 0b0343f

Please sign in to comment.