Skip to content

Commit faa847e

Browse files
committed
refactored to remove try_cast function
1 parent 39042be commit faa847e

File tree

4 files changed

+30
-22
lines changed

4 files changed

+30
-22
lines changed

src/databricks/labs/dqx/col_functions.py

+12-14
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def is_not_null_and_not_empty(col_name: str, trim_strings: bool = False) -> Colu
3737
column = F.col(col_name)
3838
if trim_strings:
3939
column = F.trim(column).alias(col_name)
40-
condition = column.isNull() | (column.try_cast("string") == F.lit(""))
40+
condition = column.isNull() | (column.cast("string").isNull() | (column.cast("string") == F.lit("")))
4141
return make_condition(condition, f"Column {col_name} is null or empty", f"{col_name}_is_null_or_empty")
4242

4343

@@ -48,8 +48,8 @@ def is_not_empty(col_name: str) -> Column:
4848
:return: Column object for condition
4949
"""
5050
column = F.col(col_name)
51-
column = column.try_cast("string")
52-
return make_condition((column == ""), f"Column {col_name} is empty", f"{col_name}_is_empty")
51+
condition = column.cast("string") == F.lit("")
52+
return make_condition(condition, f"Column {col_name} is empty", f"{col_name}_is_empty")
5353

5454

5555
def is_not_null(col_name: str) -> Column:
@@ -77,7 +77,7 @@ def value_is_not_null_and_is_in_list(col_name: str, allowed: list) -> Column:
7777
F.concat_ws(
7878
"",
7979
F.lit("Value "),
80-
F.when(column.isNull(), F.lit("null")).otherwise(column.try_cast("string")),
80+
F.when(column.isNull(), F.lit("null")).otherwise(column.cast("string")),
8181
F.lit(" is not in the allowed list: ["),
8282
F.concat_ws(", ", *allowed_cols),
8383
F.lit("]"),
@@ -381,15 +381,15 @@ def is_valid_date(col_name: str, date_format: str | None = None) -> Column:
381381
:param date_format: date format (e.g. 'yyyy-mm-dd')
382382
:return: Column object for condition
383383
"""
384-
str_col = F.col(col_name)
385-
date_col = str_col.try_cast("date") if date_format is None else F.try_to_timestamp(str_col, F.lit(date_format))
386-
condition = F.when(str_col.isNull(), F.lit(None)).otherwise(date_col.isNull())
384+
column = F.col(col_name)
385+
date_col = F.try_to_timestamp(column) if date_format is None else F.try_to_timestamp(column, F.lit(date_format))
386+
condition = F.when(column.isNull(), F.lit(None)).otherwise(date_col.isNull())
387387
condition_str = "' is not a valid date"
388388
if date_format is not None:
389389
condition_str += f" with format '{date_format}'"
390390
return make_condition(
391391
condition,
392-
F.concat_ws("", F.lit("Value '"), str_col, F.lit(condition_str)),
392+
F.concat_ws("", F.lit("Value '"), column, F.lit(condition_str)),
393393
f"{col_name}_is_not_valid_date",
394394
)
395395

