Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new tables #27

Merged
merged 4 commits into from
Jan 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ lint-bandit: ## Run bandit
@echo "\n${BLUE}Running bandit...${NC}\n"
@${POETRY_RUN} bandit -r ${PROJ}

lint-base: lint-flake8 lint-bandit ## Just run the linters without autolinting
lint-base: lint-flake8 ## Just run the linters without autolinting

lint: autolint lint-base lint-mypy ## Autolint and code linting

Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ Then for each test case:
2. Verify the metadata read from the Delta table matches that in the `table_version_metadata.json`. For example, verify that the connector parsed the correct `min_reader_version` from the Delta log. This step may be skipped if the reader connector does not expose such details in its public API.
3. Attempt to read the Delta table's data:
a. If the Delta table uses a version unsupported by the reader connector (as determined from `table_version_metadata.json`), verify an appropriate error is returned.
b. If the Delta table is supported by the reader connector, assert that the read data is equal to the data read from `table_content.parquet`.
b. If the Delta table is supported by the reader connector, assert that the read data is equal to the data read from `table_content.parquet`. In order to make it easy to sort the tables for comparison, some tables have a column `pk` which is an ascending integer sequence.

For an example implementation of this, see the example PySpark tests in `tests/pyspark_delta/`.

Expand Down
219 changes: 215 additions & 4 deletions dat/generated_tables.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,12 @@
import os
from datetime import date
import random
from datetime import date, datetime, timedelta
from decimal import Decimal
from pathlib import Path
from typing import Callable, List, Tuple

import pyspark.sql
import pyspark.sql.types as types
from delta.tables import DeltaTable
from pyspark.sql import SparkSession

Expand Down Expand Up @@ -35,7 +40,7 @@ def save_expected(case: TestCaseInfo, as_latest=False) -> None:
# Need to ensure directory exists first
os.makedirs(case.expected_root(version))

df.toPandas().to_parquet(case.expected_path(version))
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It turned out toPandas().to_parquet() causes weird things to happen to timestamps, so better to stick to Spark here. Without this, I was able to eliminate the pandas dependency as well.

df.write.parquet(case.expected_path(version))

out_path = case.expected_root(version) / 'table_version_metadata.json'
with open(out_path, 'w') as f:
Expand Down Expand Up @@ -121,8 +126,6 @@ def create_multi_partitioned(case: TestCaseInfo, spark: SparkSession):
('b', date(1970, 1, 2), b'world', 3)
]
df = spark.createDataFrame(data, schema=columns)
# rdd = spark.sparkContext.parallelize(data)
# df = rdd.toDF(columns)
schema = df.schema

