Skip to content

Commit a09b0fe

Browse files
hrfmartinshrfmartins
hrfmartins
authored andcommitted
feat: customizable column names for DQEngine along with placeholer for other future configurations
1 parent 27bc038 commit a09b0fe

File tree

3 files changed

+94
-36
lines changed

3 files changed

+94
-36
lines changed

src/databricks/labs/dqx/engine.py

+45-27
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,17 @@
55
import itertools
66
from pathlib import Path
77
from collections.abc import Callable
8-
from typing import Any
8+
from typing import Any, Optional
9+
from dataclasses import dataclass, field
910
import yaml
1011
import pyspark.sql.functions as F
1112
from pyspark.sql import DataFrame
12-
from databricks.labs.dqx.rule import DQRule, Criticality, Columns, DQRuleColSet, ChecksValidationStatus
13+
from databricks.labs.dqx.rule import DQRule, Criticality, DQRuleColSet, ChecksValidationStatus, ColumnArguments, \
14+
ExtraParams, DefaultColumnNames
1315
from databricks.labs.dqx.utils import deserialize_dicts
1416
from databricks.labs.dqx import col_functions
1517
from databricks.labs.blueprint.installation import Installation
18+
1619
from databricks.labs.dqx.base import DQEngineBase, DQEngineCoreBase
1720
from databricks.labs.dqx.config import WorkspaceConfig, RunConfig
1821
from databricks.sdk.errors import NotFound
@@ -24,16 +27,35 @@
2427

2528

2629
class DQEngineCore(DQEngineCoreBase):
27-
"""Data Quality Engine Core class to apply data quality checks to a given dataframe."""
30+
"""Data Quality Engine Core class to apply data quality checks to a given dataframe.
31+
Args:
32+
workspace_client (WorkspaceClient): WorkspaceClient instance to use for accessing the workspace.
33+
extra_params (ExtraParams): Extra parameters for the DQEngine.
34+
"""
35+
36+
def __init__(self, workspace_client: WorkspaceClient, extra_params: ExtraParams | None = None):
37+
super().__init__(workspace_client)
38+
39+
extra_params = extra_params or ExtraParams()
40+
41+
self._column_names = {
42+
ColumnArguments.ERRORS: extra_params.column_names.get(
43+
ColumnArguments.ERRORS.value, DefaultColumnNames.ERRORS.value
44+
),
45+
ColumnArguments.WARNINGS: extra_params.column_names.get(
46+
ColumnArguments.WARNINGS.value, DefaultColumnNames.WARNINGS.value
47+
),
48+
}
49+
2850

2951
def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame:
3052
if not checks:
3153
return self._append_empty_checks(df)
3254

3355
warning_checks = self._get_check_columns(checks, Criticality.WARN.value)
3456
error_checks = self._get_check_columns(checks, Criticality.ERROR.value)
35-
ndf = self._create_results_map(df, error_checks, Columns.ERRORS.value)
36-
ndf = self._create_results_map(ndf, warning_checks, Columns.WARNINGS.value)
57+
ndf = self._create_results_map(df, error_checks, self._column_names[ColumnArguments.ERRORS])
58+
ndf = self._create_results_map(ndf, warning_checks, self._column_names[ColumnArguments.WARNINGS])
3759

3860
return ndf
3961

@@ -57,12 +79,13 @@ def apply_checks_by_metadata_and_split(
5779

5880
return good_df, bad_df
5981

82+
6083
def apply_checks_by_metadata(
61-
self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None
62-
) -> DataFrame:
63-
dq_rule_checks = self.build_checks_by_metadata(checks, glbs)
84+
self, df: DataFrame, checks: list[dict], glbs: dict[str, Any] | None = None
85+
) -> DataFrame:
86+
dq_rule_checks = self.build_checks_by_metadata(checks, glbs)
6487

65-
return self.apply_checks(df, dq_rule_checks)
88+
return self.apply_checks(df, dq_rule_checks)
6689

