-
Notifications
You must be signed in to change notification settings - Fork 7
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add new tables #27
Add new tables #27
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,7 +1,12 @@ | ||
import os | ||
from datetime import date | ||
import random | ||
from datetime import date, datetime, timedelta | ||
from decimal import Decimal | ||
from pathlib import Path | ||
from typing import Callable, List, Tuple | ||
|
||
import pyspark.sql | ||
import pyspark.sql.types as types | ||
from delta.tables import DeltaTable | ||
from pyspark.sql import SparkSession | ||
|
||
|
@@ -35,7 +40,7 @@ def save_expected(case: TestCaseInfo, as_latest=False) -> None: | |
# Need to ensure directory exists first | ||
os.makedirs(case.expected_root(version)) | ||
|
||
df.toPandas().to_parquet(case.expected_path(version)) | ||
df.write.parquet(case.expected_path(version)) | ||
|
||
out_path = case.expected_root(version) / 'table_version_metadata.json' | ||
with open(out_path, 'w') as f: | ||
|
@@ -121,8 +126,6 @@ def create_multi_partitioned(case: TestCaseInfo, spark: SparkSession): | |
('b', date(1970, 1, 2), b'world', 3) | ||
] | ||
df = spark.createDataFrame(data, schema=columns) | ||
# rdd = spark.sparkContext.parallelize(data) | ||
# df = rdd.toDF(columns) | ||
schema = df.schema | ||
|
||
df.repartition(1).write.format('delta').partitionBy( | ||
|
@@ -152,6 +155,24 @@ def create_multi_partitioned(case: TestCaseInfo, spark: SparkSession): | |
save_expected(case) | ||
|
||
|
||
@reference_table( | ||
name='multi_partitioned_2', | ||
description=('Multiple levels of partitioning, with boolean, timestamp, and ' | ||
'decimal partition columns') | ||
) | ||
def create_multi_partitioned_2(case: TestCaseInfo, spark: SparkSession): | ||
columns = ['bool', 'time', 'amount', 'int'] | ||
partition_columns = ['bool', 'time', 'amount'] | ||
data = [ | ||
(True, datetime(1970, 1, 1), Decimal('200.00'), 1), | ||
(True, datetime(1970, 1, 1, 12, 30), Decimal('200.00'), 2), | ||
(False, datetime(1970, 1, 2, 8, 45), Decimal('12.00'), 3) | ||
] | ||
df = spark.createDataFrame(data, schema=columns) | ||
df.repartition(1).write.format('delta').partitionBy( | ||
*partition_columns).save(case.delta_root) | ||
|
||
|
||
@reference_table( | ||
name='with_schema_change', | ||
description='Table which has schema change using overwriteSchema=True.', | ||
|
@@ -171,3 +192,193 @@ def with_schema_change(case: TestCaseInfo, spark: SparkSession): | |
'overwriteSchema', True).format('delta').save( | ||
case.delta_root) | ||
save_expected(case) | ||
|
||
|
||
@reference_table( | ||
name='all_primitive_types', | ||
description='Table containing all non-nested types', | ||
) | ||
def create_all_primitive_types(case: TestCaseInfo, spark: SparkSession): | ||
schema = types.StructType([ | ||
types.StructField('utf8', types.StringType()), | ||
types.StructField('int64', types.LongType()), | ||
types.StructField('int32', types.IntegerType()), | ||
types.StructField('int16', types.ShortType()), | ||
types.StructField('int8', types.ByteType()), | ||
types.StructField('float32', types.FloatType()), | ||
types.StructField('float64', types.DoubleType()), | ||
types.StructField('bool', types.BooleanType()), | ||
types.StructField('binary', types.BinaryType()), | ||
types.StructField('decimal', types.DecimalType(5, 3)), | ||
types.StructField('date32', types.DateType()), | ||
types.StructField('timestamp', types.TimestampType()), | ||
]) | ||
|
||
df = spark.createDataFrame([ | ||
( | ||
str(i), | ||
i, | ||
i, | ||
i, | ||
i, | ||
float(i), | ||
float(i), | ||
i % 2 == 0, | ||
bytes(i), | ||
Decimal('10.000') + i, | ||
date(1970, 1, 1) + timedelta(days=i), | ||
datetime(1970, 1, 1) + timedelta(hours=i) | ||
) | ||
for i in range(5) | ||
], schema=schema) | ||
|
||
df.repartition(1).write.format('delta').save(case.delta_root) | ||
|
||
|
||
@reference_table( | ||
name='nested_types', | ||
description='Table containing various nested types', | ||
) | ||
def create_nested_types(case: TestCaseInfo, spark: SparkSession): | ||
schema = types.StructType([ | ||
types.StructField( | ||
'pk', types.IntegerType() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. chispa doesn't support ignoring sort order when comparing tables that contain map types, so we have to add a |
||
), | ||
types.StructField( | ||
'struct', types.StructType( | ||
[types.StructField( | ||
'float64', types.DoubleType()), | ||
types.StructField( | ||
'bool', types.BooleanType()), ])), | ||
types.StructField( | ||
'array', types.ArrayType( | ||
types.ShortType())), | ||
types.StructField( | ||
'map', types.MapType( | ||
types.StringType(), | ||
types.IntegerType())), ]) | ||
|
||
df = spark.createDataFrame([ | ||
( | ||
i, | ||
{'float64': float(i), 'bool': i % 2 == 0}, | ||
list(range(i + 1)), | ||
{str(i): i for i in range(i)} | ||
) | ||
for i in range(5) | ||
], schema=schema) | ||
|
||
df.repartition(1).write.format('delta').save(case.delta_root) | ||
|
||
|
||
def get_sample_data( | ||
spark: SparkSession, seed: int = 42, nrows: int = 5) -> pyspark.sql.DataFrame: | ||
# Use seed to get consistent data between runs, for reproducibility | ||
random.seed(seed) | ||
return spark.createDataFrame([ | ||
( | ||
random.choice(['a', 'b', 'c', None]), | ||
random.randint(0, 1000), | ||
date(random.randint(1970, 2020), random.randint(1, 12), 1) | ||
) | ||
for i in range(nrows) | ||
], schema=['letter', 'int', 'date']) | ||
|
||
|
||
@reference_table( | ||
name='with_checkpoint', | ||
description='Table with a checkpoint', | ||
) | ||
def create_with_checkpoint(case: TestCaseInfo, spark: SparkSession): | ||
df = get_sample_data(spark) | ||
|
||
(DeltaTable.create(spark) | ||
.location(str(Path(case.delta_root).absolute())) | ||
.addColumns(df.schema) | ||
.property('delta.checkpointInterval', '2') | ||
.execute()) | ||
|
||
for i in range(3): | ||
df = get_sample_data(spark, seed=i, nrows=5) | ||
df.repartition(1).write.format('delta').mode( | ||
'overwrite').save(case.delta_root) | ||
|
||
assert any(path.suffixes == ['.checkpoint', '.parquet'] | ||
for path in (Path(case.delta_root) / '_delta_log').iterdir()) | ||
|
||
|
||
def remove_log_file(delta_root: str, version: int): | ||
os.remove(os.path.join(delta_root, '_delta_log', f'{version:0>20}.json')) | ||
|
||
|
||
@reference_table( | ||
name='no_replay', | ||
description='Table with a checkpoint and prior commits cleaned up', | ||
) | ||
def create_no_replay(case: TestCaseInfo, spark: SparkSession): | ||
spark.conf.set( | ||
'spark.databricks.delta.retentionDurationCheck.enabled', 'false') | ||
|
||
df = get_sample_data(spark) | ||
|
||
table = (DeltaTable.create(spark) | ||
.location(str(Path(case.delta_root).absolute())) | ||
.addColumns(df.schema) | ||
.property('delta.checkpointInterval', '2') | ||
.execute()) | ||
|
||
for i in range(3): | ||
df = get_sample_data(spark, seed=i, nrows=5) | ||
df.repartition(1).write.format('delta').mode( | ||
'overwrite').save(case.delta_root) | ||
|
||
table.vacuum(retentionHours=0) | ||
|
||
remove_log_file(case.delta_root, version=0) | ||
remove_log_file(case.delta_root, version=1) | ||
|
||
files_in_log = list((Path(case.delta_root) / '_delta_log').iterdir()) | ||
assert any(path.suffixes == ['.checkpoint', '.parquet'] | ||
for path in files_in_log) | ||
assert not any(path.name == f'{0:0>20}.json' for path in files_in_log) | ||
|
||
|
||
@reference_table( | ||
name='stats_as_struct', | ||
description='Table with stats only written as struct (not JSON) with Checkpoint', | ||
) | ||
def create_stats_as_struct(case: TestCaseInfo, spark: SparkSession): | ||
df = get_sample_data(spark) | ||
(DeltaTable.create(spark) | ||
.location(str(Path(case.delta_root).absolute())) | ||
.addColumns(df.schema) | ||
.property('delta.checkpointInterval', '2') | ||
.property('delta.checkpoint.writeStatsAsStruct', 'true') | ||
.property('delta.checkpoint.writeStatsAsJson', 'false') | ||
.execute()) | ||
|
||
for i in range(3): | ||
df = get_sample_data(spark, seed=i, nrows=5) | ||
df.repartition(1).write.format('delta').mode( | ||
'overwrite').save(case.delta_root) | ||
|
||
|
||
@reference_table( | ||
name='no_stats', | ||
description='Table with no stats', | ||
) | ||
def create_no_stats(case: TestCaseInfo, spark: SparkSession): | ||
df = get_sample_data(spark) | ||
(DeltaTable.create(spark) | ||
.location(str(Path(case.delta_root).absolute())) | ||
.addColumns(df.schema) | ||
.property('delta.checkpointInterval', '2') | ||
.property('delta.checkpoint.writeStatsAsStruct', 'false') | ||
.property('delta.checkpoint.writeStatsAsJson', 'false') | ||
.property('delta.dataSkippingNumIndexedCols', '0') | ||
.execute()) | ||
|
||
for i in range(3): | ||
df = get_sample_data(spark, seed=i, nrows=5) | ||
df.repartition(1).write.format('delta').mode( | ||
'overwrite').save(case.delta_root) |
This file was deleted.
This file was deleted.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It turned out
toPandas().to_parquet()
causes weird things to happen to timestamps, so better to stick to Spark here. Without this, I was able to eliminate the pandas dependency as well.