Skip to content

Commit

Permalink
feat: streamed write execution except cdf
Browse files Browse the repository at this point in the history
Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com>
  • Loading branch information
ion-elgreco committed Feb 19, 2025
1 parent b1776c7 commit 31b7880
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 8 deletions.
65 changes: 65 additions & 0 deletions crates/core/src/operations/write/lazy.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
//! This module contains helper functions to create a LazyTableProvider from an ArrowArrayStreamReader
use crate::DeltaResult;
use arrow::ffi_stream::ArrowArrayStreamReader;
use datafusion::catalog::TableProvider;
use datafusion::physical_plan::memory::LazyBatchGenerator;
use delta_datafusion::LazyTableProvider;
use parking_lot::RwLock;
use std::fmt::{self};
use std::sync::{Arc, Mutex};

use crate::delta_datafusion;

/// Convert an [ArrowArrayStreamReader] into a [LazyTableProvider]
pub fn to_lazy_table(source: ArrowArrayStreamReader) -> DeltaResult<Arc<dyn TableProvider>> {
use arrow::array::RecordBatchReader;
let schema = source.schema();
let arrow_stream: Arc<Mutex<ArrowArrayStreamReader>> = Arc::new(Mutex::new(source));
let arrow_stream_batch_generator: Arc<RwLock<dyn LazyBatchGenerator>> =
Arc::new(RwLock::new(ArrowStreamBatchGenerator::new(arrow_stream)));

Ok(Arc::new(LazyTableProvider::try_new(
schema.clone(),
vec![arrow_stream_batch_generator],
)?))
}

#[derive(Debug)]
pub(crate) struct ArrowStreamBatchGenerator {
pub array_stream: Arc<Mutex<ArrowArrayStreamReader>>,
}

impl fmt::Display for ArrowStreamBatchGenerator {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"ArrowStreamBatchGenerator {{ array_stream: {:?} }}",
self.array_stream
)
}
}

impl ArrowStreamBatchGenerator {
pub fn new(array_stream: Arc<Mutex<ArrowArrayStreamReader>>) -> Self {
Self { array_stream }
}
}

impl LazyBatchGenerator for ArrowStreamBatchGenerator {
fn generate_next_batch(
&mut self,
) -> datafusion::error::Result<Option<arrow::array::RecordBatch>> {
let mut stream_reader = self.array_stream.lock().map_err(|_| {
datafusion::error::DataFusionError::Execution(
"Failed to lock the ArrowArrayStreamReader".to_string(),
)
})?;

match stream_reader.next() {
Some(Ok(record_batch)) => Ok(Some(record_batch)),
Some(Err(err)) => Err(datafusion::error::DataFusionError::ArrowError(err, None)),
None => Ok(None), // End of stream
}
}
}
5 changes: 3 additions & 2 deletions crates/core/src/operations/write/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
pub mod configs;
pub(crate) mod execution;
pub(crate) mod generated_columns;
pub mod lazy;
pub(crate) mod schema_evolution;