@@ -401,18 +401,16 @@ def is_valid_timestamp(col_name: str, timestamp_format: str | None = None) -> Co
401401
:param timestamp_format: timestamp format (e.g. 'yyyy-mm-dd HH:mm:ss')
402402
:return: Column object for condition
403403
"""
404-
str_col = F.col(col_name)
404+
column = F.col(col_name)
405405
ts_col = (
406-
str_col.try_cast("timestamp")
407-
if timestamp_format is None
408-
else F.try_to_timestamp(str_col, F.lit(timestamp_format))
406+
F.try_to_timestamp(column) if timestamp_format is None else F.try_to_timestamp(column, F.lit(timestamp_format))
409407
)
410-
condition = F.when(str_col.isNull(), F.lit(None)).otherwise(ts_col.isNull())
408+
condition = F.when(column.isNull(), F.lit(None)).otherwise(ts_col.isNull())
411409
condition_str = "' is not a valid timestamp"
412410
if timestamp_format is not None:
413411
condition_str += f" with format '{timestamp_format}'"
414412
return make_condition(
415413
condition,
416-
F.concat_ws("", F.lit("Value '"), str_col, F.lit(condition_str)),
414+
F.concat_ws("", F.lit("Value '"), column, F.lit(condition_str)),
417415
f"{col_name}_is_not_valid_timestamp",
418416
)

tests/integration/conftest.py

+11
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,17 @@ def product_info():
3939
return "dqx", __version__
4040

4141

42+
@pytest.fixture
43+
def set_utc_timezone():
44+
"""
45+
Set the timezone to UTC for the duration of the test to make sure spark timestamps
46+
are handled the same way regardless of the environment.
47+
"""
48+
os.environ["TZ"] = "UTC"
49+
yield
50+
os.environ.pop("TZ")
51+
52+
4253
@pytest.fixture
4354
def make_check_file_as_yaml(ws, make_random, make_directory):
4455
def create(**kwargs):

tests/integration/test_apply_checks.py

+1-2
Original file line numberDiff line numberDiff line change
@@ -509,8 +509,7 @@ def test_apply_checks_by_metadata_with_func_defined_outside_framework(ws, spark)
509509

510510
def col_test_check_func(col_name: str) -> Column:
511511
check_col = F.col(col_name)
512-
check_col = check_col.try_cast("string")
513-
condition = check_col.isNull() | (check_col == "") | (check_col == "null")
512+
condition = check_col.isNull() | (check_col.cast("string").isNull() | (check_col.cast("string") == F.lit("")))
514513
return make_condition(condition, "new check failed", f"{col_name}_is_null_or_empty")
515514

516515

tests/integration/test_functions.py tests/integration/test_col_functions.py

+6-6
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def test_is_col_older_than_n_days_cur(spark):
220220
assert_df_equality(actual, expected, ignore_nullable=True)
221221

222222

223-
def test_col_not_less_than(spark):
223+
def test_col_not_less_than(spark, set_utc_timezone):
224224
schema_num = "a: int, b: date, c: timestamp"
225225
test_df = spark.createDataFrame(
226226
[
@@ -254,7 +254,7 @@ def test_col_not_less_than(spark):
254254
assert_df_equality(actual, expected, ignore_nullable=True)
255255

256256

257-
def test_col_not_greater_than(spark):
257+
def test_col_not_greater_than(spark, set_utc_timezone):
258258
schema_num = "a: int, b: date, c: timestamp"
259259
test_df = spark.createDataFrame(
260260
[
@@ -288,7 +288,7 @@ def test_col_not_greater_than(spark):
288288
assert_df_equality(actual, expected, ignore_nullable=True)
289289

290290

291-
def test_col_is_in_range(spark):
291+
def test_col_is_in_range(spark, set_utc_timezone):
292292
schema_num = "a: int, b: date, c: timestamp"
293293
test_df = spark.createDataFrame(
294294
[
@@ -334,7 +334,7 @@ def test_col_is_in_range(spark):
334334
assert_df_equality(actual, expected, ignore_nullable=True)
335335

336336

337-
def test_col_is_not_in_range(spark):
337+
def test_col_is_not_in_range(spark, set_utc_timezone):
338338
schema_num = "a: int, b: date, c: timestamp"
339339
test_df = spark.createDataFrame(
340340
[
@@ -486,7 +486,7 @@ def test_col_is_not_null_and_not_empty_array(spark):
486486
assert_df_equality(actual, expected, ignore_nullable=True)
487487

488488

489-
def test_col_is_valid_date(spark):
489+
def test_col_is_valid_date(spark, set_utc_timezone):
490490
schema_array = "a: string, b: string, c: string, d: string"
491491
data = [
492492
["2024-01-01", "12/31/2025", "invalid_date", None],
@@ -526,7 +526,7 @@ def test_col_is_valid_date(spark):
526526
assert_df_equality(actual, expected, ignore_nullable=True)
527527

528528

529-
def test_col_is_valid_timestamp(spark):
529+
def test_col_is_valid_timestamp(spark, set_utc_timezone):
530530
schema_array = "a: string, b: string, c: string, d: string, e: string"
531531
data = [
532532
["2024-01-01 00:00:00", "12/31/2025 00:00:00", "invalid_timestamp", None, "2025-01-31T00:00:00"],

0 commit comments

Comments
 (0)