Skip to content

Commit

Permalink
anthropic: Spruce up tool calling example (#921)
Browse files Browse the repository at this point in the history
  • Loading branch information
tmc authored Jun 21, 2024
1 parent f8762e7 commit 183687d
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import (

func main() {
llm, err := anthropic.New(
anthropic.WithModel("claude-3-opus-20240229"),
anthropic.WithModel("claude-3-5-sonnet-20240620"),
)
if err != nil {
log.Fatal(err)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,9 +70,8 @@ func main() {
assistantResponse = llms.TextParts(llms.ChatMessageTypeAI, resp.Choices[0].Content)
messageHistory = append(messageHistory, assistantResponse)

fmt.Println("asking again... and again")
// Human asks again
humanQuestion = llms.TextParts(llms.ChatMessageTypeHuman, "How about the weather in chicago?")
// Compare responsses
humanQuestion = llms.TextParts(llms.ChatMessageTypeHuman, "How do these compare?")
messageHistory = append(messageHistory, humanQuestion)

// Send Request
Expand All @@ -82,7 +81,7 @@ func main() {
}
// Perform Tool call
messageHistory = executeToolCalls(ctx, llm, messageHistory, resp)
fmt.Println("Querying with tool response...")
fmt.Println("Asking for comparison...")
resp, err = llm.GenerateContent(ctx, messageHistory, llms.WithTools(availableTools))
if err != nil {
log.Fatal(err)
Expand Down
59 changes: 33 additions & 26 deletions llms/anthropic/anthropicllm.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@ import (
)

var (
ErrEmptyResponse = errors.New("no response")
ErrMissingToken = errors.New("missing the Anthropic API key, set it in the ANTHROPIC_API_KEY environment variable")

ErrEmptyResponse = errors.New("no response")
ErrMissingToken = errors.New("missing the Anthropic API key, set it in the ANTHROPIC_API_KEY environment variable")
ErrUnexpectedResponseLength = errors.New("unexpected length of response")
ErrInvalidContentType = errors.New("invalid content type")
ErrUnsupportedMessageType = errors.New("unsupported message type")
ErrUnsupportedContentType = errors.New("unsupported content type")
)

const (
Expand All @@ -36,9 +38,12 @@ var _ llms.Model = (*LLM)(nil)
// New returns a new Anthropic LLM.
func New(opts ...Option) (*LLM, error) {
c, err := newClient(opts...)
if err != nil {
return nil, fmt.Errorf("anthropic: failed to create client: %w", err)
}
return &LLM{
client: c,
}, err
}, nil
}

func newClient(opts ...Option) (*anthropicclient.Client, error) {
Expand Down Expand Up @@ -85,11 +90,15 @@ func (o *LLM) GenerateContent(ctx context.Context, messages []llms.MessageConten
}

func generateCompletionsContent(ctx context.Context, o *LLM, messages []llms.MessageContent, opts *llms.CallOptions) (*llms.ContentResponse, error) {
if len(messages) == 0 || len(messages[0].Parts) == 0 {
return nil, ErrEmptyResponse
}

msg0 := messages[0]
part := msg0.Parts[0]
partText, ok := part.(llms.TextContent)
if !ok {
return nil, fmt.Errorf("unexpected message type: %T", part)
return nil, fmt.Errorf("anthropic: unexpected message type: %T", part)
}
prompt := fmt.Sprintf("\n\nHuman: %s\n\nAssistant:", partText.Text)
result, err := o.client.CreateCompletion(ctx, &anthropicclient.CompletionRequest{
Expand All @@ -105,7 +114,7 @@ func generateCompletionsContent(ctx context.Context, o *LLM, messages []llms.Mes
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
return nil, fmt.Errorf("anthropic: failed to create completion: %w", err)
}

resp := &llms.ContentResponse{
Expand All @@ -121,7 +130,7 @@ func generateCompletionsContent(ctx context.Context, o *LLM, messages []llms.Mes
func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.MessageContent, opts *llms.CallOptions) (*llms.ContentResponse, error) {
chatMessages, systemPrompt, err := processMessages(messages)
if err != nil {
return nil, err
return nil, fmt.Errorf("anthropic: failed to process messages: %w", err)
}

tools := toolsToTools(opts.Tools)
Expand All @@ -140,7 +149,7 @@ func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.Messag
if o.CallbacksHandler != nil {
o.CallbacksHandler.HandleLLMError(ctx, err)
}
return nil, err
return nil, fmt.Errorf("anthropic: failed to create message: %w", err)
}

choices := make([]*llms.ContentChoice, len(result.Content))
Expand All @@ -157,13 +166,13 @@ func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.Messag
},
}
} else {
return nil, errors.New("invalid content type for text message")
return nil, fmt.Errorf("anthropic: %w for text message", ErrInvalidContentType)
}
case "tool_use":
if toolUseContent, ok := content.(*anthropicclient.ToolUseContent); ok {
argumentsJSON, err := json.Marshal(toolUseContent.Input)
if err != nil {
return nil, err
return nil, fmt.Errorf("anthropic: failed to marshal tool use arguments: %w", err)
}
choices[i] = &llms.ContentChoice{
ToolCalls: []llms.ToolCall{
Expand All @@ -182,10 +191,10 @@ func generateMessagesContent(ctx context.Context, o *LLM, messages []llms.Messag
},
}
} else {
return nil, errors.New("invalid content type for tool use message")
return nil, fmt.Errorf("anthropic: %w for tool use message", ErrInvalidContentType)
}
default:
return nil, fmt.Errorf("unsupported content type: %v", content.GetType())
return nil, fmt.Errorf("anthropic: %w: %v", ErrUnsupportedContentType, content.GetType())
}
}

Expand Down Expand Up @@ -215,31 +224,31 @@ func processMessages(messages []llms.MessageContent) ([]anthropicclient.ChatMess
case llms.ChatMessageTypeSystem:
content, err := handleSystemMessage(msg)
if err != nil {
return nil, "", err
return nil, "", fmt.Errorf("anthropic: failed to handle system message: %w", err)
}
systemPrompt += content
case llms.ChatMessageTypeHuman:
chatMessage, err := handleHumanMessage(msg)
if err != nil {
return nil, "", err
return nil, "", fmt.Errorf("anthropic: failed to handle human message: %w", err)
}
chatMessages = append(chatMessages, chatMessage)
case llms.ChatMessageTypeAI:
chatMessage, err := handleAIMessage(msg)
if err != nil {
return nil, "", err
return nil, "", fmt.Errorf("anthropic: failed to handle AI message: %w", err)
}
chatMessages = append(chatMessages, chatMessage)
case llms.ChatMessageTypeTool:
chatMessage, err := handleToolMessage(msg)
if err != nil {
return nil, "", err
return nil, "", fmt.Errorf("anthropic: failed to handle tool message: %w", err)
}
chatMessages = append(chatMessages, chatMessage)
case llms.ChatMessageTypeGeneric, llms.ChatMessageTypeFunction:
return nil, "", fmt.Errorf("unsupported message type: %v", msg.Role)
return nil, "", fmt.Errorf("anthropic: %w: %v", ErrUnsupportedMessageType, msg.Role)
default:
return nil, "", fmt.Errorf("unsupported message type: %v", msg.Role)
return nil, "", fmt.Errorf("anthropic: %w: %v", ErrUnsupportedMessageType, msg.Role)
}
}
return chatMessages, systemPrompt, nil
Expand All @@ -249,7 +258,7 @@ func handleSystemMessage(msg llms.MessageContent) (string, error) {
if textContent, ok := msg.Parts[0].(llms.TextContent); ok {
return textContent.Text, nil
}
return "", errors.New("invalid content type for system message")
return "", fmt.Errorf("anthropic: %w for system message", ErrInvalidContentType)
}

func handleHumanMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) {
Expand All @@ -259,17 +268,15 @@ func handleHumanMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, e
Content: textContent.Text,
}, nil
}
return anthropicclient.ChatMessage{}, errors.New("invalid content type for human message")
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for human message", ErrInvalidContentType)
}

func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, error) {
if toolCall, ok := msg.Parts[0].(llms.ToolCall); ok {
var inputStruct map[string]interface{}
err := json.Unmarshal([]byte(toolCall.FunctionCall.Arguments), &inputStruct)
if err != nil {
return anthropicclient.ChatMessage{
Role: RoleAssistant,
}, err
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: failed to unmarshal tool call arguments: %w", err)
}
toolUse := anthropicclient.ToolUseContent{
Type: "tool_use",
Expand All @@ -286,13 +293,13 @@ func handleAIMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, erro
if textContent, ok := msg.Parts[0].(llms.TextContent); ok {
return anthropicclient.ChatMessage{
Role: RoleAssistant,
Content: []anthropicclient.Content{anthropicclient.TextContent{
Content: []anthropicclient.Content{&anthropicclient.TextContent{
Type: "text",
Text: textContent.Text,
}},
}, nil
}
return anthropicclient.ChatMessage{}, errors.New("invalid content type for AI message")
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for AI message", ErrInvalidContentType)
}

type ToolResult struct {
Expand All @@ -314,5 +321,5 @@ func handleToolMessage(msg llms.MessageContent) (anthropicclient.ChatMessage, er
Content: []anthropicclient.Content{toolContent},
}, nil
}
return anthropicclient.ChatMessage{}, errors.New("invalid content type for tool message")
return anthropicclient.ChatMessage{}, fmt.Errorf("anthropic: %w for tool message", ErrInvalidContentType)
}
8 changes: 4 additions & 4 deletions llms/anthropic/internal/anthropicclient/messages.go
Original file line number Diff line number Diff line change
Expand Up @@ -123,14 +123,14 @@ func (m *MessageResponsePayload) UnmarshalJSON(data []byte) error {

switch typeStruct.Type {
case "text":
var tc TextContent
if err := json.Unmarshal(raw, &tc); err != nil {
tc := &TextContent{}
if err := json.Unmarshal(raw, tc); err != nil {
return err
}
m.Content = append(m.Content, tc)
case "tool_use":
var tuc ToolUseContent
if err := json.Unmarshal(raw, &tuc); err != nil {
tuc := &ToolUseContent{}
if err := json.Unmarshal(raw, tuc); err != nil {
return err
}
m.Content = append(m.Content, tuc)
Expand Down

0 comments on commit 183687d

Please sign in to comment.