6790
@staticmethod
6891
def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus:
@@ -77,13 +100,11 @@ def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> C
77100

78101
return status
79102

80-
@staticmethod
81-
def get_invalid(df: DataFrame) -> DataFrame:
82-
return df.where(F.col(Columns.ERRORS.value).isNotNull() | F.col(Columns.WARNINGS.value).isNotNull())
103+
def get_invalid(self, df: DataFrame) -> DataFrame:
104+
return df.where(F.col(self._column_names[ColumnArguments.ERRORS]).isNotNull() | F.col(self._column_names[ColumnArguments.WARNINGS]).isNotNull())
83105

84-
@staticmethod
85-
def get_valid(df: DataFrame) -> DataFrame:
86-
return df.where(F.col(Columns.ERRORS.value).isNull()).drop(Columns.ERRORS.value, Columns.WARNINGS.value)
106+
def get_valid(self, df: DataFrame) -> DataFrame:
107+
return df.where(F.col(self._column_names[ColumnArguments.ERRORS]).isNull()).drop(self._column_names[ColumnArguments.ERRORS], self._column_names[ColumnArguments.WARNINGS])
87108

88109
@staticmethod
89110
def load_checks_from_local_file(path: str) -> list[dict]:
@@ -177,17 +198,16 @@ def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]:
177198
"""
178199
return [check for check in checks if check.rule_criticality == criticality]
179200

180-
@staticmethod
181-
def _append_empty_checks(df: DataFrame) -> DataFrame:
201+
def _append_empty_checks(self, df: DataFrame) -> DataFrame:
182202
"""Append empty checks at the end of dataframe.
183203
184204
:param df: dataframe without checks
185205
:return: dataframe with checks
186206
"""
187207
return df.select(
188208
"*",
189-
F.lit(None).cast("map<string, string>").alias(Columns.ERRORS.value),
190-
F.lit(None).cast("map<string, string>").alias(Columns.WARNINGS.value),
209+
F.lit(None).cast("map<string, string>").alias(self._column_names[ColumnArguments.ERRORS]),
210+
F.lit(None).cast("map<string, string>").alias(self._column_names[ColumnArguments.WARNINGS]),
191211
)
192212

193213
@staticmethod
@@ -350,9 +370,9 @@ def _resolve_function(func_name: str, glbs: dict[str, Any] | None = None, fail_o
350370
class DQEngine(DQEngineBase):
351371
"""Data Quality Engine class to apply data quality checks to a given dataframe."""
352372

353-
def __init__(self, workspace_client: WorkspaceClient, engine: DQEngineCoreBase | None = None):
373+
def __init__(self, workspace_client: WorkspaceClient, engine: DQEngineCoreBase | None = None, extra_params: ExtraParams | None = None):
354374
super().__init__(workspace_client)
355-
self._engine = engine or DQEngineCore(workspace_client)
375+
self._engine = engine or DQEngineCore(workspace_client, extra_params)
356376

357377
def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame:
358378
return self._engine.apply_checks(df, checks)
@@ -374,13 +394,11 @@ def apply_checks_by_metadata(
374394
def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus:
375395
return DQEngineCore.validate_checks(checks, glbs)
376396

377-
@staticmethod
378-
def get_invalid(df: DataFrame) -> DataFrame:
379-
return DQEngineCore.get_invalid(df)
397+
def get_invalid(self, df: DataFrame) -> DataFrame:
398+
return self._engine.get_invalid(df)
380399

381-
@staticmethod
382-
def get_valid(df: DataFrame) -> DataFrame:
383-
return DQEngineCore.get_valid(df)
400+
def get_valid(self, df: DataFrame) -> DataFrame:
401+
return self._engine.get_valid(df)
384402

385403
@staticmethod
386404
def load_checks_from_local_file(path: str) -> list[dict]:

src/databricks/labs/dqx/rule.py

+19-7
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,39 @@
11
from enum import Enum
22
from dataclasses import dataclass, field
33
import functools as ft
4-
from typing import Any
4+
from typing import Any, Optional
55
from collections.abc import Callable
66
from pyspark.sql import Column
77
import pyspark.sql.functions as F
88
from databricks.labs.dqx.utils import get_column_name
99

1010

11-
# TODO: make this configurable
12-
class Columns(Enum):
11+
class Criticality(Enum):
12+
"""Enum class to represent criticality of the check."""
13+
14+
WARN = "warn"
15+
ERROR = "error"
16+
17+
18+
class DefaultColumnNames(Enum):
1319
"""Enum class to represent columns in the dataframe that will be used for error and warning reporting."""
1420

1521
ERRORS = "_errors"
1622
WARNINGS = "_warnings"
1723

1824

19-
class Criticality(Enum):
20-
"""Enum class to represent criticality of the check."""
25+
class ColumnArguments(Enum):
26+
"""Enum class that is used as input parsing for custom column naming."""
2127

22-
WARN = "warn"
23-
ERROR = "error"
28+
ERRORS = "errors"
29+
WARNINGS = "warnings"
30+
31+
32+
@dataclass(frozen=True)
33+
class ExtraParams:
34+
"""Class to represent extra parameters for DQEngine."""
2435

36+
column_names: Optional[dict[str, str]] = field(default_factory=dict)
2537

2638
@dataclass(frozen=True)
2739
class DQRule:

tests/integration/test_apply_checks.py

+30-2
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,16 @@
44
from pyspark.sql import Column
55
from chispa.dataframe_comparer import assert_df_equality # type: ignore
66
from databricks.labs.dqx.col_functions import is_not_null_and_not_empty, make_condition
7-
from databricks.labs.dqx.engine import DQEngine
7+
from databricks.labs.dqx.engine import (
8+
DQRule,
9+
DQEngine,
10+
ExtraParams,
11+
)
812
from databricks.labs.dqx.rule import DQRule
913

10-
1114
SCHEMA = "a: int, b: int, c: int"
1215
EXPECTED_SCHEMA = SCHEMA + ", _errors: map<string,string>, _warnings: map<string,string>"
16+
EXPECTED_SCHEMA_WITH_CUSTOM_NAMES = SCHEMA + ", ERROR: map<string,string>, WARN: map<string,string>"
1317

1418

1519
def test_apply_checks_on_empty_checks(ws, spark):
@@ -491,3 +495,27 @@ def test_get_invalid_records(ws, spark):
491495
)
492496

493497
assert_df_equality(invalid_df, expected_invalid_df)
498+
499+
def test_apply_checks_with_custom_column_naming(ws, spark):
500+
dq_engine = DQEngine(ws, extra_params=ExtraParams(column_names = {'errors': 'ERROR', 'warnings': 'WARN'}))
501+
test_df = spark.createDataFrame([[1, 3, 3], [2, None, 4], [None, 4, None], [None, None, None]], SCHEMA)
502+
503+
checks = [{"criticality": "warn", "check": {"function": "col_test_check_func", "arguments": {"col_name": "a"}}}]
504+
checked = dq_engine.apply_checks_by_metadata(test_df, checks, globals())
505+
506+
assert 'ERROR' in checked.columns
507+
assert 'WARN' in checked.columns
508+
509+
expected = spark.createDataFrame(
510+
[
511+
[1, 3, 3, None, None],
512+
[2, None, 4, None, None],
513+
[None, 4, None, None, {"col_a_is_null_or_empty": "new check failed"}],
514+
[None, None, None, None, {"col_a_is_null_or_empty": "new check failed"}],
515+
],
516+
EXPECTED_SCHEMA_WITH_CUSTOM_NAMES,
517+
)
518+
519+
assert_df_equality(checked, expected, ignore_nullable=True)
520+
521+

0 commit comments

Comments
 (0)