df.repartition(1).write.format('delta').partitionBy(
Expand Down Expand Up @@ -152,6 +155,24 @@ def create_multi_partitioned(case: TestCaseInfo, spark: SparkSession):
save_expected(case)


@reference_table(
name='multi_partitioned_2',
description=('Multiple levels of partitioning, with boolean, timestamp, and '
'decimal partition columns')
)
def create_multi_partitioned_2(case: TestCaseInfo, spark: SparkSession):
columns = ['bool', 'time', 'amount', 'int']
partition_columns = ['bool', 'time', 'amount']
data = [
(True, datetime(1970, 1, 1), Decimal('200.00'), 1),
(True, datetime(1970, 1, 1, 12, 30), Decimal('200.00'), 2),
(False, datetime(1970, 1, 2, 8, 45), Decimal('12.00'), 3)
]
df = spark.createDataFrame(data, schema=columns)
df.repartition(1).write.format('delta').partitionBy(
*partition_columns).save(case.delta_root)


@reference_table(
name='with_schema_change',
description='Table which has schema change using overwriteSchema=True.',
Expand All @@ -171,3 +192,193 @@ def with_schema_change(case: TestCaseInfo, spark: SparkSession):
'overwriteSchema', True).format('delta').save(
case.delta_root)
save_expected(case)


@reference_table(
name='all_primitive_types',
description='Table containing all non-nested types',
)
def create_all_primitive_types(case: TestCaseInfo, spark: SparkSession):
schema = types.StructType([
types.StructField('utf8', types.StringType()),
types.StructField('int64', types.LongType()),
types.StructField('int32', types.IntegerType()),
types.StructField('int16', types.ShortType()),
types.StructField('int8', types.ByteType()),
types.StructField('float32', types.FloatType()),
types.StructField('float64', types.DoubleType()),
types.StructField('bool', types.BooleanType()),
types.StructField('binary', types.BinaryType()),
types.StructField('decimal', types.DecimalType(5, 3)),
types.StructField('date32', types.DateType()),
types.StructField('timestamp', types.TimestampType()),
])

df = spark.createDataFrame([
(
str(i),
i,
i,
i,
i,
float(i),
float(i),
i % 2 == 0,
bytes(i),
Decimal('10.000') + i,
date(1970, 1, 1) + timedelta(days=i),
datetime(1970, 1, 1) + timedelta(hours=i)
)
for i in range(5)
], schema=schema)

df.repartition(1).write.format('delta').save(case.delta_root)


@reference_table(
name='nested_types',
description='Table containing various nested types',
)
def create_nested_types(case: TestCaseInfo, spark: SparkSession):
schema = types.StructType([
types.StructField(
'pk', types.IntegerType()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

chispa doesn't support ignoring sort order when comparing tables that contain map types, so we have to add a pk column that tests can sort on. I've added a note about this to the readme.

),
types.StructField(
'struct', types.StructType(
[types.StructField(
'float64', types.DoubleType()),
types.StructField(
'bool', types.BooleanType()), ])),
types.StructField(
'array', types.ArrayType(
types.ShortType())),
types.StructField(
'map', types.MapType(
types.StringType(),
types.IntegerType())), ])

df = spark.createDataFrame([
(
i,
{'float64': float(i), 'bool': i % 2 == 0},
list(range(i + 1)),
{str(i): i for i in range(i)}
)
for i in range(5)
], schema=schema)

df.repartition(1).write.format('delta').save(case.delta_root)


def get_sample_data(
spark: SparkSession, seed: int = 42, nrows: int = 5) -> pyspark.sql.DataFrame:
# Use seed to get consistent data between runs, for reproducibility
random.seed(seed)
return spark.createDataFrame([
(
random.choice(['a', 'b', 'c', None]),
random.randint(0, 1000),
date(random.randint(1970, 2020), random.randint(1, 12), 1)
)
for i in range(nrows)
], schema=['letter', 'int', 'date'])


@reference_table(
name='with_checkpoint',
description='Table with a checkpoint',
)
def create_with_checkpoint(case: TestCaseInfo, spark: SparkSession):
df = get_sample_data(spark)

(DeltaTable.create(spark)
.location(str(Path(case.delta_root).absolute()))
.addColumns(df.schema)
.property('delta.checkpointInterval', '2')
.execute())

for i in range(3):
df = get_sample_data(spark, seed=i, nrows=5)
df.repartition(1).write.format('delta').mode(
'overwrite').save(case.delta_root)

assert any(path.suffixes == ['.checkpoint', '.parquet']
for path in (Path(case.delta_root) / '_delta_log').iterdir())


def remove_log_file(delta_root: str, version: int):
os.remove(os.path.join(delta_root, '_delta_log', f'{version:0>20}.json'))


@reference_table(
name='no_replay',
description='Table with a checkpoint and prior commits cleaned up',
)
def create_no_replay(case: TestCaseInfo, spark: SparkSession):
spark.conf.set(
'spark.databricks.delta.retentionDurationCheck.enabled', 'false')

df = get_sample_data(spark)

table = (DeltaTable.create(spark)
.location(str(Path(case.delta_root).absolute()))
.addColumns(df.schema)
.property('delta.checkpointInterval', '2')
.execute())

for i in range(3):
df = get_sample_data(spark, seed=i, nrows=5)
df.repartition(1).write.format('delta').mode(
'overwrite').save(case.delta_root)

table.vacuum(retentionHours=0)

remove_log_file(case.delta_root, version=0)
remove_log_file(case.delta_root, version=1)

files_in_log = list((Path(case.delta_root) / '_delta_log').iterdir())
assert any(path.suffixes == ['.checkpoint', '.parquet']
for path in files_in_log)
assert not any(path.name == f'{0:0>20}.json' for path in files_in_log)


@reference_table(
name='stats_as_struct',
description='Table with stats only written as struct (not JSON) with Checkpoint',
)
def create_stats_as_struct(case: TestCaseInfo, spark: SparkSession):
df = get_sample_data(spark)
(DeltaTable.create(spark)
.location(str(Path(case.delta_root).absolute()))
.addColumns(df.schema)
.property('delta.checkpointInterval', '2')
.property('delta.checkpoint.writeStatsAsStruct', 'true')
.property('delta.checkpoint.writeStatsAsJson', 'false')
.execute())

for i in range(3):
df = get_sample_data(spark, seed=i, nrows=5)
df.repartition(1).write.format('delta').mode(
'overwrite').save(case.delta_root)


@reference_table(
name='no_stats',
description='Table with no stats',
)
def create_no_stats(case: TestCaseInfo, spark: SparkSession):
df = get_sample_data(spark)
(DeltaTable.create(spark)
.location(str(Path(case.delta_root).absolute()))
.addColumns(df.schema)
.property('delta.checkpointInterval', '2')
.property('delta.checkpoint.writeStatsAsStruct', 'false')
.property('delta.checkpoint.writeStatsAsJson', 'false')
.property('delta.dataSkippingNumIndexedCols', '0')
.execute())

for i in range(3):
df = get_sample_data(spark, seed=i, nrows=5)
df.repartition(1).write.format('delta').mode(
'overwrite').save(case.delta_root)
29 changes: 22 additions & 7 deletions dat/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
import shutil
from pathlib import Path
from typing import Optional

import click

Expand Down Expand Up @@ -32,13 +33,27 @@ def cli():


@click.command()
def write_generated_reference_tables():
out_base = Path('out/reader_tests/generated')
shutil.rmtree(out_base, ignore_errors=True)

for metadata, create_table in generated_tables.registered_reference_tables:
logging.info("Writing table '%s'", metadata.name)
create_table()
@click.option('--table-name')
def write_generated_reference_tables(table_name: Optional[str]):
if table_name:
for metadata, create_table in generated_tables.registered_reference_tables:
if metadata.name == table_name:
logging.info("Writing table '%s'", metadata.name)
out_base = Path('out/reader_tests/generated') / table_name
shutil.rmtree(out_base, ignore_errors=True)

create_table()
break
else:
raise ValueError(
f"Could not find generated table named '{table_name}'")
else:
out_base = Path('out/reader_tests/generated')
shutil.rmtree(out_base, ignore_errors=True)

for metadata, create_table in generated_tables.registered_reference_tables:
logging.info("Writing table '%s'", metadata.name)
create_table()


@click.command()
Expand Down
Binary file not shown.
Binary file not shown.

This file was deleted.

This file was deleted.

Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ pyspark = "^3.2.0"
click = "^8.1.3"
delta-spark = "^2.1.1"
rootpath = "^0.1.1"
pandas = "^1.4.3"
pyarrow = "^10.0.0"

[tool.poetry.dev-dependencies]
Expand Down
3 changes: 2 additions & 1 deletion setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ per-file-ignores =
# WPS202 Found too many module members
tests/*: S101 WPS114 WPS226 WPS202
dat/external_tables.py: WPS226 WPS114
dat/generated_tables.py: WPS226 WPS114
dat/generated_tables.py: WPS226 WPS114
max-line-length = 90
7 changes: 6 additions & 1 deletion tests/pyspark_delta/test_pyspark_delta.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,4 +69,9 @@ def test_readers_dat(spark_session, case: ReadCase):
expected_df = spark_session.read.format('parquet').load(
str(case.parquet_root) + '/*.parquet')

chispa.assert_df_equality(actual_df, expected_df)
if 'pk' in actual_df.columns:
actual_df = actual_df.orderBy('pk')
expected_df = expected_df.orderBy('pk')
chispa.assert_df_equality(actual_df, expected_df)
else:
chispa.assert_df_equality(actual_df, expected_df, ignore_row_order=True)