28
28
from contextvars import copy_context
29
29
from functools import partial
30
30
from inspect import signature
31
- from typing import Any , Awaitable , Callable , Dict , List , Optional , Tuple , Type , Union
31
+ from typing import (
32
+ Any ,
33
+ Awaitable ,
34
+ Callable ,
35
+ Dict ,
36
+ List ,
37
+ Optional ,
38
+ Sequence ,
39
+ Tuple ,
40
+ Type ,
41
+ Union ,
42
+ )
43
+
44
+ from typing_extensions import Annotated , get_args , get_origin
32
45
33
46
from langchain_core ._api import deprecated
34
47
from langchain_core .callbacks import (
@@ -76,11 +89,32 @@ class SchemaAnnotationError(TypeError):
76
89
"""Raised when 'args_schema' is missing or has an incorrect type annotation."""
77
90
78
91
92
+ def _is_annotated_type (typ : Type [Any ]) -> bool :
93
+ return get_origin (typ ) is Annotated
94
+
95
+
96
+ def _get_annotation_description (arg : str , arg_type : Type [Any ]) -> str | None :
97
+ if _is_annotated_type (arg_type ):
98
+ annotated_args = get_args (arg_type )
99
+ arg_type = annotated_args [0 ]
100
+ if len (annotated_args ) > 1 :
101
+ for annotation in annotated_args [1 :]:
102
+ if isinstance (annotation , str ):
103
+ return annotation
104
+ return None
105
+
106
+
79
107
def _create_subset_model (
80
- name : str , model : Type [BaseModel ], field_names : list
108
+ name : str ,
109
+ model : Type [BaseModel ],
110
+ field_names : list ,
111
+ * ,
112
+ descriptions : Optional [dict ] = None ,
113
+ fn_description : Optional [str ] = None ,
81
114
) -> Type [BaseModel ]:
82
115
"""Create a pydantic model with only a subset of model's fields."""
83
116
fields = {}
117
+
84
118
for field_name in field_names :
85
119
field = model .__fields__ [field_name ]
86
120
t = (
@@ -89,19 +123,89 @@ def _create_subset_model(
89
123
if field .required and not field .allow_none
90
124
else Optional [field .outer_type_ ]
91
125
)
126
+ if descriptions and field_name in descriptions :
127
+ field .field_info .description = descriptions [field_name ]
92
128
fields [field_name ] = (t , field .field_info )
129
+
93
130
rtn = create_model (name , ** fields ) # type: ignore
131
+ rtn .__doc__ = textwrap .dedent (fn_description or model .__doc__ or "" )
94
132
return rtn
95
133
96
134
97
135
def _get_filtered_args (
98
136
inferred_model : Type [BaseModel ],
99
137
func : Callable ,
138
+ * ,
139
+ filter_args : Sequence [str ],
100
140
) -> dict :
101
141
"""Get the arguments from a function's signature."""
102
142
schema = inferred_model .schema ()["properties" ]
103
143
valid_keys = signature (func ).parameters
104
- return {k : schema [k ] for k in valid_keys if k not in ("run_manager" , "callbacks" )}
144
+ return {
145
+ k : schema [k ]
146
+ for i , (k , param ) in enumerate (valid_keys .items ())
147
+ if k not in filter_args and (i > 0 or param .name not in ("self" , "cls" ))
148
+ }
149
+
150
+
151
+ def _parse_python_function_docstring (function : Callable ) -> Tuple [str , dict ]:
152
+ """Parse the function and argument descriptions from the docstring of a function.
153
+
154
+ Assumes the function docstring follows Google Python style guide.
155
+ """
156
+ docstring = inspect .getdoc (function )
157
+ if docstring :
158
+ docstring_blocks = docstring .split ("\n \n " )
159
+ descriptors = []
160
+ args_block = None
161
+ past_descriptors = False
162
+ for block in docstring_blocks :
163
+ if block .startswith ("Args:" ):
164
+ args_block = block
165
+ break
166
+ elif block .startswith ("Returns:" ) or block .startswith ("Example:" ):
167
+ # Don't break in case Args come after
168
+ past_descriptors = True
169
+ elif not past_descriptors :
170
+ descriptors .append (block )
171
+ else :
172
+ continue
173
+ description = " " .join (descriptors )
174
+ else :
175
+ description = ""
176
+ args_block = None
177
+ arg_descriptions = {}
178
+ if args_block :
179
+ arg = None
180
+ for line in args_block .split ("\n " )[1 :]:
181
+ if ":" in line :
182
+ arg , desc = line .split (":" , maxsplit = 1 )
183
+ arg_descriptions [arg .strip ()] = desc .strip ()
184
+ elif arg :
185
+ arg_descriptions [arg .strip ()] += " " + line .strip ()
186
+ return description , arg_descriptions
187
+
188
+
189
+ def _infer_arg_descriptions (
190
+ fn : Callable , * , parse_docstring : bool = False
191
+ ) -> Tuple [str , dict ]:
192
+ """Infer argument descriptions from a function's docstring."""
193
+ if parse_docstring :
194
+ description , arg_descriptions = _parse_python_function_docstring (fn )
195
+ else :
196
+ description = inspect .getdoc (fn ) or ""
197
+ arg_descriptions = {}
198
+ if hasattr (inspect , "get_annotations" ):
199
+ # This is for python < 3.10
200
+ annotations = inspect .get_annotations (fn ) # type: ignore
201
+ else :
202
+ annotations = getattr (fn , "__annotations__" , {})
203
+ for arg , arg_type in annotations .items ():
204
+ if arg in arg_descriptions :
205
+ continue
206
+ if desc := _get_annotation_description (arg , arg_type ):
207
+ arg_descriptions [arg ] = desc
208
+ return description , arg_descriptions
105
209
106
210
107
211
class _SchemaConfig :
@@ -114,25 +218,40 @@ class _SchemaConfig:
114
218
def create_schema_from_function (
115
219
model_name : str ,
116
220
func : Callable ,
221
+ * ,
222
+ filter_args : Optional [Sequence [str ]] = None ,
223
+ parse_docstring : bool = False ,
117
224
) -> Type [BaseModel ]:
118
225
"""Create a pydantic schema from a function's signature.
119
226
Args:
120
227
model_name: Name to assign to the generated pydandic schema
121
228
func: Function to generate the schema from
229
+ filter_args: Optional list of arguments to exclude from the schema
230
+ parse_docstring: Whether to parse the function's docstring for descriptions
231
+ for each argument.
122
232
Returns:
123
233
A pydantic model with the same arguments as the function
124
234
"""
125
235
# https://docs.pydantic.dev/latest/usage/validation_decorator/
126
236
validated = validate_arguments (func , config = _SchemaConfig ) # type: ignore
127
237
inferred_model = validated .model # type: ignore
128
- if "run_manager" in inferred_model .__fields__ :
129
- del inferred_model .__fields__ ["run_manager" ]
130
- if "callbacks" in inferred_model .__fields__ :
131
- del inferred_model .__fields__ ["callbacks" ]
238
+ filter_args = (
239
+ filter_args if filter_args is not None else ("run_manager" , "callbacks" )
240
+ )
241
+ for arg in filter_args :
242
+ if arg in inferred_model .__fields__ :
243
+ del inferred_model .__fields__ [arg ]
244
+ description , arg_descriptions = _infer_arg_descriptions (
245
+ func , parse_docstring = parse_docstring
246
+ )
132
247
# Pydantic adds placeholder virtual fields we need to strip
133
- valid_properties = _get_filtered_args (inferred_model , func )
248
+ valid_properties = _get_filtered_args (inferred_model , func , filter_args = filter_args )
134
249
return _create_subset_model (
135
- f"{ model_name } Schema" , inferred_model , list (valid_properties )
250
+ f"{ model_name } Schema" ,
251
+ inferred_model ,
252
+ list (valid_properties ),
253
+ descriptions = arg_descriptions ,
254
+ fn_description = description ,
136
255
)
137
256
138
257
0 commit comments