-
Notifications
You must be signed in to change notification settings - Fork 32
/
Copy pathutils.py
60 lines (46 loc) · 2.11 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import re
import yaml
from pyspark.sql import Column
from pyspark.sql import SparkSession
STORAGE_PATH_PATTERN = re.compile(r"^(/|s3:/|abfss:/|gs:/)")
UNITY_CATALOG_TABLE_PATTERN = re.compile(r"^[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+\.[a-zA-Z0-9_]+$")
def get_column_name(col: Column) -> str:
"""
PySpark doesn't allow to directly access the column name with respect to aliases from an unbound column.
It is necessary to parse this out from the string representation.
This works on columns with one or more aliases as well as not aliased columns.
:param col: Column
:return: Col name alias as str
"""
return str(col).removeprefix("Column<'").removesuffix("'>").split(" AS ")[-1]
def read_input_data(spark: SparkSession, input_location: str | None, input_format: str | None):
"""
Reads input data from the specified location and format.
:param spark: SparkSession
:param input_location: The input data location.
:param input_format: The input data format.
"""
if not input_location:
raise ValueError("Input location not configured")
if UNITY_CATALOG_TABLE_PATTERN.match(input_location):
return spark.read.table(input_location) # must provide 3-level Unity Catalog namespace
if STORAGE_PATH_PATTERN.match(input_location):
if not input_format:
raise ValueError("Input format not configured")
# TODO handle spark options while reading data from a file location
# https://github.com/databrickslabs/dqx/issues/161
return spark.read.format(str(input_format)).load(input_location)
raise ValueError(
f"Invalid input location. It must be Unity Catalog table / view or storage location, " f"given {input_location}"
)
def deserialize_dicts(checks: list[dict[str, str]]) -> list[dict]:
"""
deserialize string fields instances containing dictionaries
@param checks: list of checks
@return:
"""
for item in checks:
for key, value in item.items():
if value.startswith("{") and value.endswith("}"):
item[key] = yaml.safe_load(value.replace("'", '"'))
return checks