Skip to content

Commit c370d2f

Browse files
committed
Updates code with locations that need to change to have a generic with_columns decorator
1 parent 3ca633e commit c370d2f

File tree

1 file changed

+11
-0
lines changed

1 file changed

+11
-0
lines changed

hamilton/plugins/h_spark.py

+11
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,8 @@ def _fabricate_spark_function(
371371
return FunctionType(func_code, {**globals(), **{"partial_fn": partial_fn}}, func_name)
372372

373373

374+
# TODO -- change this to have a different implementation based on the dataframe type. This will have
375+
# to likely be custom to each dataframe type
374376
def _lambda_udf(df: DataFrame, node_: node.Node, actual_kwargs: Dict[str, Any]) -> DataFrame:
375377
"""Function to create a lambda UDF for a function.
376378
@@ -1080,12 +1082,16 @@ def create_selector_node(
10801082
"""
10811083

10821084
def new_callable(**kwargs) -> DataFrame:
1085+
# TODO -- change to have a `select` that's generic to the library
1086+
# Use the registry
10831087
return kwargs[upstream_name].select(*columns)
10841088

10851089
return node.Node(
10861090
name=node_name,
1091+
# TODO -- change to have the right dataframe type (from the registry)
10871092
typ=DataFrame,
10881093
callabl=new_callable,
1094+
# TODO -- change to have the right dataframe type (from the registry)
10891095
input_types={upstream_name: DataFrame},
10901096
)
10911097

@@ -1107,8 +1113,10 @@ def new_callable(**kwargs) -> DataFrame:
11071113

11081114
return node.Node(
11091115
name=node_name,
1116+
# TODO -- change to have the right dataframe type (from the registry)
11101117
typ=DataFrame,
11111118
callabl=new_callable,
1119+
# TODO -- change to have the right dataframe type (from the registry)
11121120
input_types={upstream_name: DataFrame},
11131121
)
11141122

@@ -1195,7 +1203,9 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
11951203
column for column in node_.input_types if column in columns_passed_in_from_dataframe
11961204
}
11971205
# In the case that we are using pyspark UDFs
1206+
# TODO -- use the right dataframe type to do this correctly
11981207
if require_columns.is_decorated_pyspark_udf(node_):
1208+
# TODO -- change to use the right "sparkification" function that is dataframe-type-agnostic
11991209
sparkified = require_columns.sparkify_node(
12001210
node_,
12011211
current_dataframe_node,
@@ -1206,6 +1216,7 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
12061216
)
12071217
# otherwise we're using pandas/primitive UDFs
12081218
else:
1219+
# TODO -- change to use the right "sparkification" function that is dataframe-type-agnostic
12091220
sparkified = sparkify_node_with_udf(
12101221
node_,
12111222
current_dataframe_node,

0 commit comments

Comments
 (0)