Skip to content

Commit

Permalink
Update
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkbrnd committed Feb 7, 2025
1 parent 70350b7 commit 896ae72
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 59 deletions.
40 changes: 20 additions & 20 deletions libs/agno/agno/models/anthropic/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from agno.media import Image
from agno.models.base import Model
from agno.models.message import Message
from agno.models.response import ProviderResponse
from agno.models.response import ModelResponse
from agno.utils.log import logger

try:
Expand Down Expand Up @@ -453,31 +453,31 @@ def get_system_message_for_model(self) -> Optional[str]:
return tool_call_prompt
return None

def parse_provider_response(self, response: AnthropicMessage) -> ProviderResponse:
def parse_provider_response(self, response: AnthropicMessage) -> ModelResponse:
"""
Parse the Claude response into a ModelProviderResponse.
Parse the Claude response into a ModelResponse.
Args:
response: Raw response from Anthropic
Returns:
ProviderResponse: Parsed response data
ModelResponse: Parsed response data
"""
provider_response = ProviderResponse()
model_response = ModelResponse()

# Add role (Claude always uses 'assistant')
provider_response.role = response.role or "assistant"
model_response.role = response.role or "assistant"

if response.content:
first_block = response.content[0]
if isinstance(first_block, TextBlock):
provider_response.content = first_block.text
model_response.content = first_block.text
elif isinstance(first_block, ToolUseBlock):
tool_name = first_block.name
tool_input = first_block.input

if tool_input and isinstance(tool_input, dict):
provider_response.content = tool_input.get("query", "")
model_response.content = tool_input.get("query", "")

# -*- Extract tool calls from the response
if response.stop_reason == "tool_use":
Expand All @@ -490,8 +490,8 @@ def parse_provider_response(self, response: AnthropicMessage) -> ProviderRespons
if tool_input:
function_def["arguments"] = json.dumps(tool_input)

