Skip to content

Commit 524ee6d

Browse files
Invalid tool_choice being passed to ChatLiteLLM (langchain-ai#28198)
- **Description:** Invalid `tool_choice` is given to `ChatLiteLLM` to `bind_tools` due to it's parent's class default value being pass through `with_structured_output`. - **Issue:** langchain-ai#28176
1 parent dd0085a commit 524ee6d

File tree

1 file changed

+67
-6
lines changed
  • libs/community/langchain_community/chat_models

1 file changed

+67
-6
lines changed

libs/community/langchain_community/chat_models/litellm.py

+67-6
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
Dict,
1212
Iterator,
1313
List,
14+
Literal,
1415
Mapping,
1516
Optional,
1617
Sequence,
@@ -212,6 +213,33 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
212213
return message_dict
213214

214215

216+
_OPENAI_MODELS = [
217+
"o1-mini",
218+
"o1-preview",
219+
"gpt-4o-mini",
220+
"gpt-4o-mini-2024-07-18",
221+
"gpt-4o",
222+
"gpt-4o-2024-08-06",
223+
"gpt-4o-2024-05-13",
224+
"gpt-4-turbo",
225+
"gpt-4-turbo-preview",
226+
"gpt-4-0125-preview",
227+
"gpt-4-1106-preview",
228+
"gpt-3.5-turbo-1106",
229+
"gpt-3.5-turbo",
230+
"gpt-3.5-turbo-0301",
231+
"gpt-3.5-turbo-0613",
232+
"gpt-3.5-turbo-16k",
233+
"gpt-3.5-turbo-16k-0613",
234+
"gpt-4",
235+
"gpt-4-0314",
236+
"gpt-4-0613",
237+
"gpt-4-32k",
238+
"gpt-4-32k-0314",
239+
"gpt-4-32k-0613",
240+
]
241+
242+
215243
class ChatLiteLLM(BaseChatModel):
216244
"""Chat model that uses the LiteLLM API."""
217245

@@ -465,6 +493,9 @@ async def _agenerate(
465493
def bind_tools(
466494
self,
467495
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
496+
tool_choice: Optional[
497+
Union[dict, str, Literal["auto", "none", "required", "any"], bool]
498+
] = None,
468499
**kwargs: Any,
469500
) -> Runnable[LanguageModelInput, BaseMessage]:
470501
"""Bind tool-like objects to this chat model.
@@ -476,17 +507,47 @@ def bind_tools(
476507
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
477508
models, callables, and BaseTools will be automatically converted to
478509
their schema dictionary representation.
479-
tool_choice: Which tool to require the model to call.
480-
Must be the name of the single provided function or
481-
"auto" to automatically determine which function to call
482-
(if any), or a dict of the form:
483-
{"type": "function", "function": {"name": <<tool_name>>}}.
510+
tool_choice: Which tool to require the model to call. Options are:
511+
- str of the form ``"<<tool_name>>"``: calls <<tool_name>> tool.
512+
- ``"auto"``:
513+
automatically selects a tool (including no tool).
514+
- ``"none"``:
515+
does not call a tool.
516+
- ``"any"`` or ``"required"`` or ``True``:
517+
forces least one tool to be called.
518+
- dict of the form:
519+
``{"type": "function", "function": {"name": <<tool_name>>}}``
520+
- ``False`` or ``None``: no effect
484521
**kwargs: Any additional parameters to pass to the
485522
:class:`~langchain.runnable.Runnable` constructor.
486523
"""
487524

488525
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
489-
return super().bind(tools=formatted_tools, **kwargs)
526+
527+
# In case of openai if tool_choice is `any` or if bool has been provided we
528+
# change it to `required` as that is suppored by openai.
529+
if (
530+
(self.model is not None and "azure" in self.model)
531+
or (self.model_name is not None and "azure" in self.model_name)
532+
or (self.model is not None and self.model in _OPENAI_MODELS)
533+
or (self.model_name is not None and self.model_name in _OPENAI_MODELS)
534+
) and (tool_choice == "any" or isinstance(tool_choice, bool)):
535+
tool_choice = "required"
536+
# If tool_choice is bool apart from openai we make it `any`
537+
elif isinstance(tool_choice, bool):
538+
tool_choice = "any"
539+
elif isinstance(tool_choice, dict):
540+
tool_names = [
541+
formatted_tool["function"]["name"] for formatted_tool in formatted_tools
542+
]
543+
if not any(
544+
tool_name == tool_choice["function"]["name"] for tool_name in tool_names
545+
):
546+
raise ValueError(
547+
f"Tool choice {tool_choice} was specified, but the only "
548+
f"provided tools were {tool_names}."
549+
)
550+
return super().bind(tools=formatted_tools, tool_choice=tool_choice, **kwargs)
490551

491552
@property
492553
def _identifying_params(self) -> Dict[str, Any]:

0 commit comments

Comments
 (0)