Skip to content

Commit dbc3dd2

Browse files
author
hrfmartins
committed
feat: customizable column names for DQEngine along with placeholer for other future configurations
1 parent 982bd60 commit dbc3dd2

File tree

2 files changed

+74
-17
lines changed

2 files changed

+74
-17
lines changed

src/databricks/labs/dqx/engine.py

+48-16
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from collections.abc import Callable
88
from dataclasses import dataclass, field
99
from enum import Enum
10-
from typing import Any
10+
from typing import Any, Optional
1111
import yaml
1212

1313
import pyspark.sql.functions as F
@@ -18,19 +18,26 @@
1818
from databricks.labs.dqx.base import DQEngineBase
1919
from databricks.labs.dqx.config import WorkspaceConfig
2020
from databricks.labs.dqx.utils import get_column_name
21+
from databricks.sdk import WorkspaceClient
2122
from databricks.sdk.errors import NotFound
2223

2324
logger = logging.getLogger(__name__)
2425

2526

26-
# TODO: make this configurable
27-
class Columns(Enum):
27+
class DefaultColumnNames(Enum):
2828
"""Enum class to represent columns in the dataframe that will be used for error and warning reporting."""
2929

3030
ERRORS = "_errors"
3131
WARNINGS = "_warnings"
3232

3333

34+
class ColumnArguments(Enum):
35+
"""Enum class that is used as input parsing for custom column naming."""
36+
37+
ERRORS = "errors"
38+
WARNINGS = "warnings"
39+
40+
3441
class Criticality(Enum):
3542
"""Enum class to represent criticality of the check."""
3643

@@ -142,9 +149,32 @@ def get_rules(self) -> list[DQRule]:
142149
rules.append(rule)
143150
return rules
144151

152+
@dataclass(frozen=True)
153+
class ExtraParams:
154+
"""Class to represent extra parameters for DQEngine."""
155+
156+
column_names: Optional[dict[str, str]] = field(default_factory=dict)
157+
145158

146159
class DQEngine(DQEngineBase):
147-
"""Data Quality Engine class to apply data quality checks to a given dataframe."""
160+
"""Data Quality Engine class to apply data quality checks to a given dataframe.
161+
162+
Args:
163+
workspace_client (WorkspaceClient): WorkspaceClient instance to use for accessing the workspace.
164+
extra_params (ExtraParams): Extra parameters for the DQEngine.
165+
"""
166+
167+
def __init__(self, workspace_client: WorkspaceClient, extra_params: Optional[ExtraParams] = None):
168+
super().__init__(workspace_client)
169+
170+
extra_params = extra_params or ExtraParams()
171+
172+
self._column_names = {
173+
ColumnArguments.ERRORS: extra_params.column_names.get(ColumnArguments.ERRORS.value, DefaultColumnNames.ERRORS.value),
174+
ColumnArguments.WARNINGS: extra_params.column_names.get(
175+
ColumnArguments.WARNINGS.value, DefaultColumnNames.WARNINGS.value
176+
),
177+
}
148178