provider_response.extra.setdefault("tool_ids", []).append(block.id)
provider_response.tool_calls.append(
model_response.extra.setdefault("tool_ids", []).append(block.id)
model_response.tool_calls.append(
{
"id": block.id,
"type": "function",
Expand All @@ -501,28 +501,28 @@ def parse_provider_response(self, response: AnthropicMessage) -> ProviderRespons

# Add usage metrics
if response.usage is not None:
provider_response.response_usage = response.usage
model_response.response_usage = response.usage

return provider_response
return model_response

def parse_provider_response_delta(
self, response: Union[ContentBlockDeltaEvent, ContentBlockStopEvent, MessageDeltaEvent]
) -> ProviderResponse:
) -> ModelResponse:
"""
Parse the Claude streaming response into ModelProviderResponse objects.
Args:
response: Raw response chunk from Anthropic
Returns:
ProviderResponse: Iterator of parsed response data
ModelResponse: Iterator of parsed response data
"""
provider_response = ProviderResponse()
model_response = ModelResponse()

if isinstance(response, ContentBlockDeltaEvent):
# Handle text content
if isinstance(response.delta, TextDelta):
provider_response.content = response.delta.text
model_response.content = response.delta.text

elif isinstance(response, ContentBlockStopEvent):
# Handle tool calls
Expand All @@ -535,9 +535,9 @@ def parse_provider_response_delta(
if tool_input:
function_def["arguments"] = json.dumps(tool_input)

provider_response.extra.setdefault("tool_ids", []).append(tool_use.id)
model_response.extra.setdefault("tool_ids", []).append(tool_use.id)

provider_response.tool_calls = [
model_response.tool_calls = [
{
"id": tool_use.id,
"type": "function",
Expand All @@ -548,6 +548,6 @@ def parse_provider_response_delta(
# Handle message completion and usage metrics
elif isinstance(response, MessageStopEvent):
if response.message.usage is not None:
provider_response.response_usage = response.message.usage
model_response.response_usage = response.message.usage

return provider_response
return model_response
22 changes: 11 additions & 11 deletions libs/agno/agno/models/aws/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from agno.models.aws.bedrock import AwsBedrock
from agno.models.base import MessageData
from agno.models.message import Message
from agno.models.response import ProviderResponse, ModelResponse
from agno.models.response import ModelResponse
from agno.utils.log import logger

@dataclass
Expand Down Expand Up @@ -157,37 +157,37 @@ def format_messages(self, messages: List[Message]) -> Dict[str, Any]:

return request_body

def parse_provider_response(self, response: Dict[str, Any]) -> ProviderResponse:
def parse_provider_response(self, response: Dict[str, Any]) -> ModelResponse:
"""
Parse the response from the Bedrock API.
Args:
response (Dict[str, Any]): The response from the Bedrock API.
Returns:
ProviderResponse: The parsed response.
ModelResponse: The parsed response.
"""
provider_response = ProviderResponse()
model_response = ModelResponse()

# Extract message from output
if "output" in response and "message" in response["output"]:
message = response["output"]["message"]

# Add role
if "role" in message:
provider_response.role = message["role"]
model_response.role = message["role"]

# Extract and join text content from content list
if "content" in message:
content = message["content"]
if isinstance(content, list) and content:
text_content = [item.get("text", "") for item in content if "text" in item]
provider_response.content = "\n".join(text_content)
model_response.content = "\n".join(text_content)

# Add usage metrics
if "usage" in response:
# This ensures that the usage can be parsed upstream
provider_response.response_usage = BedrockResponseUsage(
model_response.response_usage = BedrockResponseUsage(
input_tokens=response.get("usage", {}).get("inputTokens", 0),
output_tokens=response.get("usage", {}).get("outputTokens", 0),
total_tokens=response.get("usage", {}).get("totalTokens", 0),
Expand Down Expand Up @@ -222,12 +222,12 @@ def parse_provider_response(self, response: Dict[str, Any]) -> ProviderResponse:
}
)
if tool_calls:
provider_response.tool_calls = tool_calls
model_response.tool_calls = tool_calls
if tool_requests:
provider_response.content = tool_requests[0]["text"]
provider_response.extra["tool_ids"] = tool_ids
model_response.content = tool_requests[0]["text"]
model_response.extra["tool_ids"] = tool_ids

return provider_response
return model_response

# Override the base class method
def format_function_call_results(self, messages: List[Message], function_call_results: List[Message], tool_ids: List[str]) -> None:
Expand Down
20 changes: 10 additions & 10 deletions libs/agno/agno/models/cohere/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from agno.exceptions import ModelProviderError
from agno.models.base import MessageData, Model
from agno.models.message import Message
from agno.models.response import ModelResponse, ProviderResponse
from agno.models.response import ModelResponse
from agno.utils.log import logger

try:
Expand Down Expand Up @@ -201,43 +201,43 @@ async def ainvoke_stream(
AsyncIterator[StreamedChatResponseV2]: An async iterator of streamed chat responses.
"""
request_kwargs = self.request_kwargs

try:
async for response in self.get_async_client().chat_stream(model=self.id, messages=self._format_messages(messages), **request_kwargs):
yield response
except Exception as e:
logger.error(f"Unexpected error calling Cohere API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e

def parse_provider_response(self, response: ChatResponse) -> ProviderResponse:
def parse_provider_response(self, response: ChatResponse) -> ModelResponse:
"""
Parse the model provider response.
Args:
response (ChatResponse): The response from the Cohere API.
"""
provider_response = ProviderResponse()
model_response = ModelResponse()

provider_response.role = response.message.role
model_response.role = response.message.role

response_message = response.message
if response_message.content is not None:
full_content = ""
for item in response_message.content:
full_content += item.text
provider_response.content = full_content
model_response.content = full_content

if response_message.tool_calls is not None:
provider_response.tool_calls = [t.model_dump() for t in response_message.tool_calls]
model_response.tool_calls = [t.model_dump() for t in response_message.tool_calls]

if response.usage is not None and response.usage.tokens is not None:
provider_response.response_usage = CohereResponseUsage(
model_response.response_usage = CohereResponseUsage(
input_tokens=int(response.usage.tokens.input_tokens) or 0,
output_tokens=int(response.usage.tokens.output_tokens) or 0,
total_tokens=int(response.usage.tokens.input_tokens + response.usage.tokens.output_tokens) or 0,
)

return provider_response
return model_response

def _process_stream_response(
self,
Expand Down Expand Up @@ -340,5 +340,5 @@ async def aprocess_response_stream(

def parse_provider_response_delta(
self, response: Any
) -> Iterator[ProviderResponse]:
) -> Iterator[ModelResponse]:
pass
36 changes: 18 additions & 18 deletions libs/agno/agno/models/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from agno.exceptions import ModelProviderError
from agno.models.base import Model
from agno.models.message import Message
from agno.models.response import ProviderResponse
from agno.models.response import ModelResponse
from agno.utils.log import logger
from agno.utils.openai import add_images_to_message

Expand Down Expand Up @@ -341,68 +341,68 @@ def parse_tool_calls(tool_calls_data: List[ChoiceDeltaToolCall]) -> List[Dict[st

def parse_provider_response(
self, response: ChatCompletion
) -> ProviderResponse:
) -> ModelResponse:
"""
Parse the Groq response into a ModelProviderResponse.
Parse the Groq response into a ModelResponse.
Args:
response: Raw response from Groq
Returns:
ProviderResponse: Parsed response data
ModelResponse: Parsed response data
"""
provider_response = ProviderResponse()
model_response = ModelResponse()

# Get response message
response_message = response.choices[0].message

# Add role
if response_message.role is not None:
provider_response.role = response_message.role
model_response.role = response_message.role

# Add content
if response_message.content is not None:
provider_response.content = response_message.content
model_response.content = response_message.content

# Add tool calls
if response_message.tool_calls is not None and len(response_message.tool_calls) > 0:
try:
provider_response.tool_calls = [t.model_dump() for t in response_message.tool_calls]
model_response.tool_calls = [t.model_dump() for t in response_message.tool_calls]
except Exception as e:
logger.warning(f"Error processing tool calls: {e}")

# Add usage metrics if present
if response.usage is not None:
provider_response.response_usage = response.usage
model_response.response_usage = response.usage

return provider_response
return model_response

def parse_provider_response_delta(
self, response: ChatCompletionChunk
) -> ProviderResponse:
) -> ModelResponse:
"""
Parse the Groq streaming response into ModelProviderResponse objects.
Parse the Groq streaming response into ModelResponse objects.
Args:
response: Raw response chunk from Groq
Returns:
ProviderResponse: Iterator of parsed response data
ModelResponse: Iterator of parsed response data
"""
provider_response = ProviderResponse()
model_response = ModelResponse()
if len(response.choices) > 0:
delta: ChoiceDelta = response.choices[0].delta

# Add content
if delta.content is not None:
provider_response.content = delta.content
model_response.content = delta.content

# Add tool calls
if delta.tool_calls is not None:
provider_response.tool_calls = delta.tool_calls
model_response.tool_calls = delta.tool_calls

# Add usage metrics if present
if response.usage is not None:
provider_response.response_usage = response.usage
model_response.response_usage = response.usage

return provider_response
return model_response

0 comments on commit 896ae72

Please sign in to comment.