Skip to content

Commit 6cd5682

Browse files
authored
[Core] Unify function schema parsing (langchain-ai#23370)
Use pydantic to infer nested schemas and all that fun. Include bagatur's convenient docstring parser Include annotation support Previously we didn't adequately support many typehints in the bind_tools() method on raw functions (like optionals/unions, nested types, etc.)
1 parent 2a2c0d1 commit 6cd5682

File tree

4 files changed

+221
-149
lines changed

4 files changed

+221
-149
lines changed

libs/core/langchain_core/tools.py

+128-9
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,20 @@
2828
from contextvars import copy_context
2929
from functools import partial
3030
from inspect import signature
31-
from typing import Any, Awaitable, Callable, Dict, List, Optional, Tuple, Type, Union
31+
from typing import (
32+
Any,
33+
Awaitable,
34+
Callable,
35+
Dict,
36+
List,
37+
Optional,
38+
Sequence,
39+
Tuple,
40+
Type,
41+
Union,
42+
)
43+
44+
from typing_extensions import Annotated, get_args, get_origin
3245

3346
from langchain_core._api import deprecated
3447
from langchain_core.callbacks import (
@@ -76,11 +89,32 @@ class SchemaAnnotationError(TypeError):
7689
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
7790

7891

92+
def _is_annotated_type(typ: Type[Any]) -> bool:
93+
return get_origin(typ) is Annotated
94+
95+
96+
def _get_annotation_description(arg: str, arg_type: Type[Any]) -> str | None:
97+
if _is_annotated_type(arg_type):
98+
annotated_args = get_args(arg_type)
99+
arg_type = annotated_args[0]
100+
if len(annotated_args) > 1:
101+
for annotation in annotated_args[1:]:
102+
if isinstance(annotation, str):
103+
return annotation
104+
return None
105+
106+
79107
def _create_subset_model(
80-
name: str, model: Type[BaseModel], field_names: list
108+
name: str,
109+
model: Type[BaseModel],
110+
field_names: list,
111+
*,
112+
descriptions: Optional[dict] = None,
113+
fn_description: Optional[str] = None,
81114
) -> Type[BaseModel]:
82115
"""Create a pydantic model with only a subset of model's fields."""
83116
fields = {}
117+
84118
for field_name in field_names:
85119
field = model.__fields__[field_name]
86120
t = (
@@ -89,19 +123,89 @@ def _create_subset_model(
89123
if field.required and not field.allow_none
90124
else Optional[field.outer_type_]
91125
)
126+
if descriptions and field_name in descriptions:
127+
field.field_info.description = descriptions[field_name]
92128
fields[field_name] = (t, field.field_info)
129+
93130
rtn = create_model(name, **fields) # type: ignore
131+
rtn.__doc__ = textwrap.dedent(fn_description or model.__doc__ or "")
94132
return rtn
95133

96134

97135
def _get_filtered_args(
98136
inferred_model: Type[BaseModel],
99137
func: Callable,
138+
*,
139+
filter_args: Sequence[str],
100140
) -> dict:
101141
"""Get the arguments from a function's signature."""
102142
schema = inferred_model.schema()["properties"]
103143
valid_keys = signature(func).parameters
104-
return {k: schema[k] for k in valid_keys if k not in ("run_manager", "callbacks")}
144+
return {
145+
k: schema[k]
146+
for i, (k, param) in enumerate(valid_keys.items())
147+
if k not in filter_args and (i > 0 or param.name not in ("self", "cls"))
148+
}
149+
150+
151+
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
152+
"""Parse the function and argument descriptions from the docstring of a function.
153+
154+
Assumes the function docstring follows Google Python style guide.
155+
"""
156+
docstring = inspect.getdoc(function)
157+
if docstring:
158+
docstring_blocks = docstring.split("\n\n")
159+
descriptors = []
160+
args_block = None
161+
past_descriptors = False
162+
for block in docstring_blocks:
163+
if block.startswith("Args:"):
164+
args_block = block
165+
break
166+
elif block.startswith("Returns:") or block.startswith("Example:"):
167+
# Don't break in case Args come after
168+
past_descriptors = True
169+
elif not past_descriptors:
170+
descriptors.append(block)
171+
else:
172+
continue
173+
description = " ".join(descriptors)
174+
else:
175+
description = ""
176+
args_block = None
177+
arg_descriptions = {}
178+
if args_block:
179+
arg = None
180+
for line in args_block.split("\n")[1:]:
181+
if ":" in line:
182+
arg, desc = line.split(":", maxsplit=1)
183+
arg_descriptions[arg.strip()] = desc.strip()
184+
elif arg:
185+
arg_descriptions[arg.strip()] += " " + line.strip()
186+
return description, arg_descriptions
187+
188+
189+
def _infer_arg_descriptions(
190+
fn: Callable, *, parse_docstring: bool = False
191+
) -> Tuple[str, dict]:
192+
"""Infer argument descriptions from a function's docstring."""
193+
if parse_docstring:
194+
description, arg_descriptions = _parse_python_function_docstring(fn)
195+
else:
196+
description = inspect.getdoc(fn) or ""
197+
arg_descriptions = {}
198+
if hasattr(inspect, "get_annotations"):
199+
# This is for python < 3.10
200+
annotations = inspect.get_annotations(fn) # type: ignore
201+
else:
202+
annotations = getattr(fn, "__annotations__", {})
203+
for arg, arg_type in annotations.items():
204+
if arg in arg_descriptions:
205+
continue
206+
if desc := _get_annotation_description(arg, arg_type):
207+
arg_descriptions[arg] = desc
208+
return description, arg_descriptions
105209

106210

107211
class _SchemaConfig:
@@ -114,25 +218,40 @@ class _SchemaConfig:
114218
def create_schema_from_function(
115219
model_name: str,
116220
func: Callable,
221+
*,
222+
filter_args: Optional[Sequence[str]] = None,
223+
parse_docstring: bool = False,
117224
) -> Type[BaseModel]:
118225
"""Create a pydantic schema from a function's signature.
119226
Args:
120227
model_name: Name to assign to the generated pydandic schema
121228
func: Function to generate the schema from
229+
filter_args: Optional list of arguments to exclude from the schema
230+
parse_docstring: Whether to parse the function's docstring for descriptions
231+
for each argument.
122232
Returns:
123233
A pydantic model with the same arguments as the function
124234
"""
125235
# https://docs.pydantic.dev/latest/usage/validation_decorator/
126236
validated = validate_arguments(func, config=_SchemaConfig) # type: ignore
127237
inferred_model = validated.model # type: ignore
128-
if "run_manager" in inferred_model.__fields__:
129-
del inferred_model.__fields__["run_manager"]
130-
if "callbacks" in inferred_model.__fields__:
131-
del inferred_model.__fields__["callbacks"]
238+
filter_args = (
239+
filter_args if filter_args is not None else ("run_manager", "callbacks")
240+
)
241+
for arg in filter_args:
242+
if arg in inferred_model.__fields__:
243+
del inferred_model.__fields__[arg]
244+
description, arg_descriptions = _infer_arg_descriptions(
245+
func, parse_docstring=parse_docstring
246+
)
132247
# Pydantic adds placeholder virtual fields we need to strip
133-
valid_properties = _get_filtered_args(inferred_model, func)
248+
valid_properties = _get_filtered_args(inferred_model, func, filter_args=filter_args)
134249
return _create_subset_model(
135-
f"{model_name}Schema", inferred_model, list(valid_properties)
250+
f"{model_name}Schema",
251+
inferred_model,
252+
list(valid_properties),
253+
descriptions=arg_descriptions,
254+
fn_description=description,
136255
)
137256

138257

libs/core/langchain_core/utils/function_calling.py

+14-132
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,8 @@
22

33
from __future__ import annotations
44

5-
import inspect
65
import logging
76
import uuid
8-
from types import FunctionType, MethodType
97
from typing import (
108
TYPE_CHECKING,
119
Any,
@@ -14,13 +12,12 @@
1412
List,
1513
Literal,
1614
Optional,
17-
Tuple,
1815
Type,
1916
Union,
2017
cast,
2118
)
2219

23-
from typing_extensions import Annotated, TypedDict, get_args, get_origin
20+
from typing_extensions import TypedDict
2421

2522
from langchain_core._api import deprecated
2623
from langchain_core.messages import (
@@ -123,146 +120,31 @@ def _get_python_function_name(function: Callable) -> str:
123120
return function.__name__
124121

125122

126-
def _parse_python_function_docstring(function: Callable) -> Tuple[str, dict]:
127-
"""Parse the function and argument descriptions from the docstring of a function.
128-
129-
Assumes the function docstring follows Google Python style guide.
130-
"""
131-
docstring = inspect.getdoc(function)
132-
if docstring:
133-
docstring_blocks = docstring.split("\n\n")
134-
descriptors = []
135-
args_block = None
136-
past_descriptors = False
137-
for block in docstring_blocks:
138-
if block.startswith("Args:"):
139-
args_block = block
140-
break
141-
elif block.startswith("Returns:") or block.startswith("Example:"):
142-
# Don't break in case Args come after
143-
past_descriptors = True
144-
elif not past_descriptors:
145-
descriptors.append(block)
146-
else:
147-
continue
148-
description = " ".join(descriptors)
149-
else:
150-
description = ""
151-
args_block = None
152-
arg_descriptions = {}
153-
if args_block:
154-
arg = None
155-
for line in args_block.split("\n")[1:]:
156-
if ":" in line:
157-
arg, desc = line.split(":", maxsplit=1)
158-
arg_descriptions[arg.strip()] = desc.strip()
159-
elif arg:
160-
arg_descriptions[arg.strip()] += " " + line.strip()
161-
return description, arg_descriptions
162-
163-
164-
def _is_annotated_type(typ: Type[Any]) -> bool:
165-
return get_origin(typ) is Annotated
166-
167-
168-
def _get_python_function_arguments(function: Callable, arg_descriptions: dict) -> dict:
169-
"""Get JsonSchema describing a Python functions arguments.
170-
171-
Assumes all function arguments are of primitive types (int, float, str, bool) or
172-
are subclasses of pydantic.BaseModel.
173-
"""
174-
properties = {}
175-
annotations = inspect.getfullargspec(function).annotations
176-
for arg, arg_type in annotations.items():
177-
if arg == "return":
178-
continue
179-
180-
if _is_annotated_type(arg_type):
181-
annotated_args = get_args(arg_type)
182-
arg_type = annotated_args[0]
183-
if len(annotated_args) > 1:
184-
for annotation in annotated_args[1:]:
185-
if isinstance(annotation, str):
186-
arg_descriptions[arg] = annotation
187-
break
188-
if (
189-
isinstance(arg_type, type)
190-
and hasattr(arg_type, "model_json_schema")
191-
and callable(arg_type.model_json_schema)
192-
):
193-
properties[arg] = arg_type.model_json_schema()
194-
elif (
195-
isinstance(arg_type, type)
196-
and hasattr(arg_type, "schema")
197-
and callable(arg_type.schema)
198-
):
199-
properties[arg] = arg_type.schema()
200-
elif (
201-
hasattr(arg_type, "__name__")
202-
and getattr(arg_type, "__name__") in PYTHON_TO_JSON_TYPES
203-
):
204-
properties[arg] = {"type": PYTHON_TO_JSON_TYPES[arg_type.__name__]}
205-
elif (
206-
hasattr(arg_type, "__dict__")
207-
and getattr(arg_type, "__dict__").get("__origin__", None) == Literal
208-
):
209-
properties[arg] = {
210-
"enum": list(arg_type.__args__),
211-
"type": PYTHON_TO_JSON_TYPES[arg_type.__args__[0].__class__.__name__],
212-
}
213-
else:
214-
logger.warning(
215-
f"Argument {arg} of type {arg_type} from function {function.__name__} "
216-
"could not be not be converted to a JSON schema."
217-
)
218-
219-
if arg in arg_descriptions:
220-
if arg not in properties:
221-
properties[arg] = {}
222-
properties[arg]["description"] = arg_descriptions[arg]
223-
224-
return properties
225-
226-
227-
def _get_python_function_required_args(function: Callable) -> List[str]:
228-
"""Get the required arguments for a Python function."""
229-
spec = inspect.getfullargspec(function)
230-
required = spec.args[: -len(spec.defaults)] if spec.defaults else spec.args
231-
required += [k for k in spec.kwonlyargs if k not in (spec.kwonlydefaults or {})]
232-
233-
is_function_type = isinstance(function, FunctionType)
234-
is_method_type = isinstance(function, MethodType)
235-
if required and is_function_type and required[0] == "self":
236-
required = required[1:]
237-
elif required and is_method_type and required[0] == "cls":
238-
required = required[1:]
239-
return required
240-
241-
242123
@deprecated(
243124
"0.1.16",
244125
alternative="langchain_core.utils.function_calling.convert_to_openai_function()",
245126
removal="0.3.0",
246127
)
247128
def convert_python_function_to_openai_function(
248129
function: Callable,
249-
) -> Dict[str, Any]:
130+
) -> FunctionDescription:
250131
"""Convert a Python function to an OpenAI function-calling API compatible dict.
251132
252133
Assumes the Python function has type hints and a docstring with a description. If
253134
the docstring has Google Python style argument descriptions, these will be
254135
included as well.
255136
"""
256-
description, arg_descriptions = _parse_python_function_docstring(function)
257-
return {
258-
"name": _get_python_function_name(function),
259-
"description": description,
260-
"parameters": {
261-
"type": "object",
262-
"properties": _get_python_function_arguments(function, arg_descriptions),
263-
"required": _get_python_function_required_args(function),
264-
},
265-
}
137+
from langchain_core import tools
138+
139+
func_name = _get_python_function_name(function)
140+
model = tools.create_schema_from_function(
141+
func_name, function, filter_args=(), parse_docstring=True
142+
)
143+
return convert_pydantic_to_openai_function(
144+
model,
145+
name=func_name,
146+
description=model.__doc__,
147+
)
266148

267149

268150
@deprecated(
@@ -343,7 +225,7 @@ def convert_to_openai_function(
343225
elif isinstance(function, BaseTool):
344226
return cast(Dict, format_tool_to_openai_function(function))
345227
elif callable(function):
346-
return convert_python_function_to_openai_function(function)
228+
return cast(Dict, convert_python_function_to_openai_function(function))
347229
else:
348230
raise ValueError(
349231
f"Unsupported function\n\n{function}\n\nFunctions must be passed in"

0 commit comments

Comments
 (0)