11
11
Dict ,
12
12
Iterator ,
13
13
List ,
14
+ Literal ,
14
15
Mapping ,
15
16
Optional ,
16
17
Sequence ,
@@ -212,6 +213,33 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
212
213
return message_dict
213
214
214
215
216
+ _OPENAI_MODELS = [
217
+ "o1-mini" ,
218
+ "o1-preview" ,
219
+ "gpt-4o-mini" ,
220
+ "gpt-4o-mini-2024-07-18" ,
221
+ "gpt-4o" ,
222
+ "gpt-4o-2024-08-06" ,
223
+ "gpt-4o-2024-05-13" ,
224
+ "gpt-4-turbo" ,
225
+ "gpt-4-turbo-preview" ,
226
+ "gpt-4-0125-preview" ,
227
+ "gpt-4-1106-preview" ,
228
+ "gpt-3.5-turbo-1106" ,
229
+ "gpt-3.5-turbo" ,
230
+ "gpt-3.5-turbo-0301" ,
231
+ "gpt-3.5-turbo-0613" ,
232
+ "gpt-3.5-turbo-16k" ,
233
+ "gpt-3.5-turbo-16k-0613" ,
234
+ "gpt-4" ,
235
+ "gpt-4-0314" ,
236
+ "gpt-4-0613" ,
237
+ "gpt-4-32k" ,
238
+ "gpt-4-32k-0314" ,
239
+ "gpt-4-32k-0613" ,
240
+ ]
241
+
242
+
215
243
class ChatLiteLLM (BaseChatModel ):
216
244
"""Chat model that uses the LiteLLM API."""
217
245
@@ -465,6 +493,9 @@ async def _agenerate(
465
493
def bind_tools (
466
494
self ,
467
495
tools : Sequence [Union [Dict [str , Any ], Type [BaseModel ], Callable , BaseTool ]],
496
+ tool_choice : Optional [
497
+ Union [dict , str , Literal ["auto" , "none" , "required" , "any" ], bool ]
498
+ ] = None ,
468
499
** kwargs : Any ,
469
500
) -> Runnable [LanguageModelInput , BaseMessage ]:
470
501
"""Bind tool-like objects to this chat model.
@@ -476,17 +507,47 @@ def bind_tools(
476
507
Can be a dictionary, pydantic model, callable, or BaseTool. Pydantic
477
508
models, callables, and BaseTools will be automatically converted to
478
509
their schema dictionary representation.
479
- tool_choice: Which tool to require the model to call.
480
- Must be the name of the single provided function or
481
- "auto" to automatically determine which function to call
482
- (if any), or a dict of the form:
483
- {"type": "function", "function": {"name": <<tool_name>>}}.
510
+ tool_choice: Which tool to require the model to call. Options are:
511
+ - str of the form ``"<<tool_name>>"``: calls <<tool_name>> tool.
512
+ - ``"auto"``:
513
+ automatically selects a tool (including no tool).
514
+ - ``"none"``:
515
+ does not call a tool.
516
+ - ``"any"`` or ``"required"`` or ``True``:
517
+ forces least one tool to be called.
518
+ - dict of the form:
519
+ ``{"type": "function", "function": {"name": <<tool_name>>}}``
520
+ - ``False`` or ``None``: no effect
484
521
**kwargs: Any additional parameters to pass to the
485
522
:class:`~langchain.runnable.Runnable` constructor.
486
523
"""
487
524
488
525
formatted_tools = [convert_to_openai_tool (tool ) for tool in tools ]
489
- return super ().bind (tools = formatted_tools , ** kwargs )
526
+
527
+ # In case of openai if tool_choice is `any` or if bool has been provided we
528
+ # change it to `required` as that is suppored by openai.
529
+ if (
530
+ (self .model is not None and "azure" in self .model )
531
+ or (self .model_name is not None and "azure" in self .model_name )
532
+ or (self .model is not None and self .model in _OPENAI_MODELS )
533
+ or (self .model_name is not None and self .model_name in _OPENAI_MODELS )
534
+ ) and (tool_choice == "any" or isinstance (tool_choice , bool )):
535
+ tool_choice = "required"
536
+ # If tool_choice is bool apart from openai we make it `any`
537
+ elif isinstance (tool_choice , bool ):
538
+ tool_choice = "any"
539
+ elif isinstance (tool_choice , dict ):
540
+ tool_names = [
541
+ formatted_tool ["function" ]["name" ] for formatted_tool in formatted_tools
542
+ ]
543
+ if not any (
544
+ tool_name == tool_choice ["function" ]["name" ] for tool_name in tool_names
545
+ ):
546
+ raise ValueError (
547
+ f"Tool choice { tool_choice } was specified, but the only "
548
+ f"provided tools were { tool_names } ."
549
+ )
550
+ return super ().bind (tools = formatted_tools , tool_choice = tool_choice , ** kwargs )
490
551
491
552
@property
492
553
def _identifying_params (self ) -> Dict [str , Any ]:
0 commit comments