@@ -731,6 +731,36 @@ def _validate_extract_fields(fields: dict):
731
731
)
732
732
733
733
734
+ async def dict_generator_async (
735
+ * args ,
736
+ fn ,
737
+ fill_with ,
738
+ fields ,
739
+ ** kwargs ,
740
+ ):
741
+ dict_generated = await fn (* args , ** kwargs )
742
+ if fill_with is not None :
743
+ for field in fields :
744
+ if field not in dict_generated :
745
+ dict_generated [field ] = fill_with
746
+ return dict_generated
747
+
748
+
749
+ async def dict_generator (
750
+ * args ,
751
+ fn ,
752
+ fill_with ,
753
+ fields ,
754
+ ** kwargs ,
755
+ ):
756
+ dict_generated = fn (* args , ** kwargs )
757
+ if fill_with is not None :
758
+ for field in fields :
759
+ if field not in dict_generated :
760
+ dict_generated [field ] = fill_with
761
+ return dict_generated
762
+
763
+
734
764
class extract_fields (base .SingleNodeNodeTransformer ):
735
765
"""Extracts fields from a dictionary of output."""
736
766
@@ -804,29 +834,35 @@ def transform_node(
804
834
"""
805
835
fn = node_ .callable
806
836
base_doc = node_ .documentation
807
-
837
+ dict_generator_fn = (
838
+ functools .partial (dict_generator , fn = fn , fill_with = self .fill_with , fields = self .fields )
839
+ if not (inspect .iscoroutinefunction (fn ))
840
+ else functools .partial (
841
+ dict_generator_async , fn = fn , fill_with = self .fill_with , fields = self .fields
842
+ )
843
+ )
808
844
# if fn is async
809
- if inspect .iscoroutinefunction (fn ):
810
-
811
- async def dict_generator (* args , ** kwargs ):
812
- dict_generated = await fn (* args , ** kwargs )
813
- if self .fill_with is not None :
814
- for field in self .fields :
815
- if field not in dict_generated :
816
- dict_generated [field ] = self .fill_with
817
- return dict_generated
818
-
819
- else :
820
-
821
- def dict_generator (* args , ** kwargs ):
822
- dict_generated = fn (* args , ** kwargs )
823
- if self .fill_with is not None :
824
- for field in self .fields :
825
- if field not in dict_generated :
826
- dict_generated [field ] = self .fill_with
827
- return dict_generated
828
-
829
- output_nodes = [node_ .copy_with (callabl = dict_generator )]
845
+ # if inspect.iscoroutinefunction(fn):
846
+ #
847
+ # async def dict_generator(*args, **kwargs):
848
+ # dict_generated = await fn(*args, **kwargs)
849
+ # if self.fill_with is not None:
850
+ # for field in self.fields:
851
+ # if field not in dict_generated:
852
+ # dict_generated[field] = self.fill_with
853
+ # return dict_generated
854
+ #
855
+ # else:
856
+ #
857
+ # def dict_generator(*args, **kwargs):
858
+ # dict_generated = fn(*args, **kwargs)
859
+ # if self.fill_with is not None:
860
+ # for field in self.fields:
861
+ # if field not in dict_generated:
862
+ # dict_generated[field] = self.fill_with
863
+ # return dict_generated
864
+
865
+ output_nodes = [node_ .copy_with (callabl = dict_generator_fn )]
830
866
831
867
for field , field_type in self .fields .items ():
832
868
doc_string = base_doc # default doc string of base function.
0 commit comments