5
5
import itertools
6
6
from pathlib import Path
7
7
from collections .abc import Callable
8
- from typing import Any
8
+ from typing import Any , Optional
9
+ from dataclasses import dataclass , field
9
10
import yaml
10
11
import pyspark .sql .functions as F
11
12
from pyspark .sql import DataFrame
12
- from databricks .labs .dqx .rule import DQRule , Criticality , Columns , DQRuleColSet , ChecksValidationStatus
13
+ from databricks .labs .dqx .rule import DQRule , Criticality , DQRuleColSet , ChecksValidationStatus , ColumnArguments , \
14
+ ExtraParams , DefaultColumnNames
13
15
from databricks .labs .dqx .utils import deserialize_dicts
14
16
from databricks .labs .dqx import col_functions
15
17
from databricks .labs .blueprint .installation import Installation
18
+
16
19
from databricks .labs .dqx .base import DQEngineBase , DQEngineCoreBase
17
20
from databricks .labs .dqx .config import WorkspaceConfig , RunConfig
18
21
from databricks .sdk .errors import NotFound
24
27
25
28
26
29
class DQEngineCore (DQEngineCoreBase ):
27
- """Data Quality Engine Core class to apply data quality checks to a given dataframe."""
30
+ """Data Quality Engine Core class to apply data quality checks to a given dataframe.
31
+ Args:
32
+ workspace_client (WorkspaceClient): WorkspaceClient instance to use for accessing the workspace.
33
+ extra_params (ExtraParams): Extra parameters for the DQEngine.
34
+ """
35
+
36
+ def __init__ (self , workspace_client : WorkspaceClient , extra_params : ExtraParams | None = None ):
37
+ super ().__init__ (workspace_client )
38
+
39
+ extra_params = extra_params or ExtraParams ()
40
+
41
+ self ._column_names = {
42
+ ColumnArguments .ERRORS : extra_params .column_names .get (
43
+ ColumnArguments .ERRORS .value , DefaultColumnNames .ERRORS .value
44
+ ),
45
+ ColumnArguments .WARNINGS : extra_params .column_names .get (
46
+ ColumnArguments .WARNINGS .value , DefaultColumnNames .WARNINGS .value
47
+ ),
48
+ }
49
+
28
50
29
51
def apply_checks (self , df : DataFrame , checks : list [DQRule ]) -> DataFrame :
30
52
if not checks :
31
53
return self ._append_empty_checks (df )
32
54
33
55
warning_checks = self ._get_check_columns (checks , Criticality .WARN .value )
34
56
error_checks = self ._get_check_columns (checks , Criticality .ERROR .value )
35
- ndf = self ._create_results_map (df , error_checks , Columns . ERRORS . value )
36
- ndf = self ._create_results_map (ndf , warning_checks , Columns . WARNINGS . value )
57
+ ndf = self ._create_results_map (df , error_checks , self . _column_names [ ColumnArguments . ERRORS ] )
58
+ ndf = self ._create_results_map (ndf , warning_checks , self . _column_names [ ColumnArguments . WARNINGS ] )
37
59
38
60
return ndf
39
61
@@ -57,12 +79,13 @@ def apply_checks_by_metadata_and_split(
57
79
58
80
return good_df , bad_df
59
81
82
+
60
83
def apply_checks_by_metadata (
61
- self , df : DataFrame , checks : list [dict ], glbs : dict [str , Any ] | None = None
62
- ) -> DataFrame :
63
- dq_rule_checks = self .build_checks_by_metadata (checks , glbs )
84
+ self , df : DataFrame , checks : list [dict ], glbs : dict [str , Any ] | None = None
85
+ ) -> DataFrame :
86
+ dq_rule_checks = self .build_checks_by_metadata (checks , glbs )
64
87
65
- return self .apply_checks (df , dq_rule_checks )
88
+ return self .apply_checks (df , dq_rule_checks )
66
89
67
90
@staticmethod
68
91
def validate_checks (checks : list [dict ], glbs : dict [str , Any ] | None = None ) -> ChecksValidationStatus :
@@ -77,13 +100,11 @@ def validate_checks(checks: list[dict], glbs: dict[str, Any] | None = None) -> C
77
100
78
101
return status
79
102
80
- @staticmethod
81
- def get_invalid (df : DataFrame ) -> DataFrame :
82
- return df .where (F .col (Columns .ERRORS .value ).isNotNull () | F .col (Columns .WARNINGS .value ).isNotNull ())
103
+ def get_invalid (self , df : DataFrame ) -> DataFrame :
104
+ return df .where (F .col (self ._column_names [ColumnArguments .ERRORS ]).isNotNull () | F .col (self ._column_names [ColumnArguments .WARNINGS ]).isNotNull ())
83
105
84
- @staticmethod
85
- def get_valid (df : DataFrame ) -> DataFrame :
86
- return df .where (F .col (Columns .ERRORS .value ).isNull ()).drop (Columns .ERRORS .value , Columns .WARNINGS .value )
106
+ def get_valid (self , df : DataFrame ) -> DataFrame :
107
+ return df .where (F .col (self ._column_names [ColumnArguments .ERRORS ]).isNull ()).drop (self ._column_names [ColumnArguments .ERRORS ], self ._column_names [ColumnArguments .WARNINGS ])
87
108
88
109
@staticmethod
89
110
def load_checks_from_local_file (path : str ) -> list [dict ]:
@@ -177,17 +198,16 @@ def _get_check_columns(checks: list[DQRule], criticality: str) -> list[DQRule]:
177
198
"""
178
199
return [check for check in checks if check .rule_criticality == criticality ]
179
200
180
- @staticmethod
181
- def _append_empty_checks (df : DataFrame ) -> DataFrame :
201
+ def _append_empty_checks (self , df : DataFrame ) -> DataFrame :
182
202
"""Append empty checks at the end of dataframe.
183
203
184
204
:param df: dataframe without checks
185
205
:return: dataframe with checks
186
206
"""
187
207
return df .select (
188
208
"*" ,
189
- F .lit (None ).cast ("map<string, string>" ).alias (Columns . ERRORS . value ),
190
- F .lit (None ).cast ("map<string, string>" ).alias (Columns . WARNINGS . value ),
209
+ F .lit (None ).cast ("map<string, string>" ).alias (self . _column_names [ ColumnArguments . ERRORS ] ),
210
+ F .lit (None ).cast ("map<string, string>" ).alias (self . _column_names [ ColumnArguments . WARNINGS ] ),
191
211
)
192
212
193
213
@staticmethod
@@ -350,9 +370,9 @@ def _resolve_function(func_name: str, glbs: dict[str, Any] | None = None, fail_o
350
370
class DQEngine (DQEngineBase ):
351
371
"""Data Quality Engine class to apply data quality checks to a given dataframe."""
352
372
353
- def __init__ (self , workspace_client : WorkspaceClient , engine : DQEngineCoreBase | None = None ):
373
+ def __init__ (self , workspace_client : WorkspaceClient , engine : DQEngineCoreBase | None = None , extra_params : ExtraParams | None = None ):
354
374
super ().__init__ (workspace_client )
355
- self ._engine = engine or DQEngineCore (workspace_client )
375
+ self ._engine = engine or DQEngineCore (workspace_client , extra_params )
356
376
357
377
def apply_checks (self , df : DataFrame , checks : list [DQRule ]) -> DataFrame :
358
378
return self ._engine .apply_checks (df , checks )
@@ -374,13 +394,11 @@ def apply_checks_by_metadata(
374
394
def validate_checks (checks : list [dict ], glbs : dict [str , Any ] | None = None ) -> ChecksValidationStatus :
375
395
return DQEngineCore .validate_checks (checks , glbs )
376
396
377
- @staticmethod
378
- def get_invalid (df : DataFrame ) -> DataFrame :
379
- return DQEngineCore .get_invalid (df )
397
+ def get_invalid (self , df : DataFrame ) -> DataFrame :
398
+ return self ._engine .get_invalid (df )
380
399
381
- @staticmethod
382
- def get_valid (df : DataFrame ) -> DataFrame :
383
- return DQEngineCore .get_valid (df )
400
+ def get_valid (self , df : DataFrame ) -> DataFrame :
401
+ return self ._engine .get_valid (df )
384
402
385
403
@staticmethod
386
404
def load_checks_from_local_file (path : str ) -> list [dict ]:
0 commit comments