diff --git a/crates/core/src/operations/write/lazy.rs b/crates/core/src/operations/write/lazy.rs new file mode 100644 index 0000000000..a16f579d65 --- /dev/null +++ b/crates/core/src/operations/write/lazy.rs @@ -0,0 +1,65 @@ +//! This module contains helper functions to create a LazyTableProvider from an ArrowArrayStreamReader + +use crate::DeltaResult; +use arrow::ffi_stream::ArrowArrayStreamReader; +use datafusion::catalog::TableProvider; +use datafusion::physical_plan::memory::LazyBatchGenerator; +use delta_datafusion::LazyTableProvider; +use parking_lot::RwLock; +use std::fmt::{self}; +use std::sync::{Arc, Mutex}; + +use crate::delta_datafusion; + +/// Convert an [ArrowArrayStreamReader] into a [LazyTableProvider] +pub fn to_lazy_table(source: ArrowArrayStreamReader) -> DeltaResult> { + use arrow::array::RecordBatchReader; + let schema = source.schema(); + let arrow_stream: Arc> = Arc::new(Mutex::new(source)); + let arrow_stream_batch_generator: Arc> = + Arc::new(RwLock::new(ArrowStreamBatchGenerator::new(arrow_stream))); + + Ok(Arc::new(LazyTableProvider::try_new( + schema.clone(), + vec![arrow_stream_batch_generator], + )?)) +} + +#[derive(Debug)] +pub(crate) struct ArrowStreamBatchGenerator { + pub array_stream: Arc>, +} + +impl fmt::Display for ArrowStreamBatchGenerator { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "ArrowStreamBatchGenerator {{ array_stream: {:?} }}", + self.array_stream + ) + } +} + +impl ArrowStreamBatchGenerator { + pub fn new(array_stream: Arc>) -> Self { + Self { array_stream } + } +} + +impl LazyBatchGenerator for ArrowStreamBatchGenerator { + fn generate_next_batch( + &mut self, + ) -> datafusion::error::Result> { + let mut stream_reader = self.array_stream.lock().map_err(|_| { + datafusion::error::DataFusionError::Execution( + "Failed to lock the ArrowArrayStreamReader".to_string(), + ) + })?; + + match stream_reader.next() { + Some(Ok(record_batch)) => Ok(Some(record_batch)), + Some(Err(err)) => Err(datafusion::error::DataFusionError::ArrowError(err, None)), + None => Ok(None), // End of stream + } + } +} diff --git a/crates/core/src/operations/write/mod.rs b/crates/core/src/operations/write/mod.rs index 2d0a5ffe62..74061b2c57 100644 --- a/crates/core/src/operations/write/mod.rs +++ b/crates/core/src/operations/write/mod.rs @@ -26,6 +26,7 @@ pub mod configs; pub(crate) mod execution; pub(crate) mod generated_columns; +pub mod lazy; pub(crate) mod schema_evolution; use arrow_schema::Schema; @@ -44,7 +45,7 @@ use datafusion::datasource::MemTable; use datafusion::execution::context::{SessionContext, SessionState}; use datafusion::prelude::DataFrame; use datafusion_common::{Column, DFSchema, Result, ScalarValue}; -use datafusion_expr::{cast, col, lit, Expr, LogicalPlan, UNNAMED_TABLE}; +use datafusion_expr::{cast, lit, Expr, LogicalPlan}; use execution::{prepare_predicate_actions, write_execution_plan_with_predicate}; use futures::future::BoxFuture; use parquet::file::properties::WriterProperties; @@ -438,7 +439,7 @@ impl std::future::IntoFuture for WriteBuilder { .unwrap_or_default(); let mut schema_drift = false; - let mut source = DataFrame::new(state.clone(), this.input.unwrap().as_ref().clone()); + let source = DataFrame::new(state.clone(), this.input.unwrap().as_ref().clone()); // Add missing generated columns to source_df let (mut source, missing_generated_columns) = diff --git a/python/src/lib.rs b/python/src/lib.rs index 445e061ae1..1792785a6f 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -50,6 +50,7 @@ use deltalake::operations::transaction::{ }; use deltalake::operations::update::UpdateBuilder; use deltalake::operations::vacuum::VacuumBuilder; +use deltalake::operations::write::WriteBuilder; use deltalake::operations::{collect_sendable_stream, CustomExecuteHandler}; use deltalake::parquet::basic::Compression; use deltalake::parquet::errors::ParquetError; @@ -2151,7 +2152,6 @@ fn write_to_deltalake( post_commithook_properties: Option, ) -> PyResult<()> { py.allow_threads(|| { - let batches = data.0.map(|batch| batch.unwrap()).collect::>(); let save_mode = mode.parse().map_err(PythonError::from)?; let options = storage_options.clone().unwrap_or_default(); @@ -2164,7 +2164,33 @@ fn write_to_deltalake( .map_err(PythonError::from)? }; - let mut builder = table.write(batches).with_save_mode(save_mode); + let dont_be_so_lazy = match table.0.state.as_ref() { + Some(state) => state.table_config().enable_change_data_feed(), + // You don't have state somehow, so I guess it's okay to be lazy. + _ => false, + }; + + let mut builder = + WriteBuilder::new(table.0.log_store(), table.0.state).with_save_mode(save_mode); + + if dont_be_so_lazy { + debug!( + "write_to_deltalake() is not able to lazily perform a write, collecting batches" + ); + builder = builder.with_input_batches(data.0.map(|batch| batch.unwrap())); + } else { + use deltalake::datafusion::datasource::provider_as_source; + use deltalake::datafusion::logical_expr::LogicalPlanBuilder; + use deltalake::operations::write::lazy::to_lazy_table; + let table_provider = to_lazy_table(data.0).map_err(PythonError::from)?; + + let plan = LogicalPlanBuilder::scan("source", provider_as_source(table_provider), None) + .map_err(PythonError::from)? + .build() + .map_err(PythonError::from)?; + builder = builder.with_input_execution_plan(Arc::new(plan)); + } + if let Some(schema_mode) = schema_mode { builder = builder.with_schema_mode(schema_mode.parse().map_err(PythonError::from)?); } diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index 16479b5e20..6d09906866 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -1035,6 +1035,7 @@ def test_partition_overwrite( tmp_path, sample_data, mode="overwrite", predicate=f"p2 < {filter_string}" ) + @pytest.fixture() def sample_data_for_partitioning() -> pa.Table: return pa.table( @@ -1590,9 +1591,13 @@ def test_schema_cols_diff_order(tmp_path: pathlib.Path, engine): def test_empty(existing_table: DeltaTable): schema = existing_table.schema().to_pyarrow() + expected = existing_table.to_pyarrow_table() empty_table = pa.Table.from_pylist([], schema=schema) - with pytest.raises(DeltaError, match="No data source supplied to write command"): - write_deltalake(existing_table, empty_table, mode="append", engine="rust") + write_deltalake(existing_table, empty_table, mode="append", engine="rust") + + existing_table.update_incremental() + assert existing_table.version() == 1 + assert expected == existing_table.to_pyarrow_table() def test_rust_decimal_cast(tmp_path: pathlib.Path): @@ -1815,8 +1820,11 @@ def test_roundtrip_cdc_evolution(tmp_path: pathlib.Path): def test_empty_dataset_write(tmp_path: pathlib.Path, sample_data: pa.Table): empty_arrow_table = sample_data.schema.empty_table() empty_dataset = dataset(empty_arrow_table) - with pytest.raises(DeltaError, match="No data source supplied to write command"): - write_deltalake(tmp_path, empty_dataset, mode="append") + write_deltalake(tmp_path, empty_dataset, mode="append") + dt = DeltaTable(tmp_path) + + new_dataset = dt.to_pyarrow_dataset() + assert new_dataset.count_rows() == 0 @pytest.mark.pandas