diff --git a/tools/tool.go b/tools/tool.go index 4cfca7594..b84ad1d81 100644 --- a/tools/tool.go +++ b/tools/tool.go @@ -1,6 +1,11 @@ package tools -import "context" +import ( + "context" + "errors" +) + +const ErrInvalidTool = "invalid_tool" // Tool is a tool for the llm agent to interact with different applications. type Tool interface { @@ -8,3 +13,14 @@ type Tool interface { Description() string Call(ctx context.Context, input string) (string, error) } + +type Kit []Tool + +func (tb *Kit) UseTool(ctx context.Context, toolName string, toolArgs string) (string, error) { + for _, tool := range *tb { + if tool.Name() == toolName { + return tool.Call(ctx, toolArgs) + } + } + return "", errors.New(ErrInvalidTool) +} diff --git a/tools/tool_test.go b/tools/tool_test.go new file mode 100644 index 000000000..b5f5e5791 --- /dev/null +++ b/tools/tool_test.go @@ -0,0 +1,48 @@ +package tools + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +type someTool struct{} + +func (st *someTool) Name() string { + return "An awesome tool" +} + +func (st *someTool) Description() string { + return "This tool is awesome" +} + +func (st *someTool) Call(ctx context.Context, _ string) (string, error) { + if ctx.Err() != nil { + return "", ctx.Err() + } + return "test", nil +} + +func TestToolWithTestify(t *testing.T) { + t.Parallel() + kit := Kit{ + &someTool{}, + } + + // Test when the tool exists + t.Run("Tool Exists in Kit", func(t *testing.T) { + t.Parallel() + result, err := kit.UseTool(context.Background(), "An awesome tool", "test") + assert.NoError(t, err) + assert.Equal(t, "test", result) + }) + + // Test when the tool does not exist + t.Run("Tool Does Not Exist in Kit", func(t *testing.T) { + t.Parallel() + _, err := kit.UseTool(context.Background(), "A tool that does not exist", "test") + assert.Error(t, err) + assert.Equal(t, ErrInvalidTool, err.Error()) + }) +}