7
7
from collections .abc import Callable
8
8
from dataclasses import dataclass , field
9
9
from enum import Enum
10
- from typing import Any
10
+ from typing import Any , Optional
11
11
import yaml
12
12
13
13
import pyspark .sql .functions as F
18
18
from databricks .labs .dqx .base import DQEngineBase
19
19
from databricks .labs .dqx .config import WorkspaceConfig
20
20
from databricks .labs .dqx .utils import get_column_name
21
+ from databricks .sdk import WorkspaceClient
21
22
from databricks .sdk .errors import NotFound
22
23
23
24
logger = logging .getLogger (__name__ )
24
25
25
26
26
- # TODO: make this configurable
27
- class Columns (Enum ):
27
+ class DefaultColumnNames (Enum ):
28
28
"""Enum class to represent columns in the dataframe that will be used for error and warning reporting."""
29
29
30
30
ERRORS = "_errors"
31
31
WARNINGS = "_warnings"
32
32
33
33
34
+ class ColumnArguments (Enum ):
35
+ """Enum class that is used as input parsing for custom column naming."""
36
+
37
+ ERRORS = "errors"
38
+ WARNINGS = "warnings"
39
+
40
+
34
41
class Criticality (Enum ):
35
42
"""Enum class to represent criticality of the check."""
36
43
@@ -142,9 +149,32 @@ def get_rules(self) -> list[DQRule]:
142
149
rules .append (rule )
143
150
return rules
144
151
152
+ @dataclass (frozen = True )
153
+ class ExtraParams :
154
+ """Class to represent extra parameters for DQEngine."""
155
+
156
+ column_names : Optional [dict [str , str ]] = field (default_factory = dict )
157
+
145
158
146
159
class DQEngine (DQEngineBase ):
147
- """Data Quality Engine class to apply data quality checks to a given dataframe."""
160
+ """Data Quality Engine class to apply data quality checks to a given dataframe.
161
+
162
+ Args:
163
+ workspace_client (WorkspaceClient): WorkspaceClient instance to use for accessing the workspace.
164
+ extra_params (ExtraParams): Extra parameters for the DQEngine.
165
+ """
166
+
167
+ def __init__ (self , workspace_client : WorkspaceClient , extra_params : Optional [ExtraParams ] = None ):
168
+ super ().__init__ (workspace_client )
169
+
170
+ extra_params = extra_params or ExtraParams ()
171
+
172
+ self ._column_names = {
173
+ ColumnArguments .ERRORS : extra_params .column_names .get (ColumnArguments .ERRORS .value , DefaultColumnNames .ERRORS .value ),
174
+ ColumnArguments .WARNINGS : extra_params .column_names .get (
175
+ ColumnArguments .WARNINGS .value , DefaultColumnNames .WARNINGS .value
176
+ ),
177
+ }
148
178
149
179
@staticmethod
150
180
def _get_check_columns (checks : list [DQRule ], criticality : str ) -> list [DQRule ]:
@@ -156,17 +186,16 @@ def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]:
156
186
"""
157
187
return [check for check in checks if check .rule_criticality == criticality ]
158
188
159
- @staticmethod
160
- def _append_empty_checks (df : DataFrame ) -> DataFrame :
189
+ def _append_empty_checks (self , df : DataFrame ) -> DataFrame :
161
190
"""Append empty checks at the end of dataframe.
162
191
163
192
:param df: dataframe without checks
164
193
:return: dataframe with checks
165
194
"""
166
195
return df .select (
167
196
"*" ,
168
- F .lit (None ).cast ("map<string, string>" ).alias (Columns . ERRORS . value ),
169
- F .lit (None ).cast ("map<string, string>" ).alias (Columns . WARNINGS . value ),
197
+ F .lit (None ).cast ("map<string, string>" ).alias (self . _column_names [ ColumnArguments . ERRORS ] ),
198
+ F .lit (None ).cast ("map<string, string>" ).alias (self . _column_names [ ColumnArguments . WARNINGS ] ),
170
199
)
171
200
172
201
@staticmethod
@@ -204,8 +233,8 @@ def apply_checks(self, df: DataFrame, checks: list[DQRule]) -> DataFrame:
204
233
205
234
warning_checks = self ._get_check_columns (checks , Criticality .WARN .value )
206
235
error_checks = self ._get_check_columns (checks , Criticality .ERROR .value )
207
- ndf = self ._create_results_map (df , error_checks , Columns . ERRORS . value )
208
- ndf = self ._create_results_map (ndf , warning_checks , Columns . WARNINGS . value )
236
+ ndf = self ._create_results_map (df , error_checks , self . _column_names [ ColumnArguments . ERRORS ] )
237
+ ndf = self ._create_results_map (ndf , warning_checks , self . _column_names [ ColumnArguments . WARNINGS ] )
209
238
210
239
return ndf
211
240
@@ -228,23 +257,26 @@ def apply_checks_and_split(self, df: DataFrame, checks: list[DQRule]) -> tuple[D
228
257
229
258
return good_df , bad_df
230
259
231
- @staticmethod
232
- def get_invalid (df : DataFrame ) -> DataFrame :
260
+ def get_invalid (self , df : DataFrame ) -> DataFrame :
233
261
"""
234
262
Get records that violate data quality checks (records with warnings and errors).
235
263
@param df: input DataFrame.
264
+ @param column_names: dictionary with column names for errors and warnings.
236
265
@return: dataframe with error and warning rows and corresponding reporting columns.
237
266
"""
238
- return df .where (F .col (Columns .ERRORS .value ).isNotNull () | F .col (Columns .WARNINGS .value ).isNotNull ())
267
+ return df .where (
268
+ F .col (self ._column_names [ColumnArguments .ERRORS ]).isNotNull ()
269
+ | F .col (self ._column_names [ColumnArguments .WARNINGS ]).isNotNull ()
270
+ )
239
271
240
- @staticmethod
241
- def get_valid (df : DataFrame ) -> DataFrame :
272
+ def get_valid (self , df : DataFrame ) -> DataFrame :
242
273
"""
243
274
Get records that don't violate data quality checks (records with warnings but no errors).
244
275
@param df: input DataFrame.
276
+ @param column_names: dictionary with column names for errors and warnings.
245
277
@return: dataframe with warning rows but no reporting columns.
246
278
"""
247
- return df .where (F .col (Columns . ERRORS . value ).isNull ()).drop (Columns . ERRORS . value , Columns . WARNINGS . value )
279
+ return df .where (F .col (self . _column_names [ ColumnArguments . ERRORS ] ).isNull ()).drop (self . _column_names [ ColumnArguments . ERRORS ], self . _column_names [ ColumnArguments . WARNINGS ] )
248
280
249
281
@staticmethod
250
282
def validate_checks (checks : list [dict ], glbs : dict [str , Any ] | None = None ) -> ChecksValidationStatus :
0 commit comments