149179
@staticmethod
150180
def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]:
@@ -156,17 +186,16 @@ def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]:
156186
"""
157187
return [check for check in checks if check.rule_criticality == criticality]
158188

159-
@staticmethod
160-
def _append_empty_checks(df: DataFrame) -> DataFrame:
189+
def _append_empty_checks(self, df: DataFrame) -> DataFrame:
161190
"""Append empty checks at the end of dataframe.
162191
163192
:param df: dataframe without checks
164193
:return: dataframe with checks
165194
"""
166195
return df.select(
167196
"*",
168-
F.lit(None).cast("map<string, string>").alias(Columns.ERRORS.value),
169-
F.lit(None).cast("map<string, string>").alias(Columns.WARNINGS.value),
197+
F.lit(None).cast("map<string, string>").alias(self._column_names[ColumnArguments.ERRORS]),
198+
F.lit(None).cast("map<string, string>").alias(self._column_names[ColumnArguments.WARNINGS]),
170199
)
171200

172201
@staticmethod
@@ -204,8 +233,8 @@ def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame:
204233

205234
warning_checks = self._get_check_columns(checks, Criticality.WARN.value)
206235
error_checks = self._get_check_columns(checks, Criticality.ERROR.value)
207-
ndf = self._create_results_map(df, error_checks, Columns.ERRORS.value)
208-
ndf = self._create_results_map(ndf, warning_checks, Columns.WARNINGS.value)
236+
ndf = self._create_results_map(df, error_checks, self._column_names[ColumnArguments.ERRORS])
237+
ndf = self._create_results_map(ndf, warning_checks, self._column_names[ColumnArguments.WARNINGS])
209238

210239
return ndf
211240

@@ -228,23 +257,26 @@ def apply_checks_and_split(self, df: DataFrame, checks: list[DQRule]) -> tuple[D
228257

229258
return good_df, bad_df
230259

231-
@staticmethod
232-
def get_invalid(df: DataFrame) -> DataFrame:
260+
def get_invalid(self, df: DataFrame) -> DataFrame:
233261
"""
234262
Get records that violate data quality checks (records with warnings and errors).
235263
@param df: input DataFrame.
264+
@param column_names: dictionary with column names for errors and warnings.
236265
@return: dataframe with error and warning rows and corresponding reporting columns.
237266
"""
238-
return df.where(F.col(Columns.ERRORS.value).isNotNull() | F.col(Columns.WARNINGS.value).isNotNull())
267+
return df.where(
268+
F.col(self._column_names[ColumnArguments.ERRORS]).isNotNull()
269+
| F.col(self._column_names[ColumnArguments.WARNINGS]).isNotNull()
270+
)
239271

240-
@staticmethod
241-
def get_valid(df: DataFrame) -> DataFrame:
272+
def get_valid(self, df: DataFrame) -> DataFrame:
242273
"""
243274
Get records that don't violate data quality checks (records with warnings but no errors).
244275
@param df: input DataFrame.
276+
@param column_names: dictionary with column names for errors and warnings.
245277
@return: dataframe with warning rows but no reporting columns.
246278
"""
247-
return df.where(F.col(Columns.ERRORS.value).isNull()).drop(Columns.ERRORS.value, Columns.WARNINGS.value)
279+
return df.where(F.col(self._column_names[ColumnArguments.ERRORS]).isNull()).drop(self._column_names[ColumnArguments.ERRORS], self._column_names[ColumnArguments.WARNINGS])
248280

249281
@staticmethod
250282
def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> ChecksValidationStatus:

tests/integration/test_apply_checks.py

+26-1
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,12 @@
66
from databricks.labs.dqx.col_functions import is_not_null_and_not_empty, make_condition
77
from databricks.labs.dqx.engine import (
88
DQRule,
9-
DQEngine,
9+
DQEngine, ExtraParams,
1010
)
1111

1212
SCHEMA = "a: int, b: int, c: int"
1313
EXPECTED_SCHEMA = SCHEMA + ", _errors: map<string,string>, _warnings: map<string,string>"
14+
EXPECTED_SCHEMA_WITH_CUSTOM_NAMES = SCHEMA + ", ERROR: map<string,string>, WARN: map<string,string>"
1415

1516

1617
def test_apply_checks_on_empty_checks(ws, spark):
@@ -442,3 +443,27 @@ def col_test_check_func(col_name: str) -> Column:
442443
check_col = F.col(col_name)
443444
condition = check_col.isNull() | (check_col == "") | (check_col == "null")
444445
return make_condition(condition, "new check failed", f"{col_name}_is_null_or_empty")
446+
447+
448+
def test_apply_checks_with_custom_column_naming(ws, spark):
449+
dq_engine = DQEngine(ws, ExtraParams(column_names = {'errors': 'ERROR', 'warnings': 'WARN'}))
450+
test_df = spark.createDataFrame([[1, 3, 3], [2, None, 4], [None, 4, None], [None, None, None]], SCHEMA)
451+
452+
checks = [{"criticality": "warn", "check": {"function": "col_test_check_func", "arguments": {"col_name": "a"}}}]
453+
checked = dq_engine.apply_checks_by_metadata(test_df, checks, globals())
454+
455+
assert 'ERROR' in checked.columns
456+
assert 'WARN' in checked.columns
457+
458+
expected = spark.createDataFrame(
459+
[
460+
[1, 3, 3, None, None],
461+
[2, None, 4, None, None],
462+
[None, 4, None, None, {"col_a_is_null_or_empty": "new check failed"}],
463+
[None, None, None, None, {"col_a_is_null_or_empty": "new check failed"}],
464+
],
465+
EXPECTED_SCHEMA_WITH_CUSTOM_NAMES,
466+
)
467+
468+
assert_df_equality(checked, expected, ignore_nullable=True)
469+

0 commit comments

Comments
 (0)