Skip to content

Commit 8476688

Browse files
jernejfrankelijahbenizzy
authored andcommitted
Add async support for pipe_family
Enables running pipe_input, pipe_output and mutate with asyncio.
1 parent 1b07790 commit 8476688

File tree

3 files changed

+54
-3
lines changed

3 files changed

+54
-3
lines changed

hamilton/function_modifiers/macros.py

+9-2
Original file line numberDiff line numberDiff line change
@@ -608,7 +608,7 @@ def resolve_namespace(self, default_namespace: str) -> Tuple[str, ...]:
608608
def bind_function_args(
609609
self, current_param: Optional[str]
610610
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
611-
"""Binds function arguments, given current, chained parameeter
611+
"""Binds function arguments, given current, chained parameter
612612
613613
:param current_param: Current, chained parameter. None, if we're not chaining.
614614
:return: A tuple of (upstream_inputs, literal_inputs)
@@ -1302,10 +1302,17 @@ def transform_node(
13021302
# We pick a reserved prefix that ovoids clashes with user defined functions / nodes
13031303
original_node = node_.copy_with(name=f"{node_.name}.raw")
13041304

1305+
is_async = inspect.iscoroutinefunction(fn) # determine if its async
1306+
13051307
def __identity(foo: Any) -> Any:
13061308
return foo
13071309

1308-
transforms = transforms + (step(__identity).named(fn.__name__),)
1310+
async def async_function(**kwargs):
1311+
return await __identity(**kwargs)
1312+
1313+
fn_to_use = async_function if is_async else __identity
1314+
1315+
transforms = transforms + (step(fn_to_use).named(fn.__name__),)
13091316
nodes, _ = chain_transforms(
13101317
target_arg=original_node.name,
13111318
transforms=transforms,

hamilton/node.py

+9-1
Original file line numberDiff line numberDiff line change
@@ -360,15 +360,23 @@ def reassign_inputs(
360360
if input_values is None:
361361
input_values = {}
362362

363+
is_async = inspect.iscoroutinefunction(self.callable) # determine if its async
364+
363365
def new_callable(**kwargs) -> Any:
364366
reverse_input_names = {v: k for k, v in input_names.items()}
365367
kwargs = {**kwargs, **input_values}
366368
return self.callable(**{reverse_input_names.get(k, k): v for k, v in kwargs.items()})
367369

370+
async def async_function(**kwargs):
371+
return await new_callable(**kwargs)
372+
373+
fn_to_use = async_function if is_async else new_callable
374+
368375
new_input_types = {
369376
input_names.get(k, k): v for k, v in self.input_types.items() if k not in input_values
370377
}
371-
out = self.copy_with(callabl=new_callable, input_types=new_input_types)
378+
# out = self.copy_with(callabl=new_callable, input_types=new_input_types)
379+
out = self.copy_with(callabl=fn_to_use, input_types=new_input_types)
372380
return out
373381

374382
def transform_output(

tests/test_node.py

+36
Original file line numberDiff line numberDiff line change
@@ -137,3 +137,39 @@ def test_node_from_future_annotation_standard():
137137
def test_node_from_future_annotation_collected():
138138
collected = nodes_with_future_annotation.collected
139139
assert node.Node.from_fn(collected).node_role == node.NodeType.COLLECT
140+
141+
142+
def test_reassign_inputs():
143+
def foo(a: int, b: str) -> int:
144+
return a + len(b)
145+
146+
node_ = Node.from_fn(foo)
147+
148+
first_arg_node = node_.reassign_inputs(input_names={"a": "c"})
149+
new_first_arg = list(first_arg_node.input_types.keys())
150+
second_arg_node = node_.reassign_inputs(input_names={"b": "d"})
151+
new_second_arg = list(second_arg_node.input_types.keys())
152+
both_arg_node = node_.reassign_inputs(input_names={"a": "c", "b": "d"})
153+
new_both_arg = list(both_arg_node.input_types.keys())
154+
assert new_first_arg[0] == "c"
155+
assert new_first_arg[1] == "b"
156+
assert new_second_arg[0] == "a"
157+
assert new_second_arg[1] == "d"
158+
assert new_both_arg[0] == "c"
159+
assert new_both_arg[1] == "d"
160+
assert both_arg_node(**{"c": 2, "d": "abc"}) == 5
161+
162+
163+
@pytest.mark.asyncio
164+
async def test_subdag_async():
165+
async def foo(a: int, b: str) -> int:
166+
return a + len(b)
167+
168+
node_ = Node.from_fn(foo)
169+
170+
new_node = node_.reassign_inputs(input_names={"a": "c", "b": "d"})
171+
new_args = list(new_node.input_types.keys())
172+
assert new_args[0] == "c"
173+
assert new_args[1] == "d"
174+
assert inspect.iscoroutinefunction(new_node.callable)
175+
assert await new_node(**{"c": 2, "d": "abc"}) == 5

0 commit comments

Comments
 (0)