@@ -371,6 +371,8 @@ def _fabricate_spark_function(
371
371
return FunctionType (func_code , {** globals (), ** {"partial_fn" : partial_fn }}, func_name )
372
372
373
373
374
+ # TODO -- change this to have a different implementation based on the dataframe type. This will have
375
+ # to likely be custom to each dataframe type
374
376
def _lambda_udf (df : DataFrame , node_ : node .Node , actual_kwargs : Dict [str , Any ]) -> DataFrame :
375
377
"""Function to create a lambda UDF for a function.
376
378
@@ -1080,12 +1082,16 @@ def create_selector_node(
1080
1082
"""
1081
1083
1082
1084
def new_callable (** kwargs ) -> DataFrame :
1085
+ # TODO -- change to have a `select` that's generic to the library
1086
+ # Use the registry
1083
1087
return kwargs [upstream_name ].select (* columns )
1084
1088
1085
1089
return node .Node (
1086
1090
name = node_name ,
1091
+ # TODO -- change to have the right dataframe type (from the registry)
1087
1092
typ = DataFrame ,
1088
1093
callabl = new_callable ,
1094
+ # TODO -- change to have the right dataframe type (from the registry)
1089
1095
input_types = {upstream_name : DataFrame },
1090
1096
)
1091
1097
@@ -1107,8 +1113,10 @@ def new_callable(**kwargs) -> DataFrame:
1107
1113
1108
1114
return node .Node (
1109
1115
name = node_name ,
1116
+ # TODO -- change to have the right dataframe type (from the registry)
1110
1117
typ = DataFrame ,
1111
1118
callabl = new_callable ,
1119
+ # TODO -- change to have the right dataframe type (from the registry)
1112
1120
input_types = {upstream_name : DataFrame },
1113
1121
)
1114
1122
@@ -1195,7 +1203,9 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
1195
1203
column for column in node_ .input_types if column in columns_passed_in_from_dataframe
1196
1204
}
1197
1205
# In the case that we are using pyspark UDFs
1206
+ # TODO -- use the right dataframe type to do this correctly
1198
1207
if require_columns .is_decorated_pyspark_udf (node_ ):
1208
+ # TODO -- change to use the right "sparkification" function that is dataframe-type-agnostic
1199
1209
sparkified = require_columns .sparkify_node (
1200
1210
node_ ,
1201
1211
current_dataframe_node ,
@@ -1206,6 +1216,7 @@ def generate_nodes(self, fn: Callable, config: Dict[str, Any]) -> List[node.Node
1206
1216
)
1207
1217
# otherwise we're using pandas/primitive UDFs
1208
1218
else :
1219
+ # TODO -- change to use the right "sparkification" function that is dataframe-type-agnostic
1209
1220
sparkified = sparkify_node_with_udf (
1210
1221
node_ ,
1211
1222
current_dataframe_node ,
0 commit comments