Skip to content

Commit 6bc7701

Browse files
committed
fixed tests
1 parent b018058 commit 6bc7701

File tree

1 file changed

+22
-24
lines changed

1 file changed

+22
-24
lines changed

tests/integration/test_utils.py

+22-24
Original file line numberDiff line numberDiff line change
@@ -1,33 +1,31 @@
1-
from pyspark.sql.types import Row
2-
import pytest
1+
from chispa.dataframe_comparer import assert_df_equality # type: ignore
32
from databricks.labs.dqx.utils import read_input_data
43

54

6-
@pytest.fixture()
7-
def setup(spark):
8-
schema = "col1 STRING, col2 INT"
9-
input_df = spark.createDataFrame([["k1", 1]], schema)
10-
11-
# write dataframe to catalog, create a catalog if it is not there
12-
spark.sql("CREATE CATALOG IF NOT EXISTS dqx_catalog")
13-
spark.sql("CREATE SCHEMA IF NOT EXISTS dqx_catalog.dqx_db")
14-
input_df.write.format("delta").saveAsTable("dqx_catalog.dqx_db.dqx_table")
15-
16-
# write dataframe to file
17-
input_df.write.format("delta").save("/tmp/dqx_table")
18-
19-
20-
def test_read_input_data_unity_catalog_table(setup, spark):
21-
input_location = "dqx_catalog.dqx_db.dqx_table"
5+
def test_read_input_data_unity_catalog_table(spark, make_schema, make_random):
6+
catalog_name = "main"
7+
schema_name = make_schema(catalog_name=catalog_name).name
8+
input_location = f"{catalog_name}.{schema_name}.{make_random(6).lower()}"
229
input_format = None
2310

24-
result = read_input_data(spark, input_location, input_format)
25-
assert result.collect() == [Row(col1='k1', col2=1)]
11+
schema = "a: int, b: int"
12+
input_df = spark.createDataFrame([[1, 2]], schema)
13+
input_df.write.format("delta").saveAsTable(input_location)
14+
15+
result_df = read_input_data(spark, input_location, input_format)
16+
assert_df_equality(input_df, result_df)
2617

2718

28-
def test_read_input_data_workspace_file(setup, spark):
29-
input_location = "/tmp/dqx_table"
19+
def test_read_input_data_workspace_file(spark, make_schema, make_volume):
20+
catalog_name = "main"
21+
schema_name = make_schema(catalog_name=catalog_name).name
22+
info = make_volume(catalog_name=catalog_name, schema_name=schema_name)
23+
input_location = info.full_name
3024
input_format = "delta"
3125

32-
result = read_input_data(spark, input_location, input_format)
33-
assert result.collect() == [Row(col1='k1', col2=1)]
26+
schema = "a: int, b: int"
27+
input_df = spark.createDataFrame([[1, 2]], schema)
28+
input_df.write.format("delta").saveAsTable(input_location)
29+
30+
result_df = read_input_data(spark, input_location, input_format)
31+
assert_df_equality(input_df, result_df)

0 commit comments

Comments
 (0)