Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: decorators should wrap functions with vanilla wrapper #1282

Draft
wants to merge 7 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 63 additions & 6 deletions hamilton/function_modifiers/base.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import abc
import collections
import functools
import inspect
import itertools
import logging
from abc import ABC
from inspect import unwrap

try:
from types import EllipsisType
Expand Down Expand Up @@ -99,18 +101,40 @@ def __call__(self, fn: Callable):
:param fn: Function to decorate
:return: The function again, with the desired properties.
"""
self.validate(fn)
# # stop unwrapping if not a hamilton function
# # should only be one level of "hamilton wrapping" - and that's what we attach things to.
self.validate(unwrap(fn, stop=lambda f: not hasattr(f, "__hamilton__")))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we only do one level deep? Aren't there some weird cases with multiple imports, imports of imports, etc...?

Copy link
Collaborator Author

@skrawcz skrawcz Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

before we attached to the function itself. What I'm trying to do here is replicate that by using one wrapper layer around a function.

E.g.

@tag(...)
@extract_columns(...)
@parameterize(...)
def some_func(...) -> ... :

Should only have one "wrapper" around some_func - which is what is being modified by the decorators.

This then enables someone to do this:

def a(..) -> ...:
  # base func

q = @parameterize(x=...)(a) 
a_validated = @check_output(range=...)(a)
a_validated.__name__ = "a_validated"  # need to rename to enable this to be stand alone...

This above would enable a, x, and a_validated to be in the DAG. Before it wasn't possible.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, I got that. Why just one layer was the question?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why just one layer

  1. Doesn't seem like we need more?
  2. This mirrors the prior behavior of directly attaching to the function - there were no layers before.

if not hasattr(fn, "__hamilton__"):
if inspect.iscoroutinefunction(fn):

@functools.wraps(fn)
async def wrapper(*args, **kwargs):
return await fn(*args, **kwargs)
else:

@functools.wraps(fn)
def wrapper(*args, **kwargs):
return fn(*args, **kwargs)

wrapper.__hamilton__ = True
if not hasattr(fn, "__hamilton_wrappers__"):
fn.__hamilton_wrappers__ = [wrapper]
else:
fn.__hamilton_wrappers__.append(wrapper)
else:
wrapper = fn

lifecycle_name = self.__class__.get_lifecycle_name()
if hasattr(fn, self.get_lifecycle_name()):
if hasattr(wrapper, self.get_lifecycle_name()):
if not self.allows_multiple():
raise ValueError(
f"Got multiple decorators for decorator @{self.__class__}. Only one allowed."
)
curr_value = getattr(fn, lifecycle_name)
setattr(fn, lifecycle_name, curr_value + [self])
curr_value = getattr(wrapper, lifecycle_name)
setattr(wrapper, lifecycle_name, curr_value + [self])
else:
setattr(fn, lifecycle_name, [self])
return fn
setattr(wrapper, lifecycle_name, [self])
return wrapper

def required_config(self) -> Optional[List[str]]:
"""Declares the required configuration keys for this decorator.
Expand Down Expand Up @@ -805,6 +829,9 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]
which configuration they need.
:return: A list of nodes into which this function transforms.
"""
# check for mutate...
fn = handle_mutate_hack(fn)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Take away the name "hack"


try:
function_decorators = get_node_decorators(fn, config)
node_resolvers = function_decorators[NodeResolver.get_lifecycle_name()]
Expand Down Expand Up @@ -833,6 +860,36 @@ def resolve_nodes(fn: Callable, config: Dict[str, Any]) -> Collection[node.Node]
raise e


def handle_mutate_hack(fn):
"""Function that encapsulates the mutate hack check.

This isn't pretty. It's a hack to get around how special
mutate is.

This will return the "wrapped function" if this is
a vanilla python function that has a pointer to a pipe_output
decorated function. This is because all other decorators
directly wrap the function, but mutate does not. It adds
a pointer to the function in question that we follow here.

This code depends on what the mutate class does
and what the pipe_output decorator does.

