From 761be415f3d5ff9378954fee179bf2c251bb08b0 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Wed, 5 Feb 2025 11:16:44 +0530 Subject: [PATCH 01/16] Init Commit --- .github/scripts/setup_spark_remote.sh | 60 +++++++++++++++++++++++++++ .github/workflows/push.yml | 5 +++ Makefile | 3 ++ tests/unit/conftest.py | 3 +- tests/unit/test_utils.py | 5 ++- 5 files changed, 72 insertions(+), 4 deletions(-) create mode 100755 .github/scripts/setup_spark_remote.sh diff --git a/.github/scripts/setup_spark_remote.sh b/.github/scripts/setup_spark_remote.sh new file mode 100755 index 00000000..490e2248 --- /dev/null +++ b/.github/scripts/setup_spark_remote.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash + +set -xve + +mkdir -p "$HOME"/spark +cd "$HOME"/spark || exit 1 + +version=$(wget -O - https://dlcdn.apache.org/spark/ | grep 'href="spark' | grep -v 'preview' | sed 's::\n:g' | sed -n 's/.*>//p' | tr -d spark- | tr -d / | sort -r --version-sort | head -1) +if [ -z "$version" ]; then + echo "Failed to extract Spark version" + exit 1 +fi + +spark=spark-${version}-bin-hadoop3 +spark_connect="spark-connect_2.12" + +mkdir -p "${spark}" + + +SERVER_SCRIPT=$HOME/spark/${spark}/sbin/start-connect-server.sh + +## check the spark version already exist ,if not download the respective version +if [ -f "${SERVER_SCRIPT}" ];then + echo "Spark Version already exists" +else + if [ -f "${spark}.tgz" ];then + echo "${spark}.tgz already exists" + else + wget "https://dlcdn.apache.org/spark/spark-${version}/${spark}.tgz" + fi + tar -xvf "${spark}.tgz" +fi + +cd "${spark}" || exit 1 +## check spark remote is running,if not start the spark remote +result=$(${SERVER_SCRIPT} --packages org.apache.spark:${spark_connect}:"${version}" > "$HOME"/spark/log.out; echo $?) + +if [ "$result" -ne 0 ]; then + count=$(tail "${HOME}"/spark/log.out | grep -c "SparkConnectServer running as process") + if [ "${count}" == "0" ]; then + echo "Failed to start the server" + exit 1 + fi + # Wait for the server to start by pinging localhost:4040 + echo "Waiting for the server to start..." + for i in {1..30}; do + if nc -z localhost 4040; then + echo "Server is up and running" + break + fi + echo "Server not yet available, retrying in 5 seconds..." + sleep 5 + done + + if ! nc -z localhost 4040; then + echo "Failed to start the server within the expected time" + exit 1 + fi +fi +echo "Started the Server" diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index 31d27c0e..e68e30d3 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -35,6 +35,11 @@ jobs: cache-dependency-path: '**/pyproject.toml' python-version: ${{ matrix.pyVersion }} + - name: Setup Spark Remote + run: | + pip install hatch==1.9.4 + make setup_spark_remote + - name: Run unit tests run: | pip install hatch==1.9.4 diff --git a/Makefile b/Makefile index c9713b52..e7896145 100644 --- a/Makefile +++ b/Makefile @@ -23,6 +23,9 @@ test: integration: hatch run integration +setup_spark_remote: + .github/scripts/setup_spark_remote.sh + coverage: hatch run coverage && open htmlcov/index.html diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index d5fa1f94..a27e6ba9 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,13 +1,12 @@ import os from pathlib import Path -from unittest.mock import Mock from pyspark.sql import SparkSession import pytest @pytest.fixture def spark_session_mock(): - return Mock(spec=SparkSession) + return SparkSession.builder.appName("DQX Test").remote("sc://localhost").getOrCreate() @pytest.fixture diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 29437494..331e5c22 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -27,6 +27,7 @@ def test_get_col_name_longer(): assert actual == "local" +@pytest.mark.skip(reason="Ignore") def test_read_input_data_unity_catalog_table(spark_session_mock): input_location = "catalog.schema.table" input_format = None @@ -37,7 +38,7 @@ def test_read_input_data_unity_catalog_table(spark_session_mock): spark_session_mock.read.table.assert_called_once_with(input_location) assert result == "dataframe" - +@pytest.mark.skip(reason="Ignore") def test_read_input_data_storage_path(spark_session_mock): input_location = "s3://bucket/path" input_format = "delta" @@ -49,7 +50,7 @@ def test_read_input_data_storage_path(spark_session_mock): spark_session_mock.read.format.return_value.load.assert_called_once_with(input_location) assert result == "dataframe" - +@pytest.mark.skip(reason="Ignore") def test_read_input_data_workspace_file(spark_session_mock): input_location = "/folder/path" input_format = "delta" From 22c805de494fa4f48609af04f90d411f67317652 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Wed, 5 Feb 2025 11:42:21 +0530 Subject: [PATCH 02/16] Fmt fixes --- tests/unit/test_utils.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 331e5c22..b0bd8ac7 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -38,6 +38,7 @@ def test_read_input_data_unity_catalog_table(spark_session_mock): spark_session_mock.read.table.assert_called_once_with(input_location) assert result == "dataframe" + @pytest.mark.skip(reason="Ignore") def test_read_input_data_storage_path(spark_session_mock): input_location = "s3://bucket/path" @@ -50,6 +51,7 @@ def test_read_input_data_storage_path(spark_session_mock): spark_session_mock.read.format.return_value.load.assert_called_once_with(input_location) assert result == "dataframe" + @pytest.mark.skip(reason="Ignore") def test_read_input_data_workspace_file(spark_session_mock): input_location = "/folder/path" From add9545593d5a9df3310a06c50acd1aa0d059318 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Thu, 6 Feb 2025 11:06:42 +0530 Subject: [PATCH 03/16] temp commit --- .github/scripts/setup_spark_remote.sh | 3 +- tests/unit/conftest.py | 2 +- tests/unit/test_utils.py | 40 +++++++++++++-------------- 3 files changed, 23 insertions(+), 22 deletions(-) diff --git a/.github/scripts/setup_spark_remote.sh b/.github/scripts/setup_spark_remote.sh index 490e2248..3e148e17 100755 --- a/.github/scripts/setup_spark_remote.sh +++ b/.github/scripts/setup_spark_remote.sh @@ -1,6 +1,7 @@ #!/usr/bin/env bash set -xve +echo "Setting up spark-connect" mkdir -p "$HOME"/spark cd "$HOME"/spark || exit 1 @@ -19,7 +20,7 @@ mkdir -p "${spark}" SERVER_SCRIPT=$HOME/spark/${spark}/sbin/start-connect-server.sh -## check the spark version already exist ,if not download the respective version +## check the spark version already exist, if not download the respective version if [ -f "${SERVER_SCRIPT}" ];then echo "Spark Version already exists" else diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index a27e6ba9..4649719e 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,7 +5,7 @@ @pytest.fixture -def spark_session_mock(): +def spark_session(): return SparkSession.builder.appName("DQX Test").remote("sc://localhost").getOrCreate() diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index b0bd8ac7..1c044ce8 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -28,59 +28,59 @@ def test_get_col_name_longer(): @pytest.mark.skip(reason="Ignore") -def test_read_input_data_unity_catalog_table(spark_session_mock): +def test_read_input_data_unity_catalog_table(spark_session): input_location = "catalog.schema.table" input_format = None - spark_session_mock.read.table.return_value = "dataframe" + spark_session.read.table.return_value = "dataframe" - result = read_input_data(spark_session_mock, input_location, input_format) + result = read_input_data(spark_session, input_location, input_format) - spark_session_mock.read.table.assert_called_once_with(input_location) + spark_session.read.table.assert_called_once_with(input_location) assert result == "dataframe" @pytest.mark.skip(reason="Ignore") -def test_read_input_data_storage_path(spark_session_mock): +def test_read_input_data_storage_path(spark_session): input_location = "s3://bucket/path" input_format = "delta" - spark_session_mock.read.format.return_value.load.return_value = "dataframe" + spark_session.read.format.return_value.load.return_value = "dataframe" - result = read_input_data(spark_session_mock, input_location, input_format) + result = read_input_data(spark_session, input_location, input_format) - spark_session_mock.read.format.assert_called_once_with(input_format) - spark_session_mock.read.format.return_value.load.assert_called_once_with(input_location) + spark_session.read.format.assert_called_once_with(input_format) + spark_session.read.format.return_value.load.assert_called_once_with(input_location) assert result == "dataframe" @pytest.mark.skip(reason="Ignore") -def test_read_input_data_workspace_file(spark_session_mock): +def test_read_input_data_workspace_file(spark_session): input_location = "/folder/path" input_format = "delta" - spark_session_mock.read.format.return_value.load.return_value = "dataframe" + spark_session.read.format.return_value.load.return_value = "dataframe" - result = read_input_data(spark_session_mock, input_location, input_format) + result = read_input_data(spark_session, input_location, input_format) - spark_session_mock.read.format.assert_called_once_with(input_format) - spark_session_mock.read.format.return_value.load.assert_called_once_with(input_location) + spark_session.read.format.assert_called_once_with(input_format) + spark_session.read.format.return_value.load.assert_called_once_with(input_location) assert result == "dataframe" -def test_read_input_data_no_input_location(spark_session_mock): +def test_read_input_data_no_input_location(spark_session): with pytest.raises(ValueError, match="Input location not configured"): - read_input_data(spark_session_mock, None, None) + read_input_data(spark_session, None, None) -def test_read_input_data_no_input_format(spark_session_mock): +def test_read_input_data_no_input_format(spark_session): input_location = "s3://bucket/path" input_format = None with pytest.raises(ValueError, match="Input format not configured"): - read_input_data(spark_session_mock, input_location, input_format) + read_input_data(spark_session, input_location, input_format) -def test_read_invalid_input_location(spark_session_mock): +def test_read_invalid_input_location(spark_session): input_location = "invalid/location" input_format = None with pytest.raises(ValueError, match="Invalid input location."): - read_input_data(spark_session_mock, input_location, input_format) + read_input_data(spark_session, input_location, input_format) From 5519a1c216bdc80e419ebff81e7cc51c488b98ed Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Thu, 6 Feb 2025 17:17:49 +0530 Subject: [PATCH 04/16] Make file updates --- .github/workflows/push.yml | 2 +- Makefile | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/.github/workflows/push.yml b/.github/workflows/push.yml index e68e30d3..48663df7 100644 --- a/.github/workflows/push.yml +++ b/.github/workflows/push.yml @@ -43,7 +43,7 @@ jobs: - name: Run unit tests run: | pip install hatch==1.9.4 - make test + make ci-test - name: Publish test coverage uses: codecov/codecov-action@v5 diff --git a/Makefile b/Makefile index e7896145..06da3bca 100644 --- a/Makefile +++ b/Makefile @@ -17,7 +17,7 @@ lint: fmt: hatch run fmt -test: +ci-test: hatch run test integration: @@ -26,6 +26,8 @@ integration: setup_spark_remote: .github/scripts/setup_spark_remote.sh +test: setup_spark_remote test + coverage: hatch run coverage && open htmlcov/index.html From 7204e2127ac65e1340bf44ebc8dcfb002b00fa88 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Thu, 6 Feb 2025 17:19:45 +0530 Subject: [PATCH 05/16] Make file updates --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index 06da3bca..e9ec38ac 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ integration: setup_spark_remote: .github/scripts/setup_spark_remote.sh -test: setup_spark_remote test +test: setup_spark_remote ci-test coverage: hatch run coverage && open htmlcov/index.html From 5d63d42dc3a155ac664471bef0046bae5c596751 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 7 Feb 2025 11:34:14 +0530 Subject: [PATCH 06/16] dqx test fixes --- src/databricks/labs/dqx/utils.py | 1 + tests/integration/test_utils.py | 33 +++++++++++++++++++++++ tests/unit/test_utils.py | 45 +++++++++----------------------- 3 files changed, 46 insertions(+), 33 deletions(-) create mode 100644 tests/integration/test_utils.py diff --git a/src/databricks/labs/dqx/utils.py b/src/databricks/labs/dqx/utils.py index c087df1c..de916610 100644 --- a/src/databricks/labs/dqx/utils.py +++ b/src/databricks/labs/dqx/utils.py @@ -38,6 +38,7 @@ def read_input_data(spark: SparkSession, input_location: str | None, input_forma if STORAGE_PATH_PATTERN.match(input_location): if not input_format: raise ValueError("Input format not configured") + # TODO handle spark options while reading data from a file location return spark.read.format(str(input_format)).load(input_location) raise ValueError( diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py new file mode 100644 index 00000000..576b32cf --- /dev/null +++ b/tests/integration/test_utils.py @@ -0,0 +1,33 @@ +from pyspark.sql.types import Row +import pytest +from databricks.labs.dqx.utils import read_input_data + + +@pytest.fixture(scope="module") +def setup(spark): + schema = "col1: str, col2: int" + input_df = spark.createDataFrame([["k1", 1]], schema) + + # write dataframe to catalog, create a catalog if it is not there + spark.sql("CREATE CATALOG IF NOT EXISTS dqx_catalog") + spark.sql("CREATE DATABASE IF NOT EXISTS dqx_catalog.dqx_db") + input_df.write.format("delta").saveAsTable("dqx_catalog.dqx_db.dqx_table") + + # write dataframe to file + input_df.write.format("delta").save("/tmp/dqx_table") + + +def test_read_input_data_unity_catalog_table(spark_session): + input_location = "dqx_catalog.dqx_db.dqx_table" + input_format = None + + result = read_input_data(spark_session, input_location, input_format) + assert result.collect() == [Row(col1='k1', col2=1)] + + +def test_read_input_data_workspace_file(spark_session): + input_location = "/tmp/dqx_table" + input_format = "delta" + + result = read_input_data(spark_session, input_location, input_format) + assert result.collect() == [Row(col1='k1', col2=1)] diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 1c044ce8..c0a79b19 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,4 +1,7 @@ +import tempfile +import os import pyspark.sql.functions as F +from pyspark.sql.types import Row import pytest from databricks.labs.dqx.utils import read_input_data, get_column_name @@ -27,42 +30,18 @@ def test_get_col_name_longer(): assert actual == "local" -@pytest.mark.skip(reason="Ignore") -def test_read_input_data_unity_catalog_table(spark_session): - input_location = "catalog.schema.table" - input_format = None - spark_session.read.table.return_value = "dataframe" - - result = read_input_data(spark_session, input_location, input_format) - - spark_session.read.table.assert_called_once_with(input_location) - assert result == "dataframe" - - -@pytest.mark.skip(reason="Ignore") def test_read_input_data_storage_path(spark_session): - input_location = "s3://bucket/path" - input_format = "delta" - spark_session.read.format.return_value.load.return_value = "dataframe" - - result = read_input_data(spark_session, input_location, input_format) - - spark_session.read.format.assert_called_once_with(input_format) - spark_session.read.format.return_value.load.assert_called_once_with(input_location) - assert result == "dataframe" - - -@pytest.mark.skip(reason="Ignore") -def test_read_input_data_workspace_file(spark_session): - input_location = "/folder/path" - input_format = "delta" - spark_session.read.format.return_value.load.return_value = "dataframe" + with tempfile.NamedTemporaryFile(delete=False) as temp_file: + temp_file.write(b"val1,val2\n") + temp_file_path = temp_file.name - result = read_input_data(spark_session, input_location, input_format) + try: + input_location = temp_file_path + result = read_input_data(spark_session, input_location, "csv") + assert result.collect() == [Row(_c0='val1', _c1='val2')] - spark_session.read.format.assert_called_once_with(input_format) - spark_session.read.format.return_value.load.assert_called_once_with(input_location) - assert result == "dataframe" + finally: + os.remove(temp_file_path) def test_read_input_data_no_input_location(spark_session): From 7988f353b587f3d0755e88044b049697ac6e9826 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 7 Feb 2025 12:10:48 +0530 Subject: [PATCH 07/16] dqx test fixes --- tests/integration/test_utils.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index 576b32cf..a026ae7a 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -17,17 +17,17 @@ def setup(spark): input_df.write.format("delta").save("/tmp/dqx_table") -def test_read_input_data_unity_catalog_table(spark_session): +def test_read_input_data_unity_catalog_table(spark): input_location = "dqx_catalog.dqx_db.dqx_table" input_format = None - result = read_input_data(spark_session, input_location, input_format) + result = read_input_data(spark, input_location, input_format) assert result.collect() == [Row(col1='k1', col2=1)] -def test_read_input_data_workspace_file(spark_session): +def test_read_input_data_workspace_file(spark): input_location = "/tmp/dqx_table" input_format = "delta" - result = read_input_data(spark_session, input_location, input_format) + result = read_input_data(spark, input_location, input_format) assert result.collect() == [Row(col1='k1', col2=1)] From f3ad03d93c1878f573f3b563e549a36374754133 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 7 Feb 2025 12:24:11 +0530 Subject: [PATCH 08/16] fixed missing setup fixture dependency --- tests/integration/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index a026ae7a..319aaf40 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -17,7 +17,7 @@ def setup(spark): input_df.write.format("delta").save("/tmp/dqx_table") -def test_read_input_data_unity_catalog_table(spark): +def test_read_input_data_unity_catalog_table(setup, spark): input_location = "dqx_catalog.dqx_db.dqx_table" input_format = None @@ -25,7 +25,7 @@ def test_read_input_data_unity_catalog_table(spark): assert result.collect() == [Row(col1='k1', col2=1)] -def test_read_input_data_workspace_file(spark): +def test_read_input_data_workspace_file(setup, spark): input_location = "/tmp/dqx_table" input_format = "delta" From 5f9eeca4c2d5a7c935c7c47700b680ebb5e149ab Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 7 Feb 2025 12:33:14 +0530 Subject: [PATCH 09/16] fixed missing setup fixture dependency --- tests/integration/test_utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index 319aaf40..0ce30a6f 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -3,7 +3,7 @@ from databricks.labs.dqx.utils import read_input_data -@pytest.fixture(scope="module") +@pytest.fixture() def setup(spark): schema = "col1: str, col2: int" input_df = spark.createDataFrame([["k1", 1]], schema) From b0180585f340f78facdcc663587d5f2c2a187905 Mon Sep 17 00:00:00 2001 From: "sundar.shankar" Date: Fri, 7 Feb 2025 12:41:27 +0530 Subject: [PATCH 10/16] test fixes --- tests/integration/test_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index 0ce30a6f..4aadf0c4 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -5,12 +5,12 @@ @pytest.fixture() def setup(spark): - schema = "col1: str, col2: int" + schema = "col1 STRING, col2 INT" input_df = spark.createDataFrame([["k1", 1]], schema) # write dataframe to catalog, create a catalog if it is not there spark.sql("CREATE CATALOG IF NOT EXISTS dqx_catalog") - spark.sql("CREATE DATABASE IF NOT EXISTS dqx_catalog.dqx_db") + spark.sql("CREATE SCHEMA IF NOT EXISTS dqx_catalog.dqx_db") input_df.write.format("delta").saveAsTable("dqx_catalog.dqx_db.dqx_table") # write dataframe to file From 6bc77019312b5fb529e7fb243cc8c0a4ac40c75b Mon Sep 17 00:00:00 2001 From: Marcin Wojtyczka Date: Fri, 7 Feb 2025 15:49:32 +0100 Subject: [PATCH 11/16] fixed tests --- tests/integration/test_utils.py | 46 ++++++++++++++++----------------- 1 file changed, 22 insertions(+), 24 deletions(-) diff --git a/tests/integration/test_utils.py b/tests/integration/test_utils.py index 4aadf0c4..9e4e7f44 100644 --- a/tests/integration/test_utils.py +++ b/tests/integration/test_utils.py @@ -1,33 +1,31 @@ -from pyspark.sql.types import Row -import pytest +from chispa.dataframe_comparer import assert_df_equality # type: ignore from databricks.labs.dqx.utils import read_input_data -@pytest.fixture() -def setup(spark): - schema = "col1 STRING, col2 INT" - input_df = spark.createDataFrame([["k1", 1]], schema) - - # write dataframe to catalog, create a catalog if it is not there - spark.sql("CREATE CATALOG IF NOT EXISTS dqx_catalog") - spark.sql("CREATE SCHEMA IF NOT EXISTS dqx_catalog.dqx_db") - input_df.write.format("delta").saveAsTable("dqx_catalog.dqx_db.dqx_table") - - # write dataframe to file - input_df.write.format("delta").save("/tmp/dqx_table") - - -def test_read_input_data_unity_catalog_table(setup, spark): - input_location = "dqx_catalog.dqx_db.dqx_table" +def test_read_input_data_unity_catalog_table(spark, make_schema, make_random): + catalog_name = "main" + schema_name = make_schema(catalog_name=catalog_name).name + input_location = f"{catalog_name}.{schema_name}.{make_random(6).lower()}" input_format = None - result = read_input_data(spark, input_location, input_format) - assert result.collect() == [Row(col1='k1', col2=1)] + schema = "a: int, b: int" + input_df = spark.createDataFrame([[1, 2]], schema) + input_df.write.format("delta").saveAsTable(input_location) + + result_df = read_input_data(spark, input_location, input_format) + assert_df_equality(input_df, result_df) -def test_read_input_data_workspace_file(setup, spark): - input_location = "/tmp/dqx_table" +def test_read_input_data_workspace_file(spark, make_schema, make_volume): + catalog_name = "main" + schema_name = make_schema(catalog_name=catalog_name).name + info = make_volume(catalog_name=catalog_name, schema_name=schema_name) + input_location = info.full_name input_format = "delta" - result = read_input_data(spark, input_location, input_format) - assert result.collect() == [Row(col1='k1', col2=1)] + schema = "a: int, b: int" + input_df = spark.createDataFrame([[1, 2]], schema) + input_df.write.format("delta").saveAsTable(input_location) + + result_df = read_input_data(spark, input_location, input_format) + assert_df_equality(input_df, result_df) From 1134cba9a866a48f461c7896c38154ae718e4d1e Mon Sep 17 00:00:00 2001 From: Marcin Wojtyczka Date: Fri, 7 Feb 2025 16:01:24 +0100 Subject: [PATCH 12/16] added comment --- src/databricks/labs/dqx/utils.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/databricks/labs/dqx/utils.py b/src/databricks/labs/dqx/utils.py index de916610..885906ec 100644 --- a/src/databricks/labs/dqx/utils.py +++ b/src/databricks/labs/dqx/utils.py @@ -39,6 +39,7 @@ def read_input_data(spark: SparkSession, input_location: str | None, input_forma if not input_format: raise ValueError("Input format not configured") # TODO handle spark options while reading data from a file location + # https://github.com/databrickslabs/dqx/issues/161 return spark.read.format(str(input_format)).load(input_location) raise ValueError( From d9937082e80d084d1479b04c303b6406bd4890e4 Mon Sep 17 00:00:00 2001 From: Marcin Wojtyczka Date: Fri, 7 Feb 2025 16:10:37 +0100 Subject: [PATCH 13/16] refactor added unit test for applying checks --- tests/unit/conftest.py | 2 +- tests/unit/test_apply_checks.py | 32 ++++++++++++++++++++++++++++++++ tests/unit/test_utils.py | 16 ++++++++-------- 3 files changed, 41 insertions(+), 9 deletions(-) create mode 100644 tests/unit/test_apply_checks.py diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 4649719e..a4c5499d 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -5,7 +5,7 @@ @pytest.fixture -def spark_session(): +def spark_local(): return SparkSession.builder.appName("DQX Test").remote("sc://localhost").getOrCreate() diff --git a/tests/unit/test_apply_checks.py b/tests/unit/test_apply_checks.py new file mode 100644 index 00000000..2da6ccb2 --- /dev/null +++ b/tests/unit/test_apply_checks.py @@ -0,0 +1,32 @@ +from unittest.mock import MagicMock + +from chispa.dataframe_comparer import assert_df_equality # type: ignore +from databricks.labs.dqx.col_functions import is_not_null_and_not_empty +from databricks.labs.dqx.engine import DQEngine +from databricks.labs.dqx.rule import DQRule +from databricks.sdk import WorkspaceClient + + +def test_apply_checks(spark_local): + ws = MagicMock(spec=WorkspaceClient, **{"catalogs.list.return_value": []}) + dq_engine = DQEngine(ws) + + schema = "a: int" + test_df = spark_local.createDataFrame([[1], [None]], schema) + + checks = [ + DQRule(name="col_a_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("a")), + ] + + checked = dq_engine.apply_checks(test_df, checks) + + expected_schema = schema + ", _errors: map, _warnings: map" + expected = spark_local.createDataFrame( + [ + [1, None, None], + [None, {"col_a_is_null_or_empty": "Column a is null or empty"}, None], + ], + expected_schema, + ) + + assert_df_equality(checked, expected, ignore_nullable=True) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index c0a79b19..8d71b855 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -30,36 +30,36 @@ def test_get_col_name_longer(): assert actual == "local" -def test_read_input_data_storage_path(spark_session): +def test_read_input_data_storage_path(spark_local): with tempfile.NamedTemporaryFile(delete=False) as temp_file: temp_file.write(b"val1,val2\n") temp_file_path = temp_file.name try: input_location = temp_file_path - result = read_input_data(spark_session, input_location, "csv") + result = read_input_data(spark_local, input_location, "csv") assert result.collect() == [Row(_c0='val1', _c1='val2')] finally: os.remove(temp_file_path) -def test_read_input_data_no_input_location(spark_session): +def test_read_input_data_no_input_location(spark_local): with pytest.raises(ValueError, match="Input location not configured"): - read_input_data(spark_session, None, None) + read_input_data(spark_local, None, None) -def test_read_input_data_no_input_format(spark_session): +def test_read_input_data_no_input_format(spark_local): input_location = "s3://bucket/path" input_format = None with pytest.raises(ValueError, match="Input format not configured"): - read_input_data(spark_session, input_location, input_format) + read_input_data(spark_local, input_location, input_format) -def test_read_invalid_input_location(spark_session): +def test_read_invalid_input_location(spark_local): input_location = "invalid/location" input_format = None with pytest.raises(ValueError, match="Invalid input location."): - read_input_data(spark_session, input_location, input_format) + read_input_data(spark_local, input_location, input_format) From 1ce4fe065b539cd8b92480055db26175117c3165 Mon Sep 17 00:00:00 2001 From: Marcin Wojtyczka Date: Fri, 7 Feb 2025 16:29:33 +0100 Subject: [PATCH 14/16] updated docs added apply checks test --- docs/dqx/docs/dev/contributing.mdx | 11 ++++++++--- docs/dqx/docs/reference.mdx | 6 ++++-- tests/unit/test_apply_checks.py | 24 ++++++++++-------------- 3 files changed, 22 insertions(+), 19 deletions(-) diff --git a/docs/dqx/docs/dev/contributing.mdx b/docs/dqx/docs/dev/contributing.mdx index 3a8f7c20..e4b67307 100644 --- a/docs/dqx/docs/dev/contributing.mdx +++ b/docs/dqx/docs/dev/contributing.mdx @@ -86,13 +86,18 @@ Before every commit, apply the consistent formatting of the code, as we want our make fmt ``` -Before every commit, run automated bug detector (`make lint`) and unit tests (`make test`) to ensure that automated -pull request checks do pass, before your code is reviewed by others: +Before every commit, run automated bug detector and unit tests to ensure that automated +pull request checks do pass, before your code is reviewed by others: ```shell make lint +make setup_spark_remote make test ``` +The command `make setup_spark_remote` sets up the environment for running tests locally and is required one time only. +DQX uses Databricks Connect for running integration tests, which restricts the creation of a Spark session in local mode. +To enable spark local execution for unit testing, we use spark remote. + ### Local setup for integration tests and code coverage Note that integration tests and code coverage are run automatically when you create a Pull Request in Github. @@ -215,7 +220,7 @@ Here are the example steps to submit your first contribution: 7. `make fmt` 8. `make lint` 9. .. fix if any -10. `make test` and `make integration`, optionally `make coverage` to get test coverage report +10. `make setup_spark_remote`, make test` and `make integration`, optionally `make coverage` to get test coverage report 11. .. fix if any issues 12. `git commit -S -a -m "message"`. Make sure to enter a meaningful commit message title. diff --git a/docs/dqx/docs/reference.mdx b/docs/dqx/docs/reference.mdx index fea48484..e8eb365d 100644 --- a/docs/dqx/docs/reference.mdx +++ b/docs/dqx/docs/reference.mdx @@ -252,7 +252,7 @@ def test_dq(): schema = "a: int, b: int, c: int" expected_schema = schema + ", _errors: map, _warnings: map" - test_df = spark.createDataFrame([[1, 3, 3]], schema) + test_df = spark.createDataFrame([[1, None, 3]], schema) checks = [ DQRule(name="col_a_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("a")), @@ -262,6 +262,8 @@ def test_dq(): dq_engine = DQEngine(ws) df = dq_engine.apply_checks(test_df, checks) - expected_df = spark.createDataFrame([[1, 3, 3, None, None]], expected_schema) + expected_df = spark.createDataFrame( + [[1, None, 3, {"col_b_is_null_or_empty": "Column b is null or empty"}, None]], expected_schema + ) assert_df_equality(df, expected_df) ``` diff --git a/tests/unit/test_apply_checks.py b/tests/unit/test_apply_checks.py index 2da6ccb2..7d5a93af 100644 --- a/tests/unit/test_apply_checks.py +++ b/tests/unit/test_apply_checks.py @@ -9,24 +9,20 @@ def test_apply_checks(spark_local): ws = MagicMock(spec=WorkspaceClient, **{"catalogs.list.return_value": []}) - dq_engine = DQEngine(ws) - schema = "a: int" - test_df = spark_local.createDataFrame([[1], [None]], schema) + schema = "a: int, b: int, c: int" + expected_schema = schema + ", _errors: map, _warnings: map" + test_df = spark_local.createDataFrame([[1, None, 3]], schema) checks = [ - DQRule(name="col_a_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("a")), + DQRule(name="col_a_is_null_or_empty", criticality="warn", check=is_not_null_and_not_empty("a")), + DQRule(name="col_b_is_null_or_empty", criticality="error", check=is_not_null_and_not_empty("b")), ] - checked = dq_engine.apply_checks(test_df, checks) + dq_engine = DQEngine(ws) + df = dq_engine.apply_checks(test_df, checks) - expected_schema = schema + ", _errors: map, _warnings: map" - expected = spark_local.createDataFrame( - [ - [1, None, None], - [None, {"col_a_is_null_or_empty": "Column a is null or empty"}, None], - ], - expected_schema, + expected_df = spark_local.createDataFrame( + [[1, None, 3, {"col_b_is_null_or_empty": "Column b is null or empty"}, None]], expected_schema ) - - assert_df_equality(checked, expected, ignore_nullable=True) + assert_df_equality(df, expected_df) From 60e208c5ecc28573cfb88f5d8cb67b1996d6cbca Mon Sep 17 00:00:00 2001 From: Marcin Wojtyczka Date: Fri, 7 Feb 2025 18:56:26 +0100 Subject: [PATCH 15/16] updated docs --- demos/dqx_demo_library.py | 28 ++++++++++++++-------------- demos/dqx_demo_tool.py | 4 ++-- docs/dqx/docs/dev/contributing.mdx | 6 +++--- docs/dqx/docs/guide.mdx | 8 ++++---- docs/dqx/docs/reference.mdx | 8 ++++---- 5 files changed, 27 insertions(+), 27 deletions(-) diff --git a/demos/dqx_demo_library.py b/demos/dqx_demo_library.py index 26a736c0..5140e420 100644 --- a/demos/dqx_demo_library.py +++ b/demos/dqx_demo_library.py @@ -59,7 +59,7 @@ print(dlt_expectations) # save generated checks in a workspace file -user_name = spark.sql('select current_user() as user').collect()[0]['user'] +user_name = spark.sql("select current_user() as user").collect()[0]["user"] checks_file = f"/Workspace/Users/{user_name}/dqx_demo_checks.yml" dq_engine = DQEngine(ws) dq_engine.save_checks_in_workspace_file(checks, workspace_path=checks_file) @@ -143,7 +143,7 @@ col_name: col3 - criticality: error - filter: col1<3 + filter: col1 < 3 check: function: is_not_null_and_not_empty arguments: @@ -193,17 +193,17 @@ criticality="error", check_func=is_not_null).get_rules() + [ DQRule( # define rule for a single column - name='col3_is_null_or_empty', - criticality='error', - check=is_not_null_and_not_empty('col3')), + name="col3_is_null_or_empty", + criticality="error", + check=is_not_null_and_not_empty("col3")), DQRule( # define rule with a filter - name='col_4_is_null_or_empty', - criticality='error', - filter='col1<3', - check=is_not_null_and_not_empty('col4')), + name="col_4_is_null_or_empty", + criticality="error", + filter="col1 < 3", + check=is_not_null_and_not_empty("col4")), DQRule( # name auto-generated if not provided - criticality='warn', - check=value_is_in_list('col4', ['1', '2'])) + criticality="warn", + check=value_is_in_list("col4", ["1", "2"])) ] schema = "col1: int, col2: int, col3: int, col4 int" @@ -384,9 +384,9 @@ def ends_with_foo(col_name: str) -> Column: input_df = spark.createDataFrame([["str1"], ["foo"], ["str3"]], schema) checks = [ DQRule( - name='col_1_is_null_or_empty', - criticality='error', - check=is_not_null_and_not_empty('col1')), + name="col_1_is_null_or_empty", + criticality="error", + check=is_not_null_and_not_empty("col1")), ] valid_and_quarantined_df = dq_engine.apply_checks(input_df, checks) diff --git a/demos/dqx_demo_tool.py b/demos/dqx_demo_tool.py index 9c9c970d..3821bd07 100644 --- a/demos/dqx_demo_tool.py +++ b/demos/dqx_demo_tool.py @@ -45,7 +45,7 @@ import glob import os -user_name = spark.sql('select current_user() as user').collect()[0]['user'] +user_name = spark.sql("select current_user() as user").collect()[0]["user"] dqx_wheel_files = glob.glob(f"/Workspace/Users/{user_name}/.dqx/wheels/databricks_labs_dqx-*.whl") dqx_latest_wheel = max(dqx_wheel_files, key=os.path.getctime) %pip install {dqx_latest_wheel} @@ -210,7 +210,7 @@ # COMMAND ---------- print(f"Saving quarantined data to {run_config.quarantine_table}") -quarantine_catalog, quarantine_schema, _ = run_config.quarantine_table.split('.') +quarantine_catalog, quarantine_schema, _ = run_config.quarantine_table.split(".") spark.sql(f"CREATE CATALOG IF NOT EXISTS {quarantine_catalog}") spark.sql(f"CREATE SCHEMA IF NOT EXISTS {quarantine_catalog}.{quarantine_schema}") diff --git a/docs/dqx/docs/dev/contributing.mdx b/docs/dqx/docs/dev/contributing.mdx index e4b67307..cfd18b02 100644 --- a/docs/dqx/docs/dev/contributing.mdx +++ b/docs/dqx/docs/dev/contributing.mdx @@ -94,9 +94,9 @@ make setup_spark_remote make test ``` -The command `make setup_spark_remote` sets up the environment for running tests locally and is required one time only. -DQX uses Databricks Connect for running integration tests, which restricts the creation of a Spark session in local mode. -To enable spark local execution for unit testing, we use spark remote. +The command `make setup_spark_remote` sets up the environment for running unit tests and is required one time only. +DQX uses Databricks Connect as a test dependency, which restricts the creation of a Spark session in local mode. +To enable spark local execution for unit testing, the command install spark remote. ### Local setup for integration tests and code coverage diff --git a/docs/dqx/docs/guide.mdx b/docs/dqx/docs/guide.mdx index 76e3269c..732e8d6a 100644 --- a/docs/dqx/docs/guide.mdx +++ b/docs/dqx/docs/guide.mdx @@ -251,11 +251,11 @@ checks = DQRuleColSet( # define rule for multiple columns at once DQRule( # define rule with a filter name="col_4_is_null_or_empty", criticality="error", - filter="col1<3", + filter="col1 < 3", check=is_not_null_and_not_empty("col4")), DQRule( # name auto-generated if not provided - criticality='warn', - check=value_is_in_list('col4', ['1', '2'])) + criticality="warn", + check=value_is_in_list("col4", ["1", "2"])) ] input_df = spark.read.table("catalog1.schema1.table1") @@ -294,7 +294,7 @@ checks = yaml.safe_load(""" col_name: col3 - criticality: error - filter: col1<3 + filter: col1 < 3 check: function: is_not_null_and_not_empty arguments: diff --git a/docs/dqx/docs/reference.mdx b/docs/dqx/docs/reference.mdx index 613962d2..6dc2ee51 100644 --- a/docs/dqx/docs/reference.mdx +++ b/docs/dqx/docs/reference.mdx @@ -36,13 +36,13 @@ The following quality rules / functions are currently available: You can check implementation details of the rules [here](https://github.com/databrickslabs/dqx/blob/main/src/databricks/labs/dqx/col_functions.py). -#### Apply Filter on quality rule +### Apply filters on checks -If you want to apply a filter to a part of the dataframe, you can add a `filter` to the rule. -For example, if you want to check that a col `a` is not null when `b` is positive, you can do it like this: +You can apply checks to a part of the DataFrame by using a `filter`. +For example, to ensure that a column `a` is not null only when a column `b` is positive, you can define the check as follows: ```yaml - criticality: "error" - filter: b>0 + filter: b > 0 check: function: "is_not_null" arguments: From 546b36201af83013c47baeb21341a41239e2a285 Mon Sep 17 00:00:00 2001 From: Marcin Wojtyczka Date: Fri, 7 Feb 2025 19:05:02 +0100 Subject: [PATCH 16/16] updated docs --- docs/dqx/docs/reference.mdx | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/docs/dqx/docs/reference.mdx b/docs/dqx/docs/reference.mdx index 6dc2ee51..b7f8c248 100644 --- a/docs/dqx/docs/reference.mdx +++ b/docs/dqx/docs/reference.mdx @@ -41,12 +41,12 @@ You can check implementation details of the rules [here](https://github.com/data You can apply checks to a part of the DataFrame by using a `filter`. For example, to ensure that a column `a` is not null only when a column `b` is positive, you can define the check as follows: ```yaml -- criticality: "error" +- criticality: error filter: b > 0 check: - function: "is_not_null" + function: is_not_null arguments: - col_name: "a" + col_name: a ``` ### Creating your own checks