Skip to content

Commit fe8105c

Browse files
committed
refactor
1 parent 3ac56df commit fe8105c

File tree

2 files changed

+61
-52
lines changed

2 files changed

+61
-52
lines changed

pyproject.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ load-plugins = [
181181
"pylint.extensions.docparams",
182182
"pylint.extensions.dunder",
183183
"pylint.extensions.for_any_all",
184-
"pylint.extensions.mccabe",
184+
#"pylint.extensions.mccabe",
185185
"pylint.extensions.overlapping_exceptions",
186186
"pylint.extensions.private_import",
187187
"pylint.extensions.redefined_variable_type",

src/databricks/labs/dqx/profiler/profiler.py

+60-51
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,7 @@ def round_value(value: Any, direction: str, opts: dict[str, Any]) -> Any:
105105
}
106106

107107

108-
def extract_min_max( # pylint: disable=too-complex, too-many-statements
108+
def extract_min_max(
109109
dst: DataFrame,
110110
col_name: str,
111111
typ: T.DataType,
@@ -137,53 +137,7 @@ def extract_min_max( # pylint: disable=too-complex, too-many-statements
137137
dst = dst.select(F.col(column).cast("bigint").alias(column))
138138
# TODO: do summary instead? to get percentiles, etc.?
139139
mn_mx = dst.agg(F.min(column), F.max(column), F.mean(column), F.stddev(column)).collect()
140-
if mn_mx and len(mn_mx) > 0:
141-
metrics["min"] = mn_mx[0][0]
142-
metrics["max"] = mn_mx[0][1]
143-
sigmas = opts.get("sigmas", 3)
144-
avg = mn_mx[0][2]
145-
stddev = mn_mx[0][3]
146-
min_limit = avg - sigmas * stddev
147-
max_limit = avg + sigmas * stddev
148-
if min_limit > mn_mx[0][0] and max_limit < mn_mx[0][1]:
149-
descr = (
150-
f"Range doesn't include outliers, capped by {sigmas} sigmas. avg={avg}, "
151-
f"stddev={stddev}, min={metrics.get('min')}, max={metrics.get('max')}"
152-
)
153-
elif min_limit < mn_mx[0][0] and max_limit > mn_mx[0][1]: #
154-
min_limit = mn_mx[0][0]
155-
max_limit = mn_mx[0][1]
156-
descr = "Real min/max values were used"
157-
elif min_limit < mn_mx[0][0]:
158-
min_limit = mn_mx[0][0]
159-
descr = (
160-
f"Real min value was used. Max was capped by {sigmas} sigmas. avg={avg}, "
161-
f"stddev={stddev}, max={metrics.get('max')}"
162-
)
163-
elif max_limit > mn_mx[0][1]:
164-
max_limit = mn_mx[0][1]
165-
descr = (
166-
f"Real max value was used. Min was capped by {sigmas} sigmas. avg={avg}, "
167-
f"stddev={stddev}, min={metrics.get('min')}"
168-
)
169-
# we need to preserve type at the end
170-
if typ == T.IntegerType() or typ == T.LongType():
171-
min_limit = int(round_value(min_limit, "down", {"round": True}))
172-
max_limit = int(round_value(max_limit, "up", {"round": True}))
173-
elif typ == T.DateType():
174-
min_limit = datetime.date.fromtimestamp(int(min_limit))
175-
max_limit = datetime.date.fromtimestamp(int(max_limit))
176-
metrics["min"] = datetime.date.fromtimestamp(int(metrics["min"]))
177-
metrics["max"] = datetime.date.fromtimestamp(int(metrics["max"]))
178-
metrics["mean"] = datetime.date.fromtimestamp(int(avg))
179-
elif typ == T.TimestampType():
180-
min_limit = round_value(datetime.datetime.fromtimestamp(int(min_limit)), "down", {"round": True})
181-
max_limit = round_value(datetime.datetime.fromtimestamp(int(max_limit)), "up", {"round": True})
182-
metrics["min"] = datetime.datetime.fromtimestamp(int(metrics["min"]))
183-
metrics["max"] = datetime.datetime.fromtimestamp(int(metrics["max"]))
184-
metrics["mean"] = datetime.datetime.fromtimestamp(int(avg))
185-
else:
186-
print(f"Can't get min/max for field {col_name}")
140+
descr, max_limit, min_limit = get_min_max(col_name, descr, max_limit, metrics, min_limit, mn_mx, opts, typ)
187141
else:
188142
mn_mx = dst.agg(F.min(column), F.max(column)).collect()
189143
if mn_mx and len(mn_mx) > 0:
@@ -202,6 +156,57 @@ def extract_min_max( # pylint: disable=too-complex, too-many-statements
202156
return None
203157

204158

159+
def get_min_max(col_name, descr, max_limit, metrics, min_limit, mn_mx, opts, typ):
160+
if mn_mx and len(mn_mx) > 0:
161+
metrics["min"] = mn_mx[0][0]
162+
metrics["max"] = mn_mx[0][1]
163+
sigmas = opts.get("sigmas", 3)
164+
avg = mn_mx[0][2]
165+
stddev = mn_mx[0][3]
166+
min_limit = avg - sigmas * stddev
167+
max_limit = avg + sigmas * stddev
168+
if min_limit > mn_mx[0][0] and max_limit < mn_mx[0][1]:
169+
descr = (
170+
f"Range doesn't include outliers, capped by {sigmas} sigmas. avg={avg}, "
171+
f"stddev={stddev}, min={metrics.get('min')}, max={metrics.get('max')}"
172+
)
173+
elif min_limit < mn_mx[0][0] and max_limit > mn_mx[0][1]: #
174+
min_limit = mn_mx[0][0]
175+
max_limit = mn_mx[0][1]
176+
descr = "Real min/max values were used"
177+
elif min_limit < mn_mx[0][0]:
178+
min_limit = mn_mx[0][0]
179+
descr = (
180+
f"Real min value was used. Max was capped by {sigmas} sigmas. avg={avg}, "
181+
f"stddev={stddev}, max={metrics.get('max')}"
182+
)
183+
elif max_limit > mn_mx[0][1]:
184+
max_limit = mn_mx[0][1]
185+
descr = (
186+
f"Real max value was used. Min was capped by {sigmas} sigmas. avg={avg}, "
187+
f"stddev={stddev}, min={metrics.get('min')}"
188+
)
189+
# we need to preserve type at the end
190+
if typ == T.IntegerType() or typ == T.LongType():
191+
min_limit = int(round_value(min_limit, "down", {"round": True}))
192+
max_limit = int(round_value(max_limit, "up", {"round": True}))
193+
elif typ == T.DateType():
194+
min_limit = datetime.date.fromtimestamp(int(min_limit))
195+
max_limit = datetime.date.fromtimestamp(int(max_limit))
196+
metrics["min"] = datetime.date.fromtimestamp(int(metrics["min"]))
197+
metrics["max"] = datetime.date.fromtimestamp(int(metrics["max"]))
198+
metrics["mean"] = datetime.date.fromtimestamp(int(avg))
199+
elif typ == T.TimestampType():
200+
min_limit = round_value(datetime.datetime.fromtimestamp(int(min_limit)), "down", {"round": True})
201+
max_limit = round_value(datetime.datetime.fromtimestamp(int(max_limit)), "up", {"round": True})
202+
metrics["min"] = datetime.datetime.fromtimestamp(int(metrics["min"]))
203+
metrics["max"] = datetime.datetime.fromtimestamp(int(metrics["max"]))
204+
metrics["mean"] = datetime.datetime.fromtimestamp(int(avg))
205+
else:
206+
print(f"Can't get min/max for field {col_name}")
207+
return descr, max_limit, min_limit
208+
209+
205210
def get_fields(col_name: str, schema: T.StructType) -> list[T.StructField]:
206211
fields = []
207212
for f in schema.fields:
@@ -228,7 +233,7 @@ def get_columns_or_fields(cols: list[T.StructField]) -> list[T.StructField]:
228233
# TODO: split into managebale chunks
229234
# TODO: how to handle maps, arrays & structs?
230235
# TODO: return not only DQ rules, but also the profiling results - use named tuple?
231-
def profile_dataframe( # pylint: disable=too-complex, too-many-locals
236+
def profile_dataframe(
232237
df: DataFrame, cols: list[str] | None = None, opts: dict[str, Any] | None = None
233238
) -> tuple[dict[str, Any], list[DQRule]]:
234239
if opts is None:
@@ -249,6 +254,12 @@ def profile_dataframe( # pylint: disable=too-complex, too-many-locals
249254
max_nulls = opts.get("max_null_ratio", 0)
250255
trim_strings = opts.get("trim_strings", True)
251256

257+
profile(df, df_cols, dq_rules, max_nulls, opts, summary_stats, total_count, trim_strings)
258+
259+
return summary_stats, dq_rules
260+
261+
262+
def profile(df, df_cols, dq_rules, max_nulls, opts, summary_stats, total_count, trim_strings):
252263
# TODO: think, how we can do it in fewer passes. Maybe only for specific things, like, min_max, etc.
253264
for field in get_columns_or_fields(df_cols):
254265
field_name = field.name
@@ -306,5 +317,3 @@ def profile_dataframe( # pylint: disable=too-complex, too-many-locals
306317

307318
# That should be the last one
308319
dst.unpersist()
309-
310-
return summary_stats, dq_rules

0 commit comments

Comments
 (0)