-
Notifications
You must be signed in to change notification settings - Fork 138
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: decorators should wrap functions with vanilla wrapper #1282
base: main
Are you sure you want to change the base?
Changes from all commits
1f05f2e
d2052b1
cd1f317
1d057a6
448ce0b
17c298e
22509fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,9 +1,11 @@ | ||
import abc | ||
import collections | ||
import functools | ||
import inspect | ||
import itertools | ||
import logging | ||
from abc import ABC | ||
from inspect import unwrap | ||
|
||
try: | ||
from types import EllipsisType | ||
|
@@ -99,18 +101,40 @@ def __call__(self, fn: Callable): | |
:param fn: Function to decorate | ||
:return: The function again, with the desired properties. | ||
""" | ||
self.validate(fn) | ||
# # stop unwrapping if not a hamilton function | ||
# # should only be one level of "hamilton wrapping" - and that's what we attach things to. | ||
self.validate(unwrap(fn, stop=lambda f: not hasattr(f, "__hamilton__"))) | ||
if not hasattr(fn, "__hamilton__"): | ||
if inspect.iscoroutinefunction(fn): | ||
|
||
@functools.wraps(fn) | ||
async def wrapper(*args, **kwargs): | ||
return await fn(*args, **kwargs) | ||
else: | ||
|
||
@functools.wraps(fn) | ||
def wrapper(*args, **kwargs): | ||
return fn(*args, **kwargs) | ||
|
||
wrapper.__hamilton__ = True | ||
if not hasattr(fn, "__hamilton_wrappers__"): | ||
fn.__hamilton_wrappers__ = [wrapper] | ||
else: | ||
fn.__hamilton_wrappers__.append(wrapper) | ||
else: | ||
wrapper = fn | ||
|
||
lifecycle_name = self.__class__.get_lifecycle_name() | ||
if hasattr(fn, self.get_lifecycle_name()): | ||
if hasattr(wrapper, self.get_lifecycle_name()): | ||
if not self.allows_multiple(): | ||
raise ValueError( | ||
f"Got multiple decorators for decorator @{self.__class__}. Only one allowed." | ||
) | ||
curr_value = getattr(fn, lifecycle_name) | ||
setattr(fn, lifecycle_name, curr_value + [self]) | ||
curr_value = getattr(wrapper, lifecycle_name) | ||
setattr(wrapper, lifecycle_name, curr_value + [self]) | ||
else: | ||
setattr(fn, lifecycle_name, [self]) | ||
return fn | ||
setattr(wrapper, lifecycle_name, [self]) | ||
return wrapper | ||
|
||
def required_config(self) -> Optional[List[str]]: | ||
"""Declares the required configuration keys for this decorator. | ||
|
@@ -805,6 +829,9 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node] | |
which configuration they need. | ||
:return: A list of nodes into which this function transforms. | ||
""" | ||
# check for mutate... | ||
fn = handle_mutate_hack(fn) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Take away the name "hack" |
||
|
||
try: | ||
function_decorators = get_node_decorators(fn, config) | ||
node_resolvers = function_decorators[NodeResolver.get_lifecycle_name()] | ||
|
@@ -833,6 +860,36 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node] | |
raise e | ||
|
||
|
||
def handle_mutate_hack(fn): | ||
"""Function that encapsulates the mutate hack check. | ||
|
||
This isn't pretty. It's a hack to get around how special | ||
mutate is. | ||
|
||
This will return the "wrapped function" if this is | ||
a vanilla python function that has a pointer to a pipe_output | ||
decorated function. This is because all other decorators | ||
directly wrap the function, but mutate does not. It adds | ||
a pointer to the function in question that we follow here. | ||
|
||
This code depends on what the mutate class does | ||
and what the pipe_output decorator does. | ||
|
||
:param fn: Function to check | ||
:return: Function or wrapped function to use if applicable. | ||
""" | ||
if hasattr(fn, "__hamilton_wrappers__"): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this worth the complexity it adds? Is there a simpler way to approach it? Need to grok better. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The problem is that with mutate the current design was to attach to the function and it would be resolved properly via that way. I think the better approach would be to register mutates via a registry that is applied to functions that are found ... and that way we wouldn't need to attach it to the function directly... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, if we use a registry for mutate, would it be sensible to just have one global registry for all decorators instead of the additional 1-layer wrapper? |
||
wrapper = fn.__hamilton_wrappers__[0] # assume first one | ||
if hasattr(wrapper, "transform"): | ||
for decorator in wrapper.transform: | ||
from hamilton.function_modifiers import macros | ||
|
||
if isinstance(decorator, macros.pipe_output) and decorator.is_via_mutate: | ||
fn = wrapper # overwrite callable with right one | ||
break | ||
return fn | ||
|
||
|
||
class InvalidDecoratorException(Exception): | ||
pass | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -231,6 +231,65 @@ def replacement_function( | |
# now the error will be clear enough | ||
return node_.callable(*args, **new_kwargs) | ||
|
||
async def async_replacement_function( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. note -- merge with the above when you get to polishing |
||
*args, | ||
upstream_dependencies=upstream_dependencies, | ||
literal_dependencies=literal_dependencies, | ||
grouped_list_dependencies=grouped_list_dependencies, | ||
grouped_dict_dependencies=grouped_dict_dependencies, | ||
former_inputs=list(node_.input_types.keys()), # noqa | ||
**kwargs, | ||
): | ||
"""This function rewrites what is passed in kwargs to the right kwarg for the function. | ||
The passed in kwargs are all the dependencies of this node. Note that we actually have the "former inputs", | ||
which are what the node declares as its dependencies. So, we just have to loop through all of them to | ||
get the "new" value. This "new" value comes from the parameterization. | ||
|
||
Note that much of this code should *probably* live within the source/value/grouped functions, but | ||
it is here as we're not 100% sure about the abstraction. | ||
|
||
TODO -- think about how the grouped/source/literal functions should be able to grab the values from kwargs/args. | ||
Should be easy -- they should just have something like a "resolve(**kwargs)" function that they can call. | ||
""" | ||
new_kwargs = {} | ||
for node_input in former_inputs: | ||
if node_input in upstream_dependencies: | ||
# If the node is specified by `source`, then we get the value from the kwargs | ||
new_kwargs[node_input] = kwargs[upstream_dependencies[node_input].source] | ||
elif node_input in literal_dependencies: | ||
# If the node is specified by `value`, then we get the literal value (no need for kwargs) | ||
new_kwargs[node_input] = literal_dependencies[node_input].value | ||
elif node_input in grouped_list_dependencies: | ||
# If the node is specified by `group`, then we get the list of values from the kwargs or the literal | ||
new_kwargs[node_input] = [] | ||
for replacement in grouped_list_dependencies[node_input].sources: | ||
resolved_value = ( | ||
kwargs[replacement.source] | ||
if replacement.get_dependency_type() | ||
== ParametrizedDependencySource.UPSTREAM | ||
else replacement.value | ||
) | ||
new_kwargs[node_input].append(resolved_value) | ||
elif node_input in grouped_dict_dependencies: | ||
# If the node is specified by `group`, then we get the dict of values from the kwargs or the literal | ||
new_kwargs[node_input] = {} | ||
for dependency, replacement in grouped_dict_dependencies[ | ||
node_input | ||
].sources.items(): | ||
resolved_value = ( | ||
kwargs[replacement.source] | ||
if replacement.get_dependency_type() | ||
== ParametrizedDependencySource.UPSTREAM | ||
else replacement.value | ||
) | ||
new_kwargs[node_input][dependency] = resolved_value | ||
elif node_input in kwargs: | ||
new_kwargs[node_input] = kwargs[node_input] | ||
# This case is left blank for optional parameters. If we error here, we'll break | ||
# the (supported) case of optionals. We do know whether its optional but for | ||
# now the error will be clear enough | ||
return await node_.callable(*args, **new_kwargs) | ||
|
||
new_input_types = {} | ||
grouped_dependencies = { | ||
**grouped_list_dependencies, | ||
|
@@ -271,7 +330,9 @@ def replacement_function( | |
name=output_node, | ||
doc_string=docstring, # TODO -- change docstring | ||
callabl=functools.partial( | ||
replacement_function, | ||
replacement_function | ||
if not inspect.iscoroutinefunction(node_.callable) | ||
else async_replacement_function, | ||
**{parameter: val.value for parameter, val in literal_dependencies.items()}, | ||
), | ||
input_types=new_input_types, | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from .base import * # noqa: F401, F403 | ||
from .base_extended import * # noqa: F401, F403 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we only do one level deep? Aren't there some weird cases with multiple imports, imports of imports, etc...?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
before we attached to the function itself. What I'm trying to do here is replicate that by using one wrapper layer around a function.
E.g.
Should only have one "wrapper" around
some_func
- which is what is being modified by the decorators.This then enables someone to do this:
This above would enable
a
,x
, anda_validated
to be in the DAG. Before it wasn't possible.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, I got that. Why just one layer was the question?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.