Skip to content

Commit 39042be

Browse files
Introducing to Spark Remote for Unit Tests (#151)
## Changes Introduced Managed Spark Remote to run tests effectively. In future, we can also introduce uc dependency to resolve other UC-related integrations if and when required. Note: This makes the unit test more of an integration test and gets us away from mock sparksession. ### Linked issues Resolves #.. ### Tests <!-- How is this tested? Please see the checklist below and also describe any other relevant tests --> - [x] manually tested - [x] added unit tests - [x] added integration tests --------- Co-authored-by: Marcin Wojtyczka <marcin.wojtyczka@databricks.com>
1 parent c61f6d4 commit 39042be

File tree

13 files changed

+194
-74
lines changed

13 files changed

+194
-74
lines changed

.github/scripts/setup_spark_remote.sh

+61
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,61 @@
1+
#!/usr/bin/env bash
2+
3+
set -xve
4+
echo "Setting up spark-connect"
5+
6+
mkdir -p "$HOME"/spark
7+
cd "$HOME"/spark || exit 1
8+
9+
version=$(wget -O - https://dlcdn.apache.org/spark/ | grep 'href="spark' | grep -v 'preview' | sed 's:</a>:\n:g' | sed -n 's/.*>//p' | tr -d spark- | tr -d / | sort -r --version-sort | head -1)
10+
if [ -z "$version" ]; then
11+
echo "Failed to extract Spark version"
12+
exit 1
13+
fi
14+
15+
spark=spark-${version}-bin-hadoop3
16+
spark_connect="spark-connect_2.12"
17+
18+
mkdir -p "${spark}"
19+
20+
21+
SERVER_SCRIPT=$HOME/spark/${spark}/sbin/start-connect-server.sh
22+
23+
## check the spark version already exist, if not download the respective version
24+
if [ -f "${SERVER_SCRIPT}" ];then
25+
echo "Spark Version already exists"
26+
else
27+
if [ -f "${spark}.tgz" ];then
28+
echo "${spark}.tgz already exists"
29+
else
30+
wget "https://dlcdn.apache.org/spark/spark-${version}/${spark}.tgz"
31+
fi
32+
tar -xvf "${spark}.tgz"
33+
fi
34+
35+
cd "${spark}" || exit 1
36+
## check spark remote is running,if not start the spark remote
37+
result=$(${SERVER_SCRIPT} --packages org.apache.spark:${spark_connect}:"${version}" > "$HOME"/spark/log.out; echo $?)
38+
39+
if [ "$result" -ne 0 ]; then
40+
count=$(tail "${HOME}"/spark/log.out | grep -c "SparkConnectServer running as process")
41+
if [ "${count}" == "0" ]; then
42+
echo "Failed to start the server"
43+
exit 1
44+
fi
45+
# Wait for the server to start by pinging localhost:4040
46+
echo "Waiting for the server to start..."
47+
for i in {1..30}; do
48+
if nc -z localhost 4040; then
49+
echo "Server is up and running"
50+
break
51+
fi
52+
echo "Server not yet available, retrying in 5 seconds..."
53+
sleep 5
54+
done
55+
56+
if ! nc -z localhost 4040; then
57+
echo "Failed to start the server within the expected time"
58+
exit 1
59+
fi
60+
fi
61+
echo "Started the Server"

.github/workflows/push.yml

+6-1
Original file line numberDiff line numberDiff line change
@@ -35,10 +35,15 @@ jobs:
3535
cache-dependency-path: '**/pyproject.toml'
3636
python-version: ${{ matrix.pyVersion }}
3737

38+
- name: Setup Spark Remote
39+
run: |
40+
pip install hatch==1.9.4
41+
make setup_spark_remote
42+
3843
- name: Run unit tests
3944
run: |
4045
pip install hatch==1.9.4
41-
make test
46+
make ci-test
4247
4348
- name: Publish test coverage
4449
uses: codecov/codecov-action@v5

Makefile

+6-1
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,17 @@ lint:
1717
fmt:
1818
hatch run fmt
1919

20-
test:
20+
ci-test:
2121
hatch run test
2222

2323
integration:
2424
hatch run integration
2525

26+
setup_spark_remote:
27+
.github/scripts/setup_spark_remote.sh
28+
29+
test: setup_spark_remote ci-test
30+
2631
coverage:
2732
hatch run coverage && open htmlcov/index.html
2833

demos/dqx_demo_library.py

+14-14
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@
5959
print(dlt_expectations)
6060

6161
# save generated checks in a workspace file
62-
user_name = spark.sql('select current_user() as user').collect()[0]['user']
62+
user_name = spark.sql("select current_user() as user").collect()[0]["user"]
6363
checks_file = f"/Workspace/Users/{user_name}/dqx_demo_checks.yml"
6464
dq_engine = DQEngine(ws)
6565
dq_engine.save_checks_in_workspace_file(checks, workspace_path=checks_file)
@@ -143,7 +143,7 @@
143143
col_name: col3
144144
145145
- criticality: error
146-
filter: col1<3
146+
filter: col1 < 3
147147
check:
148148
function: is_not_null_and_not_empty
149149
arguments:
@@ -193,17 +193,17 @@
193193
criticality="error",
194194
check_func=is_not_null).get_rules() + [
195195
DQRule( # define rule for a single column
196-
name='col3_is_null_or_empty',
197-
criticality='error',
198-
check=is_not_null_and_not_empty('col3')),
196+
name="col3_is_null_or_empty",
197+
criticality="error",
198+
check=is_not_null_and_not_empty("col3")),
199199
DQRule( # define rule with a filter
200-
name='col_4_is_null_or_empty',
201-
criticality='error',
202-
filter='col1<3',
203-
check=is_not_null_and_not_empty('col4')),
200+
name="col_4_is_null_or_empty",
201+
criticality="error",
202+
filter="col1 < 3",
203+
check=is_not_null_and_not_empty("col4")),
204204
DQRule( # name auto-generated if not provided
205-
criticality='warn',
206-
check=value_is_in_list('col4', ['1', '2']))
205+
criticality="warn",
206+
check=value_is_in_list("col4", ["1", "2"]))
207207
]
208208

209209
schema = "col1: int, col2: int, col3: int, col4 int"
@@ -384,9 +384,9 @@ def ends_with_foo(col_name: str) -> Column:
384384
input_df = spark.createDataFrame([["str1"], ["foo"], ["str3"]], schema)
385385

386386
checks = [ DQRule(
387-
name='col_1_is_null_or_empty',
388-
criticality='error',
389-
check=is_not_null_and_not_empty('col1')),
387+
name="col_1_is_null_or_empty",
388+
criticality="error",
389+
check=is_not_null_and_not_empty("col1")),
390390
]
391391

392392
valid_and_quarantined_df = dq_engine.apply_checks(input_df, checks)

demos/dqx_demo_tool.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@
4545
import glob
4646
import os
4747

48-
user_name = spark.sql('select current_user() as user').collect()[0]['user']
48+
user_name = spark.sql("select current_user() as user").collect()[0]["user"]
4949
dqx_wheel_files = glob.glob(f"/Workspace/Users/{user_name}/.dqx/wheels/databricks_labs_dqx-*.whl")
5050
dqx_latest_wheel = max(dqx_wheel_files, key=os.path.getctime)
5151
%pip install {dqx_latest_wheel}
@@ -210,7 +210,7 @@
210210
# COMMAND ----------
211211

212212
print(f"Saving quarantined data to {run_config.quarantine_table}")
213-
quarantine_catalog, quarantine_schema, _ = run_config.quarantine_table.split('.')
213+
quarantine_catalog, quarantine_schema, _ = run_config.quarantine_table.split(".")
214214

215215
spark.sql(f"CREATE CATALOG IF NOT EXISTS {quarantine_catalog}")
216216
spark.sql(f"CREATE SCHEMA IF NOT EXISTS {quarantine_catalog}.{quarantine_schema}")

docs/dqx/docs/dev/contributing.mdx

+8-3
Original file line numberDiff line numberDiff line change
@@ -86,13 +86,18 @@ Before every commit, apply the consistent formatting of the code, as we want our
8686
make fmt
8787
```
8888

89-
Before every commit, run automated bug detector (`make lint`) and unit tests (`make test`) to ensure that automated
90-
pull request checks do pass, before your code is reviewed by others:
89+
Before every commit, run automated bug detector and unit tests to ensure that automated
90+
pull request checks do pass, before your code is reviewed by others:
9191
```shell
9292
make lint
93+
make setup_spark_remote
9394
make test
9495
```
9596

97+
The command `make setup_spark_remote` sets up the environment for running unit tests and is required one time only.
98+
DQX uses Databricks Connect as a test dependency, which restricts the creation of a Spark session in local mode.
99+
To enable spark local execution for unit testing, the command install spark remote.
100+
96101
### Local setup for integration tests and code coverage
97102

98103
Note that integration tests and code coverage are run automatically when you create a Pull Request in Github.
@@ -215,7 +220,7 @@ Here are the example steps to submit your first contribution:
215220
7. `make fmt`
216221
8. `make lint`
217222
9. .. fix if any
218-
10. `make test` and `make integration`, optionally `make coverage` to get test coverage report
223+
10. `make setup_spark_remote`, make test` and `make integration`, optionally `make coverage` to get test coverage report
219224
11. .. fix if any issues
220225
12. `git commit -S -a -m "message"`.
221226
Make sure to enter a meaningful commit message title.

docs/dqx/docs/guide.mdx

+4-4
Original file line numberDiff line numberDiff line change
@@ -251,11 +251,11 @@ checks = DQRuleColSet( # define rule for multiple columns at once
251251
DQRule( # define rule with a filter
252252
name="col_4_is_null_or_empty",
253253
criticality="error",
254-
filter="col1<3",
254+
filter="col1 < 3",
255255
check=is_not_null_and_not_empty("col4")),
256256
DQRule( # name auto-generated if not provided
257-
criticality='warn',
258-
check=value_is_in_list('col4', ['1', '2']))
257+
criticality="warn",
258+
check=value_is_in_list("col4", ["1", "2"]))
259259
]
260260

261261
input_df = spark.read.table("catalog1.schema1.table1")
@@ -294,7 +294,7 @@ checks = yaml.safe_load("""
294294
col_name: col3
295295
296296
- criticality: error
297-
filter: col1<3
297+
filter: col1 < 3
298298
check:
299299
function: is_not_null_and_not_empty
300300
arguments:

docs/dqx/docs/reference.mdx

+11-9
Original file line numberDiff line numberDiff line change
@@ -36,17 +36,17 @@ The following quality rules / functions are currently available:
3636

3737
You can check implementation details of the rules [here](https://github.com/databrickslabs/dqx/blob/main/src/databricks/labs/dqx/col_functions.py).
3838

39-
#### Apply Filter on quality rule
39+
### Apply filters on checks
4040

41-
If you want to apply a filter to a part of the dataframe, you can add a `filter` to the rule.
42-
For example, if you want to check that a col `a` is not null when `b` is positive, you can do it like this:
41+
You can apply checks to a part of the DataFrame by using a `filter`.
42+
For example, to ensure that a column `a` is not null only when a column `b` is positive, you can define the check as follows:
4343
```yaml
44-
- criticality: "error"
45-
filter: b>0
44+
- criticality: error
45+
filter: b > 0
4646
check:
47-
function: "is_not_null"
47+
function: is_not_null
4848
arguments:
49-
col_name: "a"
49+
col_name: a
5050
```
5151
5252
### Creating your own checks
@@ -265,7 +265,7 @@ def test_dq():
265265

266266
schema = "a: int, b: int, c: int"
267267
expected_schema = schema + ", _errors: map<string,string>, _warnings: map<string,string>"
268-
test_df = spark.createDataFrame([[1, 3, 3]], schema)
268+
test_df = spark.createDataFrame([[1, None, 3]], schema)
269269

270270
checks = [
271271
DQRule(name="col_a_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("a")),
@@ -275,6 +275,8 @@ def test_dq():
275275
dq_engine = DQEngine(ws)
276276
df = dq_engine.apply_checks(test_df, checks)
277277

278-
expected_df = spark.createDataFrame([[1, 3, 3, None, None]], expected_schema)
278+
expected_df = spark.createDataFrame(
279+
[[1, None, 3, {"col_b_is_null_or_empty": "Column b is null or empty"}, None]], expected_schema
280+
)
279281
assert_df_equality(df, expected_df)
280282
```

src/databricks/labs/dqx/utils.py

+2
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def read_input_data(spark: SparkSession, input_location: str | None, input_forma
3838
if STORAGE_PATH_PATTERN.match(input_location):
3939
if not input_format:
4040
raise ValueError("Input format not configured")
41+
# TODO handle spark options while reading data from a file location
42+
# https://github.com/databrickslabs/dqx/issues/161
4143
return spark.read.format(str(input_format)).load(input_location)
4244

4345
raise ValueError(

tests/integration/test_utils.py

+31
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
from chispa.dataframe_comparer import assert_df_equality # type: ignore
2+
from databricks.labs.dqx.utils import read_input_data
3+
4+
5+
def test_read_input_data_unity_catalog_table(spark, make_schema, make_random):
6+
catalog_name = "main"
7+
schema_name = make_schema(catalog_name=catalog_name).name
8+
input_location = f"{catalog_name}.{schema_name}.{make_random(6).lower()}"
9+
input_format = None
10+
11+
schema = "a: int, b: int"
12+
input_df = spark.createDataFrame([[1, 2]], schema)
13+
input_df.write.format("delta").saveAsTable(input_location)
14+
15+
result_df = read_input_data(spark, input_location, input_format)
16+
assert_df_equality(input_df, result_df)
17+
18+
19+
def test_read_input_data_workspace_file(spark, make_schema, make_volume):
20+
catalog_name = "main"
21+
schema_name = make_schema(catalog_name=catalog_name).name
22+
info = make_volume(catalog_name=catalog_name, schema_name=schema_name)
23+
input_location = info.full_name
24+
input_format = "delta"
25+
26+
schema = "a: int, b: int"
27+
input_df = spark.createDataFrame([[1, 2]], schema)
28+
input_df.write.format("delta").saveAsTable(input_location)
29+
30+
result_df = read_input_data(spark, input_location, input_format)
31+
assert_df_equality(input_df, result_df)

tests/unit/conftest.py

+2-3
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
import os
22
from pathlib import Path
3-
from unittest.mock import Mock
43
from pyspark.sql import SparkSession
54
import pytest
65

76

87
@pytest.fixture
9-
def spark_session_mock():
10-
return Mock(spec=SparkSession)
8+
def spark_local():
9+
return SparkSession.builder.appName("DQX Test").remote("sc://localhost").getOrCreate()
1110

1211

1312
@pytest.fixture

tests/unit/test_apply_checks.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
from unittest.mock import MagicMock
2+
3+
from chispa.dataframe_comparer import assert_df_equality # type: ignore
4+
from databricks.labs.dqx.col_functions import is_not_null_and_not_empty
5+
from databricks.labs.dqx.engine import DQEngine
6+
from databricks.labs.dqx.rule import DQRule
7+
from databricks.sdk import WorkspaceClient
8+
9+
10+
def test_apply_checks(spark_local):
11+
ws = MagicMock(spec=WorkspaceClient, **{"catalogs.list.return_value": []})
12+
13+
schema = "a: int, b: int, c: int"
14+
expected_schema = schema + ", _errors: map<string,string>, _warnings: map<string,string>"
15+
test_df = spark_local.createDataFrame([[1, None, 3]], schema)
16+
17+
checks = [
18+
DQRule(name="col_a_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("a")),
19+
DQRule(name="col_b_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("b")),
20+
]
21+
22+
dq_engine = DQEngine(ws)
23+
df = dq_engine.apply_checks(test_df, checks)
24+
25+
expected_df = spark_local.createDataFrame(
26+
[[1, None, 3, {"col_b_is_null_or_empty": "Column b is null or empty"}, None]], expected_schema
27+
)
28+
assert_df_equality(df, expected_df)

0 commit comments

Comments
 (0)