Skip to content

Commit

Permalink
Remove robust tool chunk stream handling - added in Haystack 2.10, fo…
Browse files Browse the repository at this point in the history
…rce tool calling for test robustness
  • Loading branch information
vblagoje committed Feb 12, 2025
1 parent 3f4e171 commit 35f484e
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -128,66 +128,4 @@ def to_dict(self) -> Dict[str, Any]:
generation_kwargs=self.generation_kwargs,
api_key=self.api_key.to_dict(),
tools=[tool.to_dict() for tool in self.tools] if self.tools else None,
)

def _convert_streaming_chunks_to_chat_message(self, chunk: Any, chunks: List[StreamingChunk]) -> ChatMessage:
"""
Connects the streaming chunks into a single ChatMessage.
:param chunk: The last chunk returned by the OpenAI API.
:param chunks: The list of all `StreamingChunk` objects.
"""

# to have streaming support and tool calls we need to do some extra work here because the superclass
# looks for tool calls in the first chunk only, while Mistral does not return tool calls in the first chunk
# so we need to find the chunk that has tool calls and use it to create the ChatMessage
# after we implement https://github.com/deepset-ai/haystack/pull/8829 we'll be able to remove this code
# and use the superclass implementation
text = "".join([chunk.content for chunk in chunks])
tool_calls = []

# are there any tool calls in the chunks?
if any(chunk.meta.get("tool_calls") for chunk in chunks):
payloads = {} # Use a dict to track tool calls by ID
for chunk_payload in chunks:
deltas = chunk_payload.meta.get("tool_calls") or []

# deltas is a list of ChoiceDeltaToolCall
for delta in deltas:
if delta.id not in payloads:
payloads[delta.id] = {"id": delta.id, "arguments": "", "name": "", "type": None}
# ChoiceDeltaToolCall has a 'function' field of type ChoiceDeltaToolCallFunction
if delta.function:
# For tool calls with the same ID, use the latest values
if delta.function.name is not None:
payloads[delta.id]["name"] = delta.function.name
if delta.function.arguments is not None:
# Use the latest arguments value
payloads[delta.id]["arguments"] = delta.function.arguments
if delta.type is not None:
payloads[delta.id]["type"] = delta.type

for payload in payloads.values():
arguments_str = payload["arguments"]
try:
# Try to parse the concatenated arguments string as JSON
arguments = json.loads(arguments_str)
tool_calls.append(ToolCall(id=payload["id"], tool_name=payload["name"], arguments=arguments))
except json.JSONDecodeError:
logger.warning(
"Mistral returned a malformed JSON string for tool call arguments. This tool call "
"will be skipped. Tool call ID: {_id}, Tool name: {_name}, Arguments: {_arguments}",
_id=payload["id"],
_name=payload["name"],
_arguments=arguments_str,
)

meta = {
"model": chunk.model,
"index": 0,
"finish_reason": chunk.choices[0].finish_reason,
"completion_start_time": chunks[0].meta.get("received_at"), # first chunk received
"usage": {}, # we don't have usage data for streaming responses
}

return ChatMessage.from_assistant(text=text, tool_calls=tool_calls, meta=meta)
)
12 changes: 8 additions & 4 deletions integrations/mistral/tests/test_mistral_chat_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ def __call__(self, chunk: StreamingChunk) -> None:

callback = Callback()
component = MistralChatGenerator(streaming_callback=callback)
results = component.run([ChatMessage.from_user("What's the capital of France?")])
results = component.run([ChatMessage.from_user("What's the capital of France?")],
generation_kwargs={"tool_choice": "any"})

assert len(results["replies"]) == 1
message: ChatMessage = results["replies"][0]
Expand Down Expand Up @@ -314,7 +315,8 @@ def test_live_run_with_tools_and_response(self, tools):
"""
initial_messages = [ChatMessage.from_user("What's the weather like in Paris?")]
component = MistralChatGenerator(tools=tools)
results = component.run(messages=initial_messages)
results = component.run(messages=initial_messages,
generation_kwargs={"tool_choice": "any"})

assert len(results["replies"]) > 0, "No replies received"

Expand Down Expand Up @@ -374,7 +376,8 @@ def __call__(self, chunk: StreamingChunk) -> None:

callback = Callback()
component = MistralChatGenerator(tools=tools, streaming_callback=callback)
results = component.run([ChatMessage.from_user("What's the weather like in Paris?")])
results = component.run([ChatMessage.from_user("What's the weather like in Paris?")],
generation_kwargs={"tool_choice": "any"})

assert len(results["replies"]) > 0, "No replies received"
assert callback.counter > 1, "Streaming callback was not called multiple times"
Expand Down Expand Up @@ -413,7 +416,8 @@ def test_pipeline_with_mistral_chat_generator(self, tools):
pipeline.connect("generator", "tool_invoker")

results = pipeline.run(
data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")]}}
data={"generator": {"messages": [ChatMessage.from_user("What's the weather like in Paris?")],
"generation_kwargs": {"tool_choice": "any"}}}
)

assert (
Expand Down

0 comments on commit 35f484e

Please sign in to comment.