|
1 |
| -from pyspark.sql.types import Row |
2 |
| -import pytest |
| 1 | +from chispa.dataframe_comparer import assert_df_equality # type: ignore |
3 | 2 | from databricks.labs.dqx.utils import read_input_data
|
4 | 3 |
|
5 | 4 |
|
6 |
| -@pytest.fixture() |
7 |
| -def setup(spark): |
8 |
| - schema = "col1 STRING, col2 INT" |
9 |
| - input_df = spark.createDataFrame([["k1", 1]], schema) |
10 |
| - |
11 |
| - # write dataframe to catalog, create a catalog if it is not there |
12 |
| - spark.sql("CREATE CATALOG IF NOT EXISTS dqx_catalog") |
13 |
| - spark.sql("CREATE SCHEMA IF NOT EXISTS dqx_catalog.dqx_db") |
14 |
| - input_df.write.format("delta").saveAsTable("dqx_catalog.dqx_db.dqx_table") |
15 |
| - |
16 |
| - # write dataframe to file |
17 |
| - input_df.write.format("delta").save("/tmp/dqx_table") |
18 |
| - |
19 |
| - |
20 |
| -def test_read_input_data_unity_catalog_table(setup, spark): |
21 |
| - input_location = "dqx_catalog.dqx_db.dqx_table" |
| 5 | +def test_read_input_data_unity_catalog_table(spark, make_schema, make_random): |
| 6 | + catalog_name = "main" |
| 7 | + schema_name = make_schema(catalog_name=catalog_name).name |
| 8 | + input_location = f"{catalog_name}.{schema_name}.{make_random(6).lower()}" |
22 | 9 | input_format = None
|
23 | 10 |
|
24 |
| - result = read_input_data(spark, input_location, input_format) |
25 |
| - assert result.collect() == [Row(col1='k1', col2=1)] |
| 11 | + schema = "a: int, b: int" |
| 12 | + input_df = spark.createDataFrame([[1, 2]], schema) |
| 13 | + input_df.write.format("delta").saveAsTable(input_location) |
| 14 | + |
| 15 | + result_df = read_input_data(spark, input_location, input_format) |
| 16 | + assert_df_equality(input_df, result_df) |
26 | 17 |
|
27 | 18 |
|
28 |
| -def test_read_input_data_workspace_file(setup, spark): |
29 |
| - input_location = "/tmp/dqx_table" |
| 19 | +def test_read_input_data_workspace_file(spark, make_schema, make_volume): |
| 20 | + catalog_name = "main" |
| 21 | + schema_name = make_schema(catalog_name=catalog_name).name |
| 22 | + info = make_volume(catalog_name=catalog_name, schema_name=schema_name) |
| 23 | + input_location = info.full_name |
30 | 24 | input_format = "delta"
|
31 | 25 |
|
32 |
| - result = read_input_data(spark, input_location, input_format) |
33 |
| - assert result.collect() == [Row(col1='k1', col2=1)] |
| 26 | + schema = "a: int, b: int" |
| 27 | + input_df = spark.createDataFrame([[1, 2]], schema) |
| 28 | + input_df.write.format("delta").saveAsTable(input_location) |
| 29 | + |
| 30 | + result_df = read_input_data(spark, input_location, input_format) |
| 31 | + assert_df_equality(input_df, result_df) |
0 commit comments