diff --git a/hamilton/function_modifiers/base.py b/hamilton/function_modifiers/base.py index a721d945a..e269b3076 100644 --- a/hamilton/function_modifiers/base.py +++ b/hamilton/function_modifiers/base.py @@ -1,9 +1,11 @@ import abc import collections import functools +import inspect import itertools import logging from abc import ABC +from inspect import unwrap try: from types import EllipsisType @@ -99,18 +101,40 @@ def __call__(self, fn: Callable): :param fn: Function to decorate :return: The function again, with the desired properties. """ - self.validate(fn) + # # stop unwrapping if not a hamilton function + # # should only be one level of "hamilton wrapping" - and that's what we attach things to. + self.validate(unwrap(fn, stop=lambda f: not hasattr(f, "__hamilton__"))) + if not hasattr(fn, "__hamilton__"): + if inspect.iscoroutinefunction(fn): + + @functools.wraps(fn) + async def wrapper(*args, **kwargs): + return await fn(*args, **kwargs) + else: + + @functools.wraps(fn) + def wrapper(*args, **kwargs): + return fn(*args, **kwargs) + + wrapper.__hamilton__ = True + if not hasattr(fn, "__hamilton_wrappers__"): + fn.__hamilton_wrappers__ = [wrapper] + else: + fn.__hamilton_wrappers__.append(wrapper) + else: + wrapper = fn + lifecycle_name = self.__class__.get_lifecycle_name() - if hasattr(fn, self.get_lifecycle_name()): + if hasattr(wrapper, self.get_lifecycle_name()): if not self.allows_multiple(): raise ValueError( f"Got multiple decorators for decorator @{self.__class__}. Only one allowed." ) - curr_value = getattr(fn, lifecycle_name) - setattr(fn, lifecycle_name, curr_value + [self]) + curr_value = getattr(wrapper, lifecycle_name) + setattr(wrapper, lifecycle_name, curr_value + [self]) else: - setattr(fn, lifecycle_name, [self]) - return fn + setattr(wrapper, lifecycle_name, [self]) + return wrapper def required_config(self) -> Optional[List[str]]: """Declares the required configuration keys for this decorator. @@ -805,6 +829,9 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node] which configuration they need. :return: A list of nodes into which this function transforms. """ + # check for mutate... + fn = handle_mutate_hack(fn) + try: function_decorators = get_node_decorators(fn, config) node_resolvers = function_decorators[NodeResolver.get_lifecycle_name()] @@ -833,6 +860,36 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node] raise e +def handle_mutate_hack(fn): + """Function that encapsulates the mutate hack check. + + This isn't pretty. It's a hack to get around how special + mutate is. + + This will return the "wrapped function" if this is + a vanilla python function that has a pointer to a pipe_output + decorated function. This is because all other decorators + directly wrap the function, but mutate does not. It adds + a pointer to the function in question that we follow here. + + This code depends on what the mutate class does + and what the pipe_output decorator does. + + :param fn: Function to check + :return: Function or wrapped function to use if applicable. + """ + if hasattr(fn, "__hamilton_wrappers__"): + wrapper = fn.__hamilton_wrappers__[0] # assume first one + if hasattr(wrapper, "transform"): + for decorator in wrapper.transform: + from hamilton.function_modifiers import macros + + if isinstance(decorator, macros.pipe_output) and decorator.is_via_mutate: + fn = wrapper # overwrite callable with right one + break + return fn + + class InvalidDecoratorException(Exception): pass diff --git a/hamilton/function_modifiers/expanders.py b/hamilton/function_modifiers/expanders.py index 62610f898..12c9a5a3c 100644 --- a/hamilton/function_modifiers/expanders.py +++ b/hamilton/function_modifiers/expanders.py @@ -231,6 +231,65 @@ def replacement_function( # now the error will be clear enough return node_.callable(*args, **new_kwargs) + async def async_replacement_function( + *args, + upstream_dependencies=upstream_dependencies, + literal_dependencies=literal_dependencies, + grouped_list_dependencies=grouped_list_dependencies, + grouped_dict_dependencies=grouped_dict_dependencies, + former_inputs=list(node_.input_types.keys()), # noqa + **kwargs, + ): + """This function rewrites what is passed in kwargs to the right kwarg for the function. + The passed in kwargs are all the dependencies of this node. Note that we actually have the "former inputs", + which are what the node declares as its dependencies. So, we just have to loop through all of them to + get the "new" value. This "new" value comes from the parameterization. + + Note that much of this code should *probably* live within the source/value/grouped functions, but + it is here as we're not 100% sure about the abstraction. + + TODO -- think about how the grouped/source/literal functions should be able to grab the values from kwargs/args. + Should be easy -- they should just have something like a "resolve(**kwargs)" function that they can call. + """ + new_kwargs = {} + for node_input in former_inputs: + if node_input in upstream_dependencies: + # If the node is specified by `source`, then we get the value from the kwargs + new_kwargs[node_input] = kwargs[upstream_dependencies[node_input].source] + elif node_input in literal_dependencies: + # If the node is specified by `value`, then we get the literal value (no need for kwargs) + new_kwargs[node_input] = literal_dependencies[node_input].value + elif node_input in grouped_list_dependencies: + # If the node is specified by `group`, then we get the list of values from the kwargs or the literal + new_kwargs[node_input] = [] + for replacement in grouped_list_dependencies[node_input].sources: + resolved_value = ( + kwargs[replacement.source] + if replacement.get_dependency_type() + == ParametrizedDependencySource.UPSTREAM + else replacement.value + ) + new_kwargs[node_input].append(resolved_value) + elif node_input in grouped_dict_dependencies: + # If the node is specified by `group`, then we get the dict of values from the kwargs or the literal + new_kwargs[node_input] = {} + for dependency, replacement in grouped_dict_dependencies[ + node_input + ].sources.items(): + resolved_value = ( + kwargs[replacement.source] + if replacement.get_dependency_type() + == ParametrizedDependencySource.UPSTREAM + else replacement.value + ) + new_kwargs[node_input][dependency] = resolved_value + elif node_input in kwargs: + new_kwargs[node_input] = kwargs[node_input] + # This case is left blank for optional parameters. If we error here, we'll break + # the (supported) case of optionals. We do know whether its optional but for + # now the error will be clear enough + return await node_.callable(*args, **new_kwargs) + new_input_types = {} grouped_dependencies = { **grouped_list_dependencies, @@ -271,7 +330,9 @@ def replacement_function( name=output_node, doc_string=docstring, # TODO -- change docstring callabl=functools.partial( - replacement_function, + replacement_function + if not inspect.iscoroutinefunction(node_.callable) + else async_replacement_function, **{parameter: val.value for parameter, val in literal_dependencies.items()}, ), input_types=new_input_types, diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py index ae877db25..a02b45de0 100644 --- a/hamilton/function_modifiers/macros.py +++ b/hamilton/function_modifiers/macros.py @@ -1256,6 +1256,11 @@ def __init__( if self.chain: raise NotImplementedError("@flow() is not yet supported -- this is ") + self.is_via_mutate = False # flag to know how this was instantiated. + + def set_is_via_mutate(self): + self.is_via_mutate = True + def _filter_individual_target(self, node_): """Resolves target option on the transform level. Adds option that we can decide for each applicable which output node it will target. @@ -1605,15 +1610,25 @@ def __call__(self, mutating_fn: Callable): mutating_fn=mutating_fn, remote_applicable_builder=remote_applicable ) found_pipe_output = False - if hasattr(remote_applicable.target_fn, base.NodeTransformer.get_lifecycle_name()): - for decorator in remote_applicable.target_fn.transform: + wrapper_fn = None + # Assumptions: + # 1. This code depends on the `__call__()` definition in the Hamilton base decorator class + # 2. This is then used in `handle_mutate_hack()` in the Hamilton function modifier base.py. + if hasattr(remote_applicable.target_fn, "__hamilton_wrappers__"): + # get first wrapper + wrapper_fn = remote_applicable.target_fn.__hamilton_wrappers__[0] + elif hasattr(remote_applicable.target_fn, "__hamilton__"): + wrapper_fn = remote_applicable.target_fn + + if wrapper_fn: + for decorator in wrapper_fn.transform: if isinstance(decorator, pipe_output): decorator.transforms = decorator.transforms + (new_pipe_step,) found_pipe_output = True if not found_pipe_output: - remote_applicable.target_fn = pipe_output( - new_pipe_step, collapse=self.collapse, _chain=self.chain - )(remote_applicable.target_fn) + decorator = pipe_output(new_pipe_step, collapse=self.collapse, _chain=self.chain) + decorator.set_is_via_mutate() + remote_applicable.target_fn = decorator(remote_applicable.target_fn) return mutating_fn diff --git a/hamilton/function_modifiers/validation.py b/hamilton/function_modifiers/validation.py index 93c54003e..bf543e1b4 100644 --- a/hamilton/function_modifiers/validation.py +++ b/hamilton/function_modifiers/validation.py @@ -1,4 +1,5 @@ import abc +import inspect from collections import defaultdict from typing import Any, Callable, Collection, Dict, List, Type @@ -46,6 +47,14 @@ def validation_function(validator_to_call: dq_base.DataValidator = validator, ** result = list(kwargs.values())[0] # This should just have one kwarg return validator_to_call.validate(result) + async def async_validation_function( + validator_to_call: dq_base.DataValidator = validator, **kwargs + ): + result = list(kwargs.values())[0] # This should just have one kwarg + if inspect.isawaitable(result): + result = await result + return validator_to_call.validate(result) + validator_node_name = node_.name + "_" + validator.name() validator_name_count[validator_node_name] = ( validator_name_count[validator_node_name] + 1 @@ -58,7 +67,9 @@ def validation_function(validator_to_call: dq_base.DataValidator = validator, ** name=validator_node_name, # TODO -- determine a good approach towards naming this typ=dq_base.ValidationResult, doc_string=validator.description(), - callabl=validation_function, + callabl=validation_function + if not inspect.iscoroutinefunction(node_.callable) + else async_validation_function, node_source=node.NodeType.STANDARD, input_types={raw_node.name: (node_.type, node.DependencyType.REQUIRED)}, tags={ diff --git a/hamilton/plugins/h_ddog.py b/hamilton/plugins/h_ddog.py index aacd1c1cf..923997311 100644 --- a/hamilton/plugins/h_ddog.py +++ b/hamilton/plugins/h_ddog.py @@ -8,6 +8,7 @@ logger = logging.getLogger(__name__) try: + # TODO: this works for ddtrace < 3.0; Span got moved somewhere else in 3.0.. from ddtrace import Span, context, tracer except ImportError as e: logger.error("ImportError: %s", e) diff --git a/pyproject.toml b/pyproject.toml index f155dabf3..a3aa6256d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -41,7 +41,7 @@ dask-core = ["dask-core"] dask-dataframe = ["dask[dataframe]"] dask-diagnostics = ["dask[diagnostics]"] dask-distributed = ["dask[distributed]"] -datadog = ["ddtrace"] +datadog = ["ddtrace<3.0"] # Temporary pin until h_ddog.py import is fixed for >3.0 version dev = [ "pre-commit", "ruff==0.5.7", # this should match `.pre-commit-config.yaml` @@ -53,7 +53,7 @@ docs = [ "commonmark==0.9.1", # read the docs pins "dask-expr; python_version == '3.9'", "dask[distributed]", - "ddtrace", + "ddtrace<3.0", "diskcache", # required for all the plugins "dlt", diff --git a/tests/function_modifiers/test_base.py b/tests/function_modifiers/test_base.py index ecab08c70..4afa43db3 100644 --- a/tests/function_modifiers/test_base.py +++ b/tests/function_modifiers/test_base.py @@ -1,12 +1,15 @@ +from inspect import unwrap from typing import Collection, Dict, List +from unittest.mock import Mock -import pytest as pytest +import pytest from hamilton import node, settings from hamilton.function_modifiers import InvalidDecoratorException, base from hamilton.function_modifiers.base import ( MissingConfigParametersException, NodeTransformer, + NodeTransformLifecycle, TargetType, ) from hamilton.node import Node @@ -149,3 +152,86 @@ def test_add_fn_metadata(): ] assert len(nodes_with_fn_pointer) == len(nodes) assert all([n.originating_functions == (test_add_fn_metadata,) for n in nodes]) + + +class MockNodeTransformLifecycle(NodeTransformLifecycle): + @classmethod + def get_lifecycle_name(cls): + return "mock_lifecycle" + + @classmethod + def allows_multiple(cls): + return True + + def validate(self, fn): + pass + + +def test_decorator_adds_attributes(): + mock_decorator = MockNodeTransformLifecycle() + + def my_function(a: int) -> int: + pass + + decorated_fn = mock_decorator(my_function) + + assert hasattr(decorated_fn, "mock_lifecycle") + assert decorated_fn.mock_lifecycle == [mock_decorator] + assert hasattr(decorated_fn, "__hamilton__") + + +def test_decorator_allows_multiple_raises_error(): + class MockMultipleNodeTransformLifecycle(NodeTransformLifecycle): + @classmethod + def get_lifecycle_name(cls): + return "mock_lifecycle" + + @classmethod + def allows_multiple(cls): + return False + + def validate(self, fn): + pass + + mock_decorator = MockMultipleNodeTransformLifecycle() + mock_fn = Mock() + decorated_fn = mock_decorator(mock_fn) + + with pytest.raises(ValueError): + mock_decorator(decorated_fn) + + +def test_decorator_only_wraps_once(): + """Tests that the decorator only wraps once.""" + mock_decorator = MockNodeTransformLifecycle() + + def my_function(a: int) -> int: + pass + + decorated_fn = mock_decorator(my_function) + decorated_fn = mock_decorator(decorated_fn) + decorated_fn = mock_decorator(decorated_fn) + + assert decorated_fn.__hamilton__ is True + assert decorated_fn.__wrapped__ == my_function # one level of wrapping only + + +def test_wrapping_and_unwrapping_logic(): + """Tests unwrapping logic works as expected.""" + + def my_function(a: int) -> int: + pass + + # Wrap the function + wrapped_fn = MockNodeTransformLifecycle()(my_function) + # Unwrap the function + unwrapped_fn = unwrap(wrapped_fn, stop=lambda f: not hasattr(f, "__hamilton__")) + + # Ensure the function is unwrapped correctly + assert unwrapped_fn == my_function + assert not hasattr(unwrapped_fn, "__hamilton__") + + wrapped_fn2 = MockNodeTransformLifecycle()(wrapped_fn) + unwrapped_fn2 = unwrap(wrapped_fn2, stop=lambda f: not hasattr(f, "__hamilton__")) + assert wrapped_fn2 == wrapped_fn # these should be the same + assert unwrapped_fn2 == my_function # these should be the same diff --git a/tests/function_modifiers/test_macros.py b/tests/function_modifiers/test_macros.py index dc0963203..5e97696cd 100644 --- a/tests/function_modifiers/test_macros.py +++ b/tests/function_modifiers/test_macros.py @@ -1107,8 +1107,6 @@ def test_mutate_local_kwargs_override_global_ones(_downstream_result_to_mutate): def test_mutate_end_to_end_simple(import_mutate_module): - dr = driver.Builder().with_config({"calc_c": True}).build() - dr = ( driver.Builder() .with_modules(import_mutate_module) diff --git a/tests/resources/decorator_related/__init__.py b/tests/resources/decorator_related/__init__.py new file mode 100644 index 000000000..215a0b123 --- /dev/null +++ b/tests/resources/decorator_related/__init__.py @@ -0,0 +1,2 @@ +from .base import * # noqa: F401, F403 +from .base_extended import * # noqa: F401, F403 diff --git a/tests/resources/decorator_related/base.py b/tests/resources/decorator_related/base.py new file mode 100644 index 000000000..df710280e --- /dev/null +++ b/tests/resources/decorator_related/base.py @@ -0,0 +1,14 @@ +def a(input: int) -> int: + return input * 2 + + +def z(input: int) -> int: + return input * 3 + + +async def aa(input: int) -> int: + return input * 4 + + +async def zz(aa: int) -> int: + return aa * 5 diff --git a/tests/resources/decorator_related/base_extended.py b/tests/resources/decorator_related/base_extended.py new file mode 100644 index 000000000..0a8b06b8a --- /dev/null +++ b/tests/resources/decorator_related/base_extended.py @@ -0,0 +1,22 @@ +from hamilton.function_modifiers import check_output, parameterize, value + +from tests.resources.decorator_related import base + +b_p = parameterize(b={"input": value(1)}, c={"input": value(2)})(base.a) + +b_p2 = parameterize(q={"input": value(4)}, r={"input": value(5)})(base.a) + +b_p3 = check_output( + range=(0, 10), +)(base.aa) +b_p3.__name__ = "b_p3" # required to register this as `b_p3` in the graph + +b_p4 = parameterize(aaa={"input": value(4)}, aab={"input": value(5)})(base.aa) + + +def d(b: int, c: int) -> int: + return b + c + + +def e(input: int, a: int) -> int: + return input * 4 diff --git a/tests/test_end_to_end.py b/tests/test_end_to_end.py index b149673c5..3d2efd35c 100644 --- a/tests/test_end_to_end.py +++ b/tests/test_end_to_end.py @@ -7,7 +7,7 @@ import pandas as pd import pytest -from hamilton import ad_hoc_utils, base, driver, settings +from hamilton import ad_hoc_utils, async_driver, base, driver, settings from hamilton.base import DefaultAdapter from hamilton.data_quality.base import DataValidationError, ValidationResult from hamilton.execution import executors, grouping @@ -15,6 +15,7 @@ from hamilton.io.materialization import from_, to import tests.resources.data_quality +import tests.resources.decorator_related import tests.resources.dynamic_config import tests.resources.example_module import tests.resources.overrides @@ -556,3 +557,18 @@ def test_driver_v2_inputs_can_be_none(): with pytest.raises(ValueError): # validate that None doesn't cause issues dr.execute(["e"], inputs=None) + + +def test_function_decorator_reuse(): + """Tests we can reuse a function with multiple decorators""" + dr = driver.Builder().with_modules(tests.resources.decorator_related).build() + result = dr.execute(["a", "b", "c", "e", "q"], inputs={"input": 2}) + assert result == {"a": 4, "b": 2, "c": 4, "e": 8, "q": 8} + + +@pytest.mark.asyncio +async def test_function_decorator_reuse_async(): + """Tests we can reuse a function with multiple decorators""" + dr = await async_driver.Builder().with_modules(tests.resources.decorator_related).build() + result = await dr.execute(["a", "b", "c", "e", "q", "zz", "b_p3", "aaa"], inputs={"input": 2}) + assert result == {"a": 4, "aaa": 16, "b": 2, "b_p3": 8, "c": 4, "e": 8, "q": 8, "zz": 40}