use arrow_schema::Schema;
Expand All @@ -44,7 +45,7 @@ use datafusion::datasource::MemTable;
use datafusion::execution::context::{SessionContext, SessionState};
use datafusion::prelude::DataFrame;
use datafusion_common::{Column, DFSchema, Result, ScalarValue};
use datafusion_expr::{cast, col, lit, Expr, LogicalPlan, UNNAMED_TABLE};
use datafusion_expr::{cast, lit, Expr, LogicalPlan};
use execution::{prepare_predicate_actions, write_execution_plan_with_predicate};
use futures::future::BoxFuture;
use parquet::file::properties::WriterProperties;
Expand Down Expand Up @@ -438,7 +439,7 @@ impl std::future::IntoFuture for WriteBuilder {
.unwrap_or_default();

let mut schema_drift = false;
let mut source = DataFrame::new(state.clone(), this.input.unwrap().as_ref().clone());
let source = DataFrame::new(state.clone(), this.input.unwrap().as_ref().clone());

// Add missing generated columns to source_df
let (mut source, missing_generated_columns) =
Expand Down
30 changes: 28 additions & 2 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ use deltalake::operations::transaction::{
};
use deltalake::operations::update::UpdateBuilder;
use deltalake::operations::vacuum::VacuumBuilder;
use deltalake::operations::write::WriteBuilder;
use deltalake::operations::{collect_sendable_stream, CustomExecuteHandler};
use deltalake::parquet::basic::Compression;
use deltalake::parquet::errors::ParquetError;
Expand Down Expand Up @@ -2151,7 +2152,6 @@ fn write_to_deltalake(
post_commithook_properties: Option<PyPostCommitHookProperties>,
) -> PyResult<()> {
py.allow_threads(|| {
let batches = data.0.map(|batch| batch.unwrap()).collect::<Vec<_>>();
let save_mode = mode.parse().map_err(PythonError::from)?;

let options = storage_options.clone().unwrap_or_default();
Expand All @@ -2164,7 +2164,33 @@ fn write_to_deltalake(
.map_err(PythonError::from)?
};

let mut builder = table.write(batches).with_save_mode(save_mode);
let dont_be_so_lazy = match table.0.state.as_ref() {
Some(state) => state.table_config().enable_change_data_feed(),
// You don't have state somehow, so I guess it's okay to be lazy.
_ => false,
};

let mut builder =
WriteBuilder::new(table.0.log_store(), table.0.state).with_save_mode(save_mode);

if dont_be_so_lazy {
debug!(
"write_to_deltalake() is not able to lazily perform a write, collecting batches"
);
builder = builder.with_input_batches(data.0.map(|batch| batch.unwrap()));
} else {
use deltalake::datafusion::datasource::provider_as_source;
use deltalake::datafusion::logical_expr::LogicalPlanBuilder;
use deltalake::operations::write::lazy::to_lazy_table;
let table_provider = to_lazy_table(data.0).map_err(PythonError::from)?;

let plan = LogicalPlanBuilder::scan("source", provider_as_source(table_provider), None)
.map_err(PythonError::from)?
.build()
.map_err(PythonError::from)?;
builder = builder.with_input_execution_plan(Arc::new(plan));
}

if let Some(schema_mode) = schema_mode {
builder = builder.with_schema_mode(schema_mode.parse().map_err(PythonError::from)?);
}
Expand Down
16 changes: 12 additions & 4 deletions python/tests/test_writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1035,6 +1035,7 @@ def test_partition_overwrite(
tmp_path, sample_data, mode="overwrite", predicate=f"p2 < {filter_string}"
)


@pytest.fixture()
def sample_data_for_partitioning() -> pa.Table:
return pa.table(
Expand Down Expand Up @@ -1590,9 +1591,13 @@ def test_schema_cols_diff_order(tmp_path: pathlib.Path, engine):

def test_empty(existing_table: DeltaTable):
schema = existing_table.schema().to_pyarrow()
expected = existing_table.to_pyarrow_table()
empty_table = pa.Table.from_pylist([], schema=schema)
with pytest.raises(DeltaError, match="No data source supplied to write command"):
write_deltalake(existing_table, empty_table, mode="append", engine="rust")
write_deltalake(existing_table, empty_table, mode="append", engine="rust")

existing_table.update_incremental()
assert existing_table.version() == 1
assert expected == existing_table.to_pyarrow_table()


def test_rust_decimal_cast(tmp_path: pathlib.Path):
Expand Down Expand Up @@ -1815,8 +1820,11 @@ def test_roundtrip_cdc_evolution(tmp_path: pathlib.Path):
def test_empty_dataset_write(tmp_path: pathlib.Path, sample_data: pa.Table):
empty_arrow_table = sample_data.schema.empty_table()
empty_dataset = dataset(empty_arrow_table)
with pytest.raises(DeltaError, match="No data source supplied to write command"):
write_deltalake(tmp_path, empty_dataset, mode="append")
write_deltalake(tmp_path, empty_dataset, mode="append")
dt = DeltaTable(tmp_path)

new_dataset = dt.to_pyarrow_dataset()
assert new_dataset.count_rows() == 0


@pytest.mark.pandas
Expand Down

0 comments on commit 31b7880

Please sign in to comment.