diff --git a/tools/tool.go b/tools/tool.go index 4cfca7594..c76f7f398 100644 --- a/tools/tool.go +++ b/tools/tool.go @@ -1,6 +1,9 @@ package tools -import "context" +import ( + "context" + "fmt" +) // Tool is a tool for the llm agent to interact with different applications. type Tool interface { @@ -8,3 +11,14 @@ type Tool interface { Description() string Call(ctx context.Context, input string) (string, error) } + +type Kit []Tool + +func (tb *Kit) UseTool(ctx context.Context, toolName string, toolArgs string) (string, error) { + for _, tool := range *tb { + if tool.Name() == toolName { + return tool.Call(ctx, toolArgs) + } + } + return "", fmt.Errorf("invalid tool %v", toolName) +} diff --git a/tools/tool_test.go b/tools/tool_test.go new file mode 100644 index 000000000..f5362d8cb --- /dev/null +++ b/tools/tool_test.go @@ -0,0 +1,47 @@ +package tools + +import ( + "context" + "testing" +) + +type SomeTool struct{} + +func (st *SomeTool) Name() string { + return "An awesome tool" +} + +func (st *SomeTool) Description() string { + return "This tool is awesome" +} + +func (st *SomeTool) Call(ctx context.Context, _ string) (string, error) { + if ctx.Err() != nil { + return "", ctx.Err() + } + return "test", nil +} + +func TestTool(t *testing.T) { + t.Parallel() + t.Run("Tool Exists in Kit", func(t *testing.T) { + t.Parallel() + kit := Kit{ + &SomeTool{}, + } + _, err := kit.UseTool(context.Background(), "An awesome tool", "test") + if err != nil { + t.Errorf("Error using tool: %v", err) + } + }) + t.Run("Tool Does Not Exist in Kit", func(t *testing.T) { + t.Parallel() + kit := Kit{ + &SomeTool{}, + } + _, err := kit.UseTool(context.Background(), "A tool that does not exist", "test") + if err == nil { + t.Errorf("Expected error, got nil") + } + }) +}