@@ -105,7 +105,7 @@ def round_value(value: Any, direction: str, opts: dict[str, Any]) -> Any:
105
105
}
106
106
107
107
108
- def extract_min_max ( # pylint: disable=too-complex, too-many-statements
108
+ def extract_min_max (
109
109
dst : DataFrame ,
110
110
col_name : str ,
111
111
typ : T .DataType ,
@@ -137,53 +137,7 @@ def extract_min_max( # pylint: disable=too-complex, too-many-statements
137
137
dst = dst .select (F .col (column ).cast ("bigint" ).alias (column ))
138
138
# TODO: do summary instead? to get percentiles, etc.?
139
139
mn_mx = dst .agg (F .min (column ), F .max (column ), F .mean (column ), F .stddev (column )).collect ()
140
- if mn_mx and len (mn_mx ) > 0 :
141
- metrics ["min" ] = mn_mx [0 ][0 ]
142
- metrics ["max" ] = mn_mx [0 ][1 ]
143
- sigmas = opts .get ("sigmas" , 3 )
144
- avg = mn_mx [0 ][2 ]
145
- stddev = mn_mx [0 ][3 ]
146
- min_limit = avg - sigmas * stddev
147
- max_limit = avg + sigmas * stddev
148
- if min_limit > mn_mx [0 ][0 ] and max_limit < mn_mx [0 ][1 ]:
149
- descr = (
150
- f"Range doesn't include outliers, capped by { sigmas } sigmas. avg={ avg } , "
151
- f"stddev={ stddev } , min={ metrics .get ('min' )} , max={ metrics .get ('max' )} "
152
- )
153
- elif min_limit < mn_mx [0 ][0 ] and max_limit > mn_mx [0 ][1 ]: #
154
- min_limit = mn_mx [0 ][0 ]
155
- max_limit = mn_mx [0 ][1 ]
156
- descr = "Real min/max values were used"
157
- elif min_limit < mn_mx [0 ][0 ]:
158
- min_limit = mn_mx [0 ][0 ]
159
- descr = (
160
- f"Real min value was used. Max was capped by { sigmas } sigmas. avg={ avg } , "
161
- f"stddev={ stddev } , max={ metrics .get ('max' )} "
162
- )
163
- elif max_limit > mn_mx [0 ][1 ]:
164
- max_limit = mn_mx [0 ][1 ]
165
- descr = (
166
- f"Real max value was used. Min was capped by { sigmas } sigmas. avg={ avg } , "
167
- f"stddev={ stddev } , min={ metrics .get ('min' )} "
168
- )
169
- # we need to preserve type at the end
170
- if typ == T .IntegerType () or typ == T .LongType ():
171
- min_limit = int (round_value (min_limit , "down" , {"round" : True }))
172
- max_limit = int (round_value (max_limit , "up" , {"round" : True }))
173
- elif typ == T .DateType ():
174
- min_limit = datetime .date .fromtimestamp (int (min_limit ))
175
- max_limit = datetime .date .fromtimestamp (int (max_limit ))
176
- metrics ["min" ] = datetime .date .fromtimestamp (int (metrics ["min" ]))
177
- metrics ["max" ] = datetime .date .fromtimestamp (int (metrics ["max" ]))
178
- metrics ["mean" ] = datetime .date .fromtimestamp (int (avg ))
179
- elif typ == T .TimestampType ():
180
- min_limit = round_value (datetime .datetime .fromtimestamp (int (min_limit )), "down" , {"round" : True })
181
- max_limit = round_value (datetime .datetime .fromtimestamp (int (max_limit )), "up" , {"round" : True })
182
- metrics ["min" ] = datetime .datetime .fromtimestamp (int (metrics ["min" ]))
183
- metrics ["max" ] = datetime .datetime .fromtimestamp (int (metrics ["max" ]))
184
- metrics ["mean" ] = datetime .datetime .fromtimestamp (int (avg ))
185
- else :
186
- print (f"Can't get min/max for field { col_name } " )
140
+ descr , max_limit , min_limit = get_min_max (col_name , descr , max_limit , metrics , min_limit , mn_mx , opts , typ )
187
141
else :
188
142
mn_mx = dst .agg (F .min (column ), F .max (column )).collect ()
189
143
if mn_mx and len (mn_mx ) > 0 :
@@ -202,6 +156,57 @@ def extract_min_max( # pylint: disable=too-complex, too-many-statements
202
156
return None
203
157
204
158
159
+ def get_min_max (col_name , descr , max_limit , metrics , min_limit , mn_mx , opts , typ ):
160
+ if mn_mx and len (mn_mx ) > 0 :
161
+ metrics ["min" ] = mn_mx [0 ][0 ]
162
+ metrics ["max" ] = mn_mx [0 ][1 ]
163
+ sigmas = opts .get ("sigmas" , 3 )
164
+ avg = mn_mx [0 ][2 ]
165
+ stddev = mn_mx [0 ][3 ]
166
+ min_limit = avg - sigmas * stddev
167
+ max_limit = avg + sigmas * stddev
168
+ if min_limit > mn_mx [0 ][0 ] and max_limit < mn_mx [0 ][1 ]:
169
+ descr = (
170
+ f"Range doesn't include outliers, capped by { sigmas } sigmas. avg={ avg } , "
171
+ f"stddev={ stddev } , min={ metrics .get ('min' )} , max={ metrics .get ('max' )} "
172
+ )
173
+ elif min_limit < mn_mx [0 ][0 ] and max_limit > mn_mx [0 ][1 ]: #
174
+ min_limit = mn_mx [0 ][0 ]
175
+ max_limit = mn_mx [0 ][1 ]
176
+ descr = "Real min/max values were used"
177
+ elif min_limit < mn_mx [0 ][0 ]:
178
+ min_limit = mn_mx [0 ][0 ]
179
+ descr = (
180
+ f"Real min value was used. Max was capped by { sigmas } sigmas. avg={ avg } , "
181
+ f"stddev={ stddev } , max={ metrics .get ('max' )} "
182
+ )
183
+ elif max_limit > mn_mx [0 ][1 ]:
184
+ max_limit = mn_mx [0 ][1 ]
185
+ descr = (
186
+ f"Real max value was used. Min was capped by { sigmas } sigmas. avg={ avg } , "
187
+ f"stddev={ stddev } , min={ metrics .get ('min' )} "
188
+ )
189
+ # we need to preserve type at the end
190
+ if typ == T .IntegerType () or typ == T .LongType ():
191
+ min_limit = int (round_value (min_limit , "down" , {"round" : True }))
192
+ max_limit = int (round_value (max_limit , "up" , {"round" : True }))
193
+ elif typ == T .DateType ():
194
+ min_limit = datetime .date .fromtimestamp (int (min_limit ))
195
+ max_limit = datetime .date .fromtimestamp (int (max_limit ))
196
+ metrics ["min" ] = datetime .date .fromtimestamp (int (metrics ["min" ]))
197
+ metrics ["max" ] = datetime .date .fromtimestamp (int (metrics ["max" ]))
198
+ metrics ["mean" ] = datetime .date .fromtimestamp (int (avg ))
199
+ elif typ == T .TimestampType ():
200
+ min_limit = round_value (datetime .datetime .fromtimestamp (int (min_limit )), "down" , {"round" : True })
201
+ max_limit = round_value (datetime .datetime .fromtimestamp (int (max_limit )), "up" , {"round" : True })
202
+ metrics ["min" ] = datetime .datetime .fromtimestamp (int (metrics ["min" ]))
203
+ metrics ["max" ] = datetime .datetime .fromtimestamp (int (metrics ["max" ]))
204
+ metrics ["mean" ] = datetime .datetime .fromtimestamp (int (avg ))
205
+ else :
206
+ print (f"Can't get min/max for field { col_name } " )
207
+ return descr , max_limit , min_limit
208
+
209
+
205
210
def get_fields (col_name : str , schema : T .StructType ) -> list [T .StructField ]:
206
211
fields = []
207
212
for f in schema .fields :
@@ -228,7 +233,7 @@ def get_columns_or_fields(cols: list[T.StructField]) -> list[T.StructField]:
228
233
# TODO: split into managebale chunks
229
234
# TODO: how to handle maps, arrays & structs?
230
235
# TODO: return not only DQ rules, but also the profiling results - use named tuple?
231
- def profile_dataframe ( # pylint: disable=too-complex, too-many-locals
236
+ def profile_dataframe (
232
237
df : DataFrame , cols : list [str ] | None = None , opts : dict [str , Any ] | None = None
233
238
) -> tuple [dict [str , Any ], list [DQRule ]]:
234
239
if opts is None :
@@ -249,6 +254,12 @@ def profile_dataframe( # pylint: disable=too-complex, too-many-locals
249
254
max_nulls = opts .get ("max_null_ratio" , 0 )
250
255
trim_strings = opts .get ("trim_strings" , True )
251
256
257
+ profile (df , df_cols , dq_rules , max_nulls , opts , summary_stats , total_count , trim_strings )
258
+
259
+ return summary_stats , dq_rules
260
+
261
+
262
+ def profile (df , df_cols , dq_rules , max_nulls , opts , summary_stats , total_count , trim_strings ):
252
263
# TODO: think, how we can do it in fewer passes. Maybe only for specific things, like, min_max, etc.
253
264
for field in get_columns_or_fields (df_cols ):
254
265
field_name = field .name
@@ -306,5 +317,3 @@ def profile_dataframe( # pylint: disable=too-complex, too-many-locals
306
317
307
318
# That should be the last one
308
319
dst .unpersist ()
309
-
310
- return summary_stats , dq_rules
0 commit comments