Skip to content

Commit

Permalink
Fix for asyncstream
Browse files Browse the repository at this point in the history
  • Loading branch information
dirkbrnd committed Feb 7, 2025
1 parent ddd3fe8 commit 2090182
Show file tree
Hide file tree
Showing 16 changed files with 158 additions and 492 deletions.
12 changes: 12 additions & 0 deletions cookbook/models/huggingface/async_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import asyncio
from agno.agent import Agent
from agno.models.huggingface import HuggingFace

agent = Agent(
model=HuggingFace(
id="mistralai/Mistral-7B-Instruct-v0.2", max_tokens=4096, temperature=0
),
)
asyncio.run(agent.aprint_response(
"What is meaning of life and then recommend 5 best books to read about it"
))
13 changes: 13 additions & 0 deletions cookbook/models/huggingface/async_basic_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import asyncio
from agno.agent import Agent
from agno.models.huggingface import HuggingFace

agent = Agent(
model=HuggingFace(
id="mistralai/Mistral-7B-Instruct-v0.2", max_tokens=4096, temperature=0
),
)
asyncio.run(agent.aprint_response(
"What is meaning of life and then recommend 5 best books to read about it",
stream=True,
))
File renamed without changes.
File renamed without changes.
12 changes: 12 additions & 0 deletions cookbook/models/openai/async_basic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
import asyncio
from agno.agent import Agent, RunResponse # noqa
from agno.models.openai import OpenAIChat

agent = Agent(model=OpenAIChat(id="gpt-4o"), markdown=True)

# Get the response in a variable
# run: RunResponse = agent.run("Share a 2 sentence horror story")
# print(run.content)

# Print the response in the terminal
asyncio.run(agent.aprint_response("Share a 2 sentence horror story"))
14 changes: 14 additions & 0 deletions cookbook/models/openai/async_basic_stream.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import asyncio
from typing import Iterator # noqa
from agno.agent import Agent, RunResponse # noqa
from agno.models.openai import OpenAIChat

agent = Agent(model=OpenAIChat(id="gpt-4o"), markdown=True)

# Get the response in a variable
# run_response: Iterator[RunResponse] = agent.run("Share a 2 sentence horror story", stream=True)
# for chunk in run_response:
# print(chunk.content)

# Print the response in the terminal
asyncio.run(agent.aprint_response("Share a 2 sentence horror story", stream=True))
20 changes: 9 additions & 11 deletions libs/agno/agno/models/anthropic/claude.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections.abc import AsyncIterator
import json
from dataclasses import dataclass
from os import getenv
Expand Down Expand Up @@ -393,7 +394,7 @@ async def ainvoke(self, messages: List[Message]) -> AnthropicMessage:
logger.error(f"Unexpected error calling Claude API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e

async def ainvoke_stream(self, messages: List[Message]) -> Any:
async def ainvoke_stream(self, messages: List[Message]) -> AsyncIterator[Any]:
"""
Stream an asynchronous response from the Anthropic API.
Expand All @@ -406,16 +407,13 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any:
try:
chat_messages, system_message = _format_messages(messages)
request_kwargs = self._prepare_request_kwargs(system_message)

return (
await self.get_async_client()
.messages.stream(
model=self.id,
messages=chat_messages, # type: ignore
**request_kwargs,
)
.__aenter__()
)
async with self.get_async_client().messages.stream(
model=self.id,
messages=chat_messages, # type: ignore
**request_kwargs,
) as stream:
async for chunk in stream:
yield chunk
except APIConnectionError as e:
logger.error(f"Connection error while calling Claude API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand Down
2 changes: 1 addition & 1 deletion libs/agno/agno/models/aws/claude.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def process_response_stream(
)

# Update metrics
self.add_usage_metrics_to_assistant_message(
self._add_usage_metrics_to_assistant_message(
assistant_message=assistant_message, response_usage=response_usage
)

Expand Down
13 changes: 8 additions & 5 deletions libs/agno/agno/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,10 @@ async def _aprocess_model_response(

# Update model response with assistant message content and audio
if assistant_message.content is not None:
model_response.content += assistant_message.get_content_string()
if model_response.content is None:
model_response.content = assistant_message.get_content_string()
else:
model_response.content += assistant_message.get_content_string()
if assistant_message.audio_output is not None:
model_response.audio = assistant_message.audio_output
if provider_response.extra is not None:
Expand Down Expand Up @@ -394,7 +397,7 @@ def _populate_assistant_message(

# Add usage metrics if provided
if provider_response.response_usage is not None:
self.add_usage_metrics_to_assistant_message(
self._add_usage_metrics_to_assistant_message(
assistant_message=assistant_message, response_usage=provider_response.response_usage
)

Expand Down Expand Up @@ -491,7 +494,7 @@ async def aprocess_response_stream(
"""
Process a streaming response from the model.
"""
async for response_delta in await self.ainvoke_stream(messages=messages):
async for response_delta in self.ainvoke_stream(messages=messages):
model_response_delta = self.parse_provider_response_delta(response_delta)
if model_response_delta:
for model_response in self._populate_stream_data_and_assistant_message(
Expand Down Expand Up @@ -604,7 +607,7 @@ def _populate_stream_data_and_assistant_message(
stream_data.extra.update(model_response.extra)

if model_response.response_usage is not None:
self.add_usage_metrics_to_assistant_message(
self._add_usage_metrics_to_assistant_message(
assistant_message=assistant_message, response_usage=model_response.response_usage
)

Expand Down Expand Up @@ -956,7 +959,7 @@ def format_function_call_results(
if len(function_call_results) > 0:
messages.extend(function_call_results)

def add_usage_metrics_to_assistant_message(self, assistant_message: Message, response_usage: Any) -> None:
def _add_usage_metrics_to_assistant_message(self, assistant_message: Message, response_usage: Any) -> None:
"""
Add usage metrics from the model provider to the assistant message.
Expand Down
2 changes: 1 addition & 1 deletion libs/agno/agno/models/cohere/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,7 @@ def _process_stream_response(
and response.delta.usage is not None
and response.delta.usage.tokens is not None
):
self.add_usage_metrics_to_assistant_message(
self._add_usage_metrics_to_assistant_message(
assistant_message=assistant_message,
response_usage=CohereResponseUsage(
input_tokens=response.delta.usage.tokens.input_tokens,
Expand Down
4 changes: 3 additions & 1 deletion libs/agno/agno/models/groq/groq.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,12 +284,14 @@ async def ainvoke_stream(self, messages: List[Message]) -> Any:
Any: An asynchronous iterator of chat completion chunks.
"""
try:
return await self.get_async_client().chat.completions.create(
stream = await self.get_async_client().chat.completions.create(
model=self.id,
messages=[format_message(m) for m in messages], # type: ignore
stream=True,
**self.request_kwargs,
)
async for chunk in stream:
yield chunk
except Exception as e:
logger.error(f"Unexpected error calling Groq API: {str(e)}")
raise ModelProviderError(e, self.name, self.id) from e
Expand Down
Loading

0 comments on commit 2090182

Please sign in to comment.