Skip to content

Commit df2dc62

Browse files
committed
WIP fix for serde + extract_fields
1 parent 1057a40 commit df2dc62

File tree

1 file changed

+58
-22
lines changed

1 file changed

+58
-22
lines changed

hamilton/function_modifiers/expanders.py

+58-22
Original file line numberDiff line numberDiff line change
@@ -731,6 +731,36 @@ def _validate_extract_fields(fields: dict):
731731
)
732732

733733

734+
async def dict_generator_async(
735+
*args,
736+
fn,
737+
fill_with,
738+
fields,
739+
**kwargs,
740+
):
741+
dict_generated = await fn(*args, **kwargs)
742+
if fill_with is not None:
743+
for field in fields:
744+
if field not in dict_generated:
745+
dict_generated[field] = fill_with
746+
return dict_generated
747+
748+
749+
async def dict_generator(
750+
*args,
751+
fn,
752+
fill_with,
753+
fields,
754+
**kwargs,
755+
):
756+
dict_generated = fn(*args, **kwargs)
757+
if fill_with is not None:
758+
for field in fields:
759+
if field not in dict_generated:
760+
dict_generated[field] = fill_with
761+
return dict_generated
762+
763+
734764
class extract_fields(base.SingleNodeNodeTransformer):
735765
"""Extracts fields from a dictionary of output."""
736766

@@ -804,29 +834,35 @@ def transform_node(
804834
"""
805835
fn = node_.callable
806836
base_doc = node_.documentation
807-
837+
dict_generator_fn = (
838+
functools.partial(dict_generator, fn=fn, fill_with=self.fill_with, fields=self.fields)
839+
if not (inspect.iscoroutinefunction(fn))
840+
else functools.partial(
841+
dict_generator_async, fn=fn, fill_with=self.fill_with, fields=self.fields
842+
)
843+
)
808844
# if fn is async
809-
if inspect.iscoroutinefunction(fn):
810-
811-
async def dict_generator(*args, **kwargs):
812-
dict_generated = await fn(*args, **kwargs)
813-
if self.fill_with is not None:
814-
for field in self.fields:
815-
if field not in dict_generated:
816-
dict_generated[field] = self.fill_with
817-
return dict_generated
818-
819-
else:
820-
821-
def dict_generator(*args, **kwargs):
822-
dict_generated = fn(*args, **kwargs)
823-
if self.fill_with is not None:
824-
for field in self.fields:
825-
if field not in dict_generated:
826-
dict_generated[field] = self.fill_with
827-
return dict_generated
828-
829-
output_nodes = [node_.copy_with(callabl=dict_generator)]
845+
# if inspect.iscoroutinefunction(fn):
846+
#
847+
# async def dict_generator(*args, **kwargs):
848+
# dict_generated = await fn(*args, **kwargs)
849+
# if self.fill_with is not None:
850+
# for field in self.fields:
851+
# if field not in dict_generated:
852+
# dict_generated[field] = self.fill_with
853+
# return dict_generated
854+
#
855+
# else:
856+
#
857+
# def dict_generator(*args, **kwargs):
858+
# dict_generated = fn(*args, **kwargs)
859+
# if self.fill_with is not None:
860+
# for field in self.fields:
861+
# if field not in dict_generated:
862+
# dict_generated[field] = self.fill_with
863+
# return dict_generated
864+
865+
output_nodes = [node_.copy_with(callabl=dict_generator_fn)]
830866

831867
for field, field_type in self.fields.items():
832868
doc_string = base_doc # default doc string of base function.

0 commit comments

Comments
 (0)