:param fn: Function to check
:return: Function or wrapped function to use if applicable.
"""
if hasattr(fn, "__hamilton_wrappers__"):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this worth the complexity it adds? Is there a simpler way to approach it? Need to grok better.

Copy link
Collaborator Author

@skrawcz skrawcz Feb 28, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that with mutate the current design was to attach to the function and it would be resolved properly via that way.

I think the better approach would be to register mutates via a registry that is applied to functions that are found ... and that way we wouldn't need to attach it to the function directly...

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious, if we use a registry for mutate, would it be sensible to just have one global registry for all decorators instead of the additional 1-layer wrapper?

wrapper = fn.__hamilton_wrappers__[0] # assume first one
if hasattr(wrapper, "transform"):
for decorator in wrapper.transform:
from hamilton.function_modifiers import macros

if isinstance(decorator, macros.pipe_output) and decorator.is_via_mutate:
fn = wrapper # overwrite callable with right one
break
return fn


class InvalidDecoratorException(Exception):
pass

Expand Down
63 changes: 62 additions & 1 deletion hamilton/function_modifiers/expanders.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,65 @@ def replacement_function(
# now the error will be clear enough
return node_.callable(*args, **new_kwargs)

async def async_replacement_function(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note -- merge with the above when you get to polishing

*args,
upstream_dependencies=upstream_dependencies,
literal_dependencies=literal_dependencies,
grouped_list_dependencies=grouped_list_dependencies,
grouped_dict_dependencies=grouped_dict_dependencies,
former_inputs=list(node_.input_types.keys()), # noqa
**kwargs,
):
"""This function rewrites what is passed in kwargs to the right kwarg for the function.
The passed in kwargs are all the dependencies of this node. Note that we actually have the "former inputs",
which are what the node declares as its dependencies. So, we just have to loop through all of them to
get the "new" value. This "new" value comes from the parameterization.

Note that much of this code should *probably* live within the source/value/grouped functions, but
it is here as we're not 100% sure about the abstraction.

TODO -- think about how the grouped/source/literal functions should be able to grab the values from kwargs/args.
Should be easy -- they should just have something like a "resolve(**kwargs)" function that they can call.
"""
new_kwargs = {}
for node_input in former_inputs:
if node_input in upstream_dependencies:
# If the node is specified by `source`, then we get the value from the kwargs
new_kwargs[node_input] = kwargs[upstream_dependencies[node_input].source]
elif node_input in literal_dependencies:
# If the node is specified by `value`, then we get the literal value (no need for kwargs)
new_kwargs[node_input] = literal_dependencies[node_input].value
elif node_input in grouped_list_dependencies:
# If the node is specified by `group`, then we get the list of values from the kwargs or the literal
new_kwargs[node_input] = []
for replacement in grouped_list_dependencies[node_input].sources:
resolved_value = (
kwargs[replacement.source]
if replacement.get_dependency_type()
== ParametrizedDependencySource.UPSTREAM
else replacement.value
)
new_kwargs[node_input].append(resolved_value)
elif node_input in grouped_dict_dependencies:
# If the node is specified by `group`, then we get the dict of values from the kwargs or the literal
new_kwargs[node_input] = {}
for dependency, replacement in grouped_dict_dependencies[
node_input
].sources.items():
resolved_value = (
kwargs[replacement.source]
if replacement.get_dependency_type()
== ParametrizedDependencySource.UPSTREAM
else replacement.value
)
new_kwargs[node_input][dependency] = resolved_value
elif node_input in kwargs:
new_kwargs[node_input] = kwargs[node_input]
# This case is left blank for optional parameters. If we error here, we'll break
# the (supported) case of optionals. We do know whether its optional but for
# now the error will be clear enough
return await node_.callable(*args, **new_kwargs)

new_input_types = {}
grouped_dependencies = {
**grouped_list_dependencies,
Expand Down Expand Up @@ -271,7 +330,9 @@ def replacement_function(
name=output_node,
doc_string=docstring, # TODO -- change docstring
callabl=functools.partial(
replacement_function,
replacement_function
if not inspect.iscoroutinefunction(node_.callable)
else async_replacement_function,
**{parameter: val.value for parameter, val in literal_dependencies.items()},
),
input_types=new_input_types,
Expand Down
25 changes: 20 additions & 5 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1256,6 +1256,11 @@ def __init__(
if self.chain:
raise NotImplementedError("@flow() is not yet supported -- this is ")

self.is_via_mutate = False # flag to know how this was instantiated.

def set_is_via_mutate(self):
self.is_via_mutate = True

def _filter_individual_target(self, node_):
"""Resolves target option on the transform level.
Adds option that we can decide for each applicable which output node it will target.
Expand Down Expand Up @@ -1605,15 +1610,25 @@ def __call__(self, mutating_fn: Callable):
mutating_fn=mutating_fn, remote_applicable_builder=remote_applicable
)
found_pipe_output = False
if hasattr(remote_applicable.target_fn, base.NodeTransformer.get_lifecycle_name()):
for decorator in remote_applicable.target_fn.transform:
wrapper_fn = None
# Assumptions:
# 1. This code depends on the `__call__()` definition in the Hamilton base decorator class
# 2. This is then used in `handle_mutate_hack()` in the Hamilton function modifier base.py.
if hasattr(remote_applicable.target_fn, "__hamilton_wrappers__"):
# get first wrapper
wrapper_fn = remote_applicable.target_fn.__hamilton_wrappers__[0]
elif hasattr(remote_applicable.target_fn, "__hamilton__"):
wrapper_fn = remote_applicable.target_fn

if wrapper_fn:
for decorator in wrapper_fn.transform:
if isinstance(decorator, pipe_output):
decorator.transforms = decorator.transforms + (new_pipe_step,)
found_pipe_output = True

if not found_pipe_output:
remote_applicable.target_fn = pipe_output(
new_pipe_step, collapse=self.collapse, _chain=self.chain
)(remote_applicable.target_fn)
decorator = pipe_output(new_pipe_step, collapse=self.collapse, _chain=self.chain)
decorator.set_is_via_mutate()
remote_applicable.target_fn = decorator(remote_applicable.target_fn)

return mutating_fn
13 changes: 12 additions & 1 deletion hamilton/function_modifiers/validation.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import abc
import inspect
from collections import defaultdict
from typing import Any, Callable, Collection, Dict, List, Type

Expand Down Expand Up @@ -46,6 +47,14 @@ def validation_function(validator_to_call: dq_base.DataValidator = validator, **
result = list(kwargs.values())[0] # This should just have one kwarg
return validator_to_call.validate(result)

async def async_validation_function(
validator_to_call: dq_base.DataValidator = validator, **kwargs
):
result = list(kwargs.values())[0] # This should just have one kwarg
if inspect.isawaitable(result):
result = await result
return validator_to_call.validate(result)

validator_node_name = node_.name + "_" + validator.name()
validator_name_count[validator_node_name] = (
validator_name_count[validator_node_name] + 1
Expand All @@ -58,7 +67,9 @@ def validation_function(validator_to_call: dq_base.DataValidator = validator, **
name=validator_node_name, # TODO -- determine a good approach towards naming this
typ=dq_base.ValidationResult,
doc_string=validator.description(),
callabl=validation_function,
callabl=validation_function
if not inspect.iscoroutinefunction(node_.callable)
else async_validation_function,
node_source=node.NodeType.STANDARD,
input_types={raw_node.name: (node_.type, node.DependencyType.REQUIRED)},
tags={
Expand Down
1 change: 1 addition & 0 deletions hamilton/plugins/h_ddog.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

logger = logging.getLogger(__name__)
try:
# TODO: this works for ddtrace < 3.0; Span got moved somewhere else in 3.0..
from ddtrace import Span, context, tracer
except ImportError as e:
logger.error("ImportError: %s", e)
Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ dask-core = ["dask-core"]
dask-dataframe = ["dask[dataframe]"]
dask-diagnostics = ["dask[diagnostics]"]
dask-distributed = ["dask[distributed]"]
datadog = ["ddtrace"]
datadog = ["ddtrace<3.0"] # Temporary pin until h_ddog.py import is fixed for >3.0 version
dev = [
"pre-commit",
"ruff==0.5.7", # this should match `.pre-commit-config.yaml`
Expand All @@ -53,7 +53,7 @@ docs = [
"commonmark==0.9.1", # read the docs pins
"dask-expr; python_version == '3.9'",
"dask[distributed]",
"ddtrace",
"ddtrace<3.0",
"diskcache",
# required for all the plugins
"dlt",
Expand Down
88 changes: 87 additions & 1 deletion tests/function_modifiers/test_base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from inspect import unwrap
from typing import Collection, Dict, List
from unittest.mock import Mock

import pytest as pytest
import pytest

from hamilton import node, settings
from hamilton.function_modifiers import InvalidDecoratorException, base
from hamilton.function_modifiers.base import (
MissingConfigParametersException,
NodeTransformer,
NodeTransformLifecycle,
TargetType,
)
from hamilton.node import Node
Expand Down Expand Up @@ -149,3 +152,86 @@ def test_add_fn_metadata():
]
assert len(nodes_with_fn_pointer) == len(nodes)
assert all([n.originating_functions == (test_add_fn_metadata,) for n in nodes])


class MockNodeTransformLifecycle(NodeTransformLifecycle):
@classmethod
def get_lifecycle_name(cls):
return "mock_lifecycle"

@classmethod
def allows_multiple(cls):
return True

def validate(self, fn):
pass


def test_decorator_adds_attributes():
mock_decorator = MockNodeTransformLifecycle()

def my_function(a: int) -> int:
pass

decorated_fn = mock_decorator(my_function)

assert hasattr(decorated_fn, "mock_lifecycle")
assert decorated_fn.mock_lifecycle == [mock_decorator]
assert hasattr(decorated_fn, "__hamilton__")


def test_decorator_allows_multiple_raises_error():
class MockMultipleNodeTransformLifecycle(NodeTransformLifecycle):
@classmethod
def get_lifecycle_name(cls):
return "mock_lifecycle"

@classmethod
def allows_multiple(cls):
return False

def validate(self, fn):
pass

mock_decorator = MockMultipleNodeTransformLifecycle()
mock_fn = Mock()
decorated_fn = mock_decorator(mock_fn)

with pytest.raises(ValueError):
mock_decorator(decorated_fn)


def test_decorator_only_wraps_once():
"""Tests that the decorator only wraps once."""
mock_decorator = MockNodeTransformLifecycle()

def my_function(a: int) -> int:
pass

decorated_fn = mock_decorator(my_function)
decorated_fn = mock_decorator(decorated_fn)
decorated_fn = mock_decorator(decorated_fn)

assert decorated_fn.__hamilton__ is True
assert decorated_fn.__wrapped__ == my_function # one level of wrapping only


def test_wrapping_and_unwrapping_logic():
"""Tests unwrapping logic works as expected."""

def my_function(a: int) -> int:
pass

# Wrap the function
wrapped_fn = MockNodeTransformLifecycle()(my_function)
# Unwrap the function
unwrapped_fn = unwrap(wrapped_fn, stop=lambda f: not hasattr(f, "__hamilton__"))

# Ensure the function is unwrapped correctly
assert unwrapped_fn == my_function
assert not hasattr(unwrapped_fn, "__hamilton__")

wrapped_fn2 = MockNodeTransformLifecycle()(wrapped_fn)
unwrapped_fn2 = unwrap(wrapped_fn2, stop=lambda f: not hasattr(f, "__hamilton__"))
assert wrapped_fn2 == wrapped_fn # these should be the same
assert unwrapped_fn2 == my_function # these should be the same
2 changes: 0 additions & 2 deletions tests/function_modifiers/test_macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -1107,8 +1107,6 @@ def test_mutate_local_kwargs_override_global_ones(_downstream_result_to_mutate):


def test_mutate_end_to_end_simple(import_mutate_module):
dr = driver.Builder().with_config({"calc_c": True}).build()

dr = (
driver.Builder()
.with_modules(import_mutate_module)
Expand Down
2 changes: 2 additions & 0 deletions tests/resources/decorator_related/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from .base import * # noqa: F401, F403
from .base_extended import * # noqa: F401, F403
Loading