diff --git a/.github/scripts/setup_spark_remote.sh b/.github/scripts/setup_spark_remote.sh new file mode 100755 index 00000000..3e148e17 --- /dev/null +++ b/.github/scripts/setup_spark_remote.sh @@ -0,0 +1,61 @@ +#!/usr/bin/env bash + +set -xve +echo "Setting up spark-connect" + +mkdir -p "$HOME"/spark +cd "$HOME"/spark || exit 1 + +version=$(wget -O - https://dlcdn.apache.org/spark/ | grep 'href="spark' | grep -v 'preview' | sed 's::\n:g' | sed -n 's/.*>//p' | tr -d spark- | tr -d / | sort -r --version-sort | head -1) +if [ -z "$version" ]; then + echo "Failed to extract Spark version" + exit 1 +fi + +spark=spark-${version}-bin-hadoop3 +spark_connect="spark-connect_2.12" + +mkdir -p "${spark}" + + +SERVER_SCRIPT=$HOME/spark/${spark}/sbin/start-connect-server.sh + +## check the spark version already exist, if not download the respective version +if [ -f "${SERVER_SCRIPT}" ];then + echo "Spark Version already exists" +else + if [ -f "${spark}.tgz" ];then + echo "${spark}.tgz already exists" + else + wget "https://dlcdn.apache.org/spark/spark-${version}/${spark}.tgz" + fi + tar -xvf "${spark}.tgz" +fi + +cd "${spark}" || exit 1 +## check spark remote is running,if not start the spark remote +result=$(${SERVER_SCRIPT} --packages org.apache.spark:${spark_connect}:"${version}" > "$HOME"/spark/log.out; echo $?) + +if [ "$result" -ne 0 ]; then + count=$(tail "${HOME}"/spark/log.out | grep -c "SparkConnectServer running as process") + if [ "${count}" == "0" ]; then + echo "Failed to start the server" + exit 1 + fi + # Wait for the server to start by pinging localhost:4040 + echo "Waiting for the server to start..." + for i in {1..30}; do + if nc -z localhost 4040; then + echo "Server is up and running" + break + fi + echo "Server not yet available, retrying in 5 seconds..." + sleep 5 + done + + if ! nc -z localhost 4040; then + echo "Failed to start the server within the expected time" + exit 1 + fi +fi +echo "Started the Server" diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 31d27c0e..48663df7 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -35,10 +35,15 @@ jobs: cache-dependency-path: '**/pyproject.toml' python-version: ${{ matrix.pyVersion }} + - name: Setup Spark Remote + run: | + pip install hatch==1.9.4 + make setup_spark_remote + - name: Run unit tests run: | pip install hatch==1.9.4 - make test + make ci-test - name: Publish test coverage uses: codecov/codecov-action@v5 diff --git a/Makefile b/Makefile index c9713b52..e9ec38ac 100644 --- a/Makefile +++ b/Makefile @@ -17,12 +17,17 @@ lint: fmt: hatch run fmt -test: +ci-test: hatch run test integration: hatch run integration +setup_spark_remote: + .github/scripts/setup_spark_remote.sh + +test: setup_spark_remote ci-test + coverage: hatch run coverage && open htmlcov/index.html diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index 26a736c0..5140e420 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -59,7 +59,7 @@ print(dlt_expectations) # save generated checks in a workspace file -user_name = spark.sql('select current_user() as user').collect()[0]['user'] +user_name = spark.sql("select current_user() as user").collect()[0]["user"] checks_file = f"/Workspace/Users/{user_name}/dqx_demo_checks.yml" dq_engine = DQEngine(ws) dq_engine.save_checks_in_workspace_file(checks, workspace_path=checks_file) @@ -143,7 +143,7 @@ col_name: col3 - criticality: error - filter: col1<3 + filter: col1 < 3 check: function: is_not_null_and_not_empty arguments: @@ -193,17 +193,17 @@ criticality="error", check_func=is_not_null).get_rules() + [ DQRule( # define rule for a single column - name='col3_is_null_or_empty', - criticality='error', - check=is_not_null_and_not_empty('col3')), + name="col3_is_null_or_empty", + criticality="error", + check=is_not_null_and_not_empty("col3")), DQRule( # define rule with a filter - name='col_4_is_null_or_empty', - criticality='error', - filter='col1<3', - check=is_not_null_and_not_empty('col4')), + name="col_4_is_null_or_empty", + criticality="error", + filter="col1 < 3", + check=is_not_null_and_not_empty("col4")), DQRule( # name auto-generated if not provided - criticality='warn', - check=value_is_in_list('col4', ['1', '2'])) + criticality="warn", + check=value_is_in_list("col4", ["1", "2"])) ] schema = "col1: int, col2: int, col3: int, col4 int" @@ -384,9 +384,9 @@ def ends_with_foo(col_name: str) -> Column: input_df = spark.createDataFrame([["str1"], ["foo"], ["str3"]], schema) checks = [ DQRule( - name='col_1_is_null_or_empty', - criticality='error', - check=is_not_null_and_not_empty('col1')), + name="col_1_is_null_or_empty", + criticality="error", + check=is_not_null_and_not_empty("col1")), ] valid_and_quarantined_df = dq_engine.apply_checks(input_df, checks) diff --git a/demos/dqx_demo_tool.py b/demos/dqx_demo_tool.py index 9c9c970d..3821bd07 100644 --- a/demos/dqx_demo_tool.py +++ b/demos/dqx_demo_tool.py @@ -45,7 +45,7 @@ import glob import os -user_name = spark.sql('select current_user() as user').collect()[0]['user'] +user_name = spark.sql("select current_user() as user").collect()[0]["user"] dqx_wheel_files = glob.glob(f"/Workspace/Users/{user_name}/.dqx/wheels/databricks_labs_dqx-*.whl") dqx_latest_wheel = max(dqx_wheel_files, key=os.path.getctime) %pip install {dqx_latest_wheel} @@ -210,7 +210,7 @@ # COMMAND ---------- print(f"Saving quarantined data to {run_config.quarantine_table}") -quarantine_catalog, quarantine_schema, _ = run_config.quarantine_table.split('.') +quarantine_catalog, quarantine_schema, _ = run_config.quarantine_table.split(".") spark.sql(f"CREATE CATALOG IF NOT EXISTS {quarantine_catalog}") spark.sql(f"CREATE SCHEMA IF NOT EXISTS {quarantine_catalog}.{quarantine_schema}") diff --git a/docs/dqx/docs/dev/contributing.mdx b/docs/dqx/docs/dev/contributing.mdx index 3a8f7c20..cfd18b02 100644 --- a/docs/dqx/docs/dev/contributing.mdx +++ b/docs/dqx/docs/dev/contributing.mdx @@ -86,13 +86,18 @@ Before every commit, apply the consistent formatting of the code, as we want our make fmt ``` -Before every commit, run automated bug detector (`make lint`) and unit tests (`make test`) to ensure that automated -pull request checks do pass, before your code is reviewed by others: +Before every commit, run automated bug detector and unit tests to ensure that automated +pull request checks do pass, before your code is reviewed by others: ```shell make lint +make setup_spark_remote make test ``` +The command `make setup_spark_remote` sets up the environment for running unit tests and is required one time only. +DQX uses Databricks Connect as a test dependency, which restricts the creation of a Spark session in local mode. +To enable spark local execution for unit testing, the command install spark remote. + ### Local setup for integration tests and code coverage Note that integration tests and code coverage are run automatically when you create a Pull Request in Github. @@ -215,7 +220,7 @@ Here are the example steps to submit your first contribution: 7. `make fmt` 8. `make lint` 9. .. fix if any -10. `make test` and `make integration`, optionally `make coverage` to get test coverage report +10. `make setup_spark_remote`, make test` and `make integration`, optionally `make coverage` to get test coverage report 11. .. fix if any issues 12. `git commit -S -a -m "message"`. Make sure to enter a meaningful commit message title. diff --git a/docs/dqx/docs/guide.mdx b/docs/dqx/docs/guide.mdx index 76e3269c..732e8d6a 100644 --- a/docs/dqx/docs/guide.mdx +++ b/docs/dqx/docs/guide.mdx @@ -251,11 +251,11 @@ checks = DQRuleColSet( # define rule for multiple columns at once DQRule( # define rule with a filter name="col_4_is_null_or_empty", criticality="error", - filter="col1<3", + filter="col1 < 3", check=is_not_null_and_not_empty("col4")), DQRule( # name auto-generated if not provided - criticality='warn', - check=value_is_in_list('col4', ['1', '2'])) + criticality="warn", + check=value_is_in_list("col4", ["1", "2"])) ] input_df = spark.read.table("catalog1.schema1.table1") @@ -294,7 +294,7 @@ checks = yaml.safe_load(""" col_name: col3 - criticality: error - filter: col1<3 + filter: col1 < 3 check: function: is_not_null_and_not_empty arguments: diff --git a/docs/dqx/docs/reference.mdx b/docs/dqx/docs/reference.mdx index 1c56c216..b7f8c248 100644 --- a/docs/dqx/docs/reference.mdx +++ b/docs/dqx/docs/reference.mdx @@ -36,17 +36,17 @@ The following quality rules / functions are currently available: You can check implementation details of the rules [here](https://github.com/databrickslabs/dqx/blob/main/src/databricks/labs/dqx/col_functions.py). -#### Apply Filter on quality rule +### Apply filters on checks -If you want to apply a filter to a part of the dataframe, you can add a `filter` to the rule. -For example, if you want to check that a col `a` is not null when `b` is positive, you can do it like this: +You can apply checks to a part of the DataFrame by using a `filter`. +For example, to ensure that a column `a` is not null only when a column `b` is positive, you can define the check as follows: ```yaml -- criticality: "error" - filter: b>0 +- criticality: error + filter: b > 0 check: - function: "is_not_null" + function: is_not_null arguments: - col_name: "a" + col_name: a ``` ### Creating your own checks @@ -265,7 +265,7 @@ def test_dq(): schema = "a: int, b: int, c: int" expected_schema = schema + ", _errors: map, _warnings: map" - test_df = spark.createDataFrame([[1, 3, 3]], schema) + test_df = spark.createDataFrame([[1, None, 3]], schema) checks = [ DQRule(name="col_a_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("a")), @@ -275,6 +275,8 @@ def test_dq(): dq_engine = DQEngine(ws) df = dq_engine.apply_checks(test_df, checks) - expected_df = spark.createDataFrame([[1, 3, 3, None, None]], expected_schema) + expected_df = spark.createDataFrame( + [[1, None, 3, {"col_b_is_null_or_empty": "Column b is null or empty"}, None]], expected_schema + ) assert_df_equality(df, expected_df) ``` diff --git a/src/databricks/labs/dqx/utils.py b/src/databricks/labs/dqx/utils.py index c087df1c..885906ec 100644 --- a/src/databricks/labs/dqx/utils.py +++ b/src/databricks/labs/dqx/utils.py @@ -38,6 +38,8 @@ def read_input_data(spark: SparkSession, input_location: str | None, input_forma if STORAGE_PATH_PATTERN.match(input_location): if not input_format: raise ValueError("Input format not configured") + # TODO handle spark options while reading data from a file location + # https://github.com/databrickslabs/dqx/issues/161 return spark.read.format(str(input_format)).load(input_location) raise ValueError( diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py new file mode 100644 index 00000000..9e4e7f44 --- /dev/null +++ b/tests/integration/test_utils.py @@ -0,0 +1,31 @@ +from chispa.dataframe_comparer import assert_df_equality # type: ignore +from databricks.labs.dqx.utils import read_input_data + + +def test_read_input_data_unity_catalog_table(spark, make_schema, make_random): + catalog_name = "main" + schema_name = make_schema(catalog_name=catalog_name).name + input_location = f"{catalog_name}.{schema_name}.{make_random(6).lower()}" + input_format = None + + schema = "a: int, b: int" + input_df = spark.createDataFrame([[1, 2]], schema) + input_df.write.format("delta").saveAsTable(input_location) + + result_df = read_input_data(spark, input_location, input_format) + assert_df_equality(input_df, result_df) + + +def test_read_input_data_workspace_file(spark, make_schema, make_volume): + catalog_name = "main" + schema_name = make_schema(catalog_name=catalog_name).name + info = make_volume(catalog_name=catalog_name, schema_name=schema_name) + input_location = info.full_name + input_format = "delta" + + schema = "a: int, b: int" + input_df = spark.createDataFrame([[1, 2]], schema) + input_df.write.format("delta").saveAsTable(input_location) + + result_df = read_input_data(spark, input_location, input_format) + assert_df_equality(input_df, result_df) diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index d5fa1f94..a4c5499d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,13 +1,12 @@ import os from pathlib import Path -from unittest.mock import Mock from pyspark.sql import SparkSession import pytest @pytest.fixture -def spark_session_mock(): - return Mock(spec=SparkSession) +def spark_local(): + return SparkSession.builder.appName("DQX Test").remote("sc://localhost").getOrCreate() @pytest.fixture diff --git a/tests/unit/test_apply_checks.py b/tests/unit/test_apply_checks.py new file mode 100644 index 00000000..7d5a93af --- /dev/null +++ b/tests/unit/test_apply_checks.py @@ -0,0 +1,28 @@ +from unittest.mock import MagicMock + +from chispa.dataframe_comparer import assert_df_equality # type: ignore +from databricks.labs.dqx.col_functions import is_not_null_and_not_empty +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.rule import DQRule +from databricks.sdk import WorkspaceClient + + +def test_apply_checks(spark_local): + ws = MagicMock(spec=WorkspaceClient, **{"catalogs.list.return_value": []}) + + schema = "a: int, b: int, c: int" + expected_schema = schema + ", _errors: map, _warnings: map" + test_df = spark_local.createDataFrame([[1, None, 3]], schema) + + checks = [ + DQRule(name="col_a_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("a")), + DQRule(name="col_b_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("b")), + ] + + dq_engine = DQEngine(ws) + df = dq_engine.apply_checks(test_df, checks) + + expected_df = spark_local.createDataFrame( + [[1, None, 3, {"col_b_is_null_or_empty": "Column b is null or empty"}, None]], expected_schema + ) + assert_df_equality(df, expected_df) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 29437494..8d71b855 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,7 @@ +import tempfile +import os import pyspark.sql.functions as F +from pyspark.sql.types import Row import pytest from databricks.labs.dqx.utils import read_input_data, get_column_name @@ -27,57 +30,36 @@ def test_get_col_name_longer(): assert actual == "local" -def test_read_input_data_unity_catalog_table(spark_session_mock): - input_location = "catalog.schema.table" - input_format = None - spark_session_mock.read.table.return_value = "dataframe" - - result = read_input_data(spark_session_mock, input_location, input_format) - - spark_session_mock.read.table.assert_called_once_with(input_location) - assert result == "dataframe" - - -def test_read_input_data_storage_path(spark_session_mock): - input_location = "s3://bucket/path" - input_format = "delta" - spark_session_mock.read.format.return_value.load.return_value = "dataframe" - - result = read_input_data(spark_session_mock, input_location, input_format) - - spark_session_mock.read.format.assert_called_once_with(input_format) - spark_session_mock.read.format.return_value.load.assert_called_once_with(input_location) - assert result == "dataframe" - - -def test_read_input_data_workspace_file(spark_session_mock): - input_location = "/folder/path" - input_format = "delta" - spark_session_mock.read.format.return_value.load.return_value = "dataframe" +def test_read_input_data_storage_path(spark_local): + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file.write(b"val1,val2\n") + temp_file_path = temp_file.name - result = read_input_data(spark_session_mock, input_location, input_format) + try: + input_location = temp_file_path + result = read_input_data(spark_local, input_location, "csv") + assert result.collect() == [Row(_c0='val1', _c1='val2')] - spark_session_mock.read.format.assert_called_once_with(input_format) - spark_session_mock.read.format.return_value.load.assert_called_once_with(input_location) - assert result == "dataframe" + finally: + os.remove(temp_file_path) -def test_read_input_data_no_input_location(spark_session_mock): +def test_read_input_data_no_input_location(spark_local): with pytest.raises(ValueError, match="Input location not configured"): - read_input_data(spark_session_mock, None, None) + read_input_data(spark_local, None, None) -def test_read_input_data_no_input_format(spark_session_mock): +def test_read_input_data_no_input_format(spark_local): input_location = "s3://bucket/path" input_format = None with pytest.raises(ValueError, match="Input format not configured"): - read_input_data(spark_session_mock, input_location, input_format) + read_input_data(spark_local, input_location, input_format) -def test_read_invalid_input_location(spark_session_mock): +def test_read_invalid_input_location(spark_local): input_location = "invalid/location" input_format = None with pytest.raises(ValueError, match="Invalid input location."): - read_input_data(spark_session_mock, input_location, input_format) + read_input_data(spark_local, input_location, input_format)