diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index f1a3e1af3e..ecf27819a9 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -53,7 +53,8 @@ use crate::delta_datafusion::{ use crate::errors::DeltaResult; use crate::kernel::{Action, Add, Remove}; use crate::logstore::LogStoreRef; -use crate::operations::write::{write_execution_plan, write_execution_plan_cdc, WriterStatsConfig}; +use crate::operations::write::execution::{write_execution_plan, write_execution_plan_cdc}; +use crate::operations::write::WriterStatsConfig; use crate::operations::CustomExecuteHandler; use crate::protocol::DeltaOperation; use crate::table::state::DeltaTableState; diff --git a/crates/core/src/operations/merge/mod.rs b/crates/core/src/operations/merge/mod.rs index 8622160318..de5367e682 100644 --- a/crates/core/src/operations/merge/mod.rs +++ b/crates/core/src/operations/merge/mod.rs @@ -52,8 +52,7 @@ use datafusion_common::tree_node::{Transformed, TreeNode}; use datafusion_common::{Column, DFSchema, ExprSchema, ScalarValue, TableReference}; use datafusion_expr::{col, conditional_expressions::CaseBuilder, lit, when, Expr, JoinType}; use datafusion_expr::{ - ExprSchemable, Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, - UNNAMED_TABLE, + Extension, LogicalPlan, LogicalPlanBuilder, UserDefinedLogicalNode, UNNAMED_TABLE, }; use delta_kernel::schema::{ColumnMetadataKey, StructType}; @@ -78,16 +77,18 @@ use crate::delta_datafusion::{ DeltaSessionConfig, DeltaTableProvider, }; -use crate::kernel::{Action, DataCheck, Metadata, StructTypeExt}; +use crate::kernel::{Action, Metadata, StructTypeExt}; use crate::logstore::LogStoreRef; use crate::operations::cast::merge_schema::{merge_arrow_field, merge_arrow_schema}; use crate::operations::cdc::*; use crate::operations::merge::barrier::find_node; use crate::operations::transaction::CommitBuilder; +use crate::operations::write::generated_columns::{ + add_generated_columns, add_missing_generated_columns, +}; use crate::operations::write::WriterStatsConfig; use crate::protocol::{DeltaOperation, MergePredicate}; use crate::table::state::DeltaTableState; -use crate::table::GeneratedColumn; use crate::{DeltaResult, DeltaTable, DeltaTableError}; use writer::write_execution_plan_v2; @@ -776,72 +777,6 @@ async fn execute( None => TableReference::bare(UNNAMED_TABLE), }; - /// Add generated column expressions to a dataframe - fn add_missing_generated_columns( - mut df: DataFrame, - generated_cols: &Vec, - ) -> DeltaResult<(DataFrame, Vec)> { - let mut missing_cols = vec![]; - for generated_col in generated_cols { - let col_name = generated_col.get_name(); - - if df - .clone() - .schema() - .field_with_unqualified_name(col_name) - .is_err() - // implies it doesn't exist - { - debug!( - "Adding missing generated column {} in source as placeholder", - col_name - ); - // If column doesn't exist, we add a null column, later we will generate the values after - // all the merge is projected. - // Other generated columns that were provided upon the start we only validate during write - missing_cols.push(col_name.to_string()); - df = df - .clone() - .with_column(col_name, Expr::Literal(ScalarValue::Null))?; - } - } - Ok((df, missing_cols)) - } - - /// Add generated column expressions to a dataframe - fn add_generated_columns( - mut df: DataFrame, - generated_cols: &Vec, - generated_cols_missing_in_source: &[String], - state: &SessionState, - ) -> DeltaResult { - debug!("Generating columns in dataframe"); - for generated_col in generated_cols { - // We only validate columns that were missing from the start. We don't update - // update generated columns that were provided during runtime - if !generated_cols_missing_in_source.contains(&generated_col.name) { - continue; - } - - let generation_expr = state.create_logical_expr( - generated_col.get_generation_expression(), - df.clone().schema(), - )?; - let col_name = generated_col.get_name(); - - df = df.clone().with_column( - generated_col.get_name(), - when(col(col_name).is_null(), generation_expr) - .otherwise(col(col_name))? - .cast_to( - &arrow_schema::DataType::try_from(&generated_col.data_type)?, - df.schema(), - )?, - )? - } - Ok(df) - } - let generated_col_expressions = snapshot .schema() .get_generated_columns() diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index b014c0214f..eba990014e 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -51,7 +51,7 @@ use super::{ }; use super::{transaction::PROTOCOL, write::WriterStatsConfig}; use super::{ - write::{write_execution_plan, write_execution_plan_cdc}, + write::execution::{write_execution_plan, write_execution_plan_cdc}, CustomExecuteHandler, Operation, }; use crate::delta_datafusion::{find_files, planner::DeltaPlanner, register_store}; diff --git a/crates/core/src/operations/write/configs.rs b/crates/core/src/operations/write/configs.rs new file mode 100644 index 0000000000..931f99189f --- /dev/null +++ b/crates/core/src/operations/write/configs.rs @@ -0,0 +1,18 @@ +/// Configuration for the writer on how to collect stats +#[derive(Clone)] +pub struct WriterStatsConfig { + /// Number of columns to collect stats for, idx based + pub num_indexed_cols: i32, + /// Optional list of columns which to collect stats for, takes precedende over num_index_cols + pub stats_columns: Option>, +} + +impl WriterStatsConfig { + /// Create new writer stats config + pub fn new(num_indexed_cols: i32, stats_columns: Option>) -> Self { + Self { + num_indexed_cols, + stats_columns, + } + } +} diff --git a/crates/core/src/operations/write/execution.rs b/crates/core/src/operations/write/execution.rs new file mode 100644 index 0000000000..cdf57a56d0 --- /dev/null +++ b/crates/core/src/operations/write/execution.rs @@ -0,0 +1,391 @@ +use std::sync::Arc; +use std::vec; + +use arrow_array::RecordBatch; +use arrow_schema::SchemaRef as ArrowSchemaRef; +use datafusion::datasource::provider_as_source; +use datafusion::execution::context::{SessionState, TaskContext}; +use datafusion::prelude::DataFrame; +use datafusion_expr::{lit, Expr, LogicalPlanBuilder}; +use datafusion_physical_plan::ExecutionPlan; +use futures::StreamExt; +use object_store::prefix::PrefixStore; +use parquet::file::properties::WriterProperties; +use tracing::log::*; +use uuid::Uuid; + +use crate::delta_datafusion::expr::fmt_expr_to_sql; +use crate::delta_datafusion::{find_files, DeltaScanConfigBuilder, DeltaTableProvider}; +use crate::delta_datafusion::{DataFusionMixins, DeltaDataChecker}; +use crate::errors::DeltaResult; +use crate::kernel::{Action, Add, AddCDCFile, Remove, StructType, StructTypeExt}; +use crate::logstore::LogStoreRef; +use crate::operations::cdc::should_write_cdc; +use crate::operations::writer::{DeltaWriter, WriterConfig}; +use crate::storage::ObjectStoreRef; +use crate::table::state::DeltaTableState; +use crate::table::Constraint as DeltaConstraint; +use tokio::sync::mpsc::Sender; + +use super::configs::WriterStatsConfig; +use super::WriteError; + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn write_execution_plan_with_predicate( + predicate: Option, + snapshot: Option<&DeltaTableState>, + state: SessionState, + plan: Arc, + partition_columns: Vec, + object_store: ObjectStoreRef, + target_file_size: Option, + write_batch_size: Option, + writer_properties: Option, + writer_stats_config: WriterStatsConfig, + sender: Option>, +) -> DeltaResult> { + // We always take the plan Schema since the data may contain Large/View arrow types, + // the schema and batches were prior constructed with this in mind. + let schema: ArrowSchemaRef = plan.schema(); + let checker = if let Some(snapshot) = snapshot { + DeltaDataChecker::new(snapshot) + } else { + debug!("Using plan schema to derive generated columns, since no snapshot was provided. Implies first write."); + let delta_schema: StructType = schema.as_ref().try_into()?; + DeltaDataChecker::new_with_generated_columns( + delta_schema.get_generated_columns().unwrap_or_default(), + ) + }; + let checker = match predicate { + Some(pred) => { + // TODO: get the name of the outer-most column? `*` will also work but would it be slower? + let chk = DeltaConstraint::new("*", &fmt_expr_to_sql(&pred)?); + checker.with_extra_constraints(vec![chk]) + } + _ => checker, + }; + // Write data to disk + let mut tasks = vec![]; + for i in 0..plan.properties().output_partitioning().partition_count() { + let inner_plan = plan.clone(); + let inner_schema = schema.clone(); + let task_ctx = Arc::new(TaskContext::from(&state)); + let config = WriterConfig::new( + inner_schema.clone(), + partition_columns.clone(), + writer_properties.clone(), + target_file_size, + write_batch_size, + writer_stats_config.num_indexed_cols, + writer_stats_config.stats_columns.clone(), + ); + let mut writer = DeltaWriter::new(object_store.clone(), config); + let checker_stream = checker.clone(); + let sender_stream = sender.clone(); + let mut stream = inner_plan.execute(i, task_ctx)?; + + let handle: tokio::task::JoinHandle>> = tokio::task::spawn( + async move { + let sendable = sender_stream.clone(); + while let Some(maybe_batch) = stream.next().await { + let batch = maybe_batch?; + + checker_stream.check_batch(&batch).await?; + + if let Some(s) = sendable.as_ref() { + if let Err(e) = s.send(batch.clone()).await { + error!("Failed to send data to observer: {e:#?}"); + } + } else { + debug!("write_execution_plan_with_predicate did not send any batches, no sender."); + } + writer.write(&batch).await?; + } + let add_actions = writer.close().await; + match add_actions { + Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::>()), + Err(err) => Err(err), + } + }, + ); + + tasks.push(handle); + } + let actions = futures::future::join_all(tasks) + .await + .into_iter() + .collect::, _>>() + .map_err(|err| WriteError::WriteTask { source: err })? + .into_iter() + .collect::, _>>()? + .concat() + .into_iter() + .collect::>(); + // Collect add actions to add to commit + Ok(actions) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn write_execution_plan_cdc( + snapshot: Option<&DeltaTableState>, + state: SessionState, + plan: Arc, + partition_columns: Vec, + object_store: ObjectStoreRef, + target_file_size: Option, + write_batch_size: Option, + writer_properties: Option, + writer_stats_config: WriterStatsConfig, + sender: Option>, +) -> DeltaResult> { + let cdc_store = Arc::new(PrefixStore::new(object_store, "_change_data")); + + Ok(write_execution_plan( + snapshot, + state, + plan, + partition_columns, + cdc_store, + target_file_size, + write_batch_size, + writer_properties, + writer_stats_config, + sender, + ) + .await? + .into_iter() + .map(|add| { + // Modify add actions into CDC actions + match add { + Action::Add(add) => { + Action::Cdc(AddCDCFile { + // This is a gnarly hack, but the action needs the nested path, not the + // path isnide the prefixed store + path: format!("_change_data/{}", add.path), + size: add.size, + partition_values: add.partition_values, + data_change: false, + tags: add.tags, + }) + } + _ => panic!("Expected Add action"), + } + }) + .collect::>()) +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn write_execution_plan( + snapshot: Option<&DeltaTableState>, + state: SessionState, + plan: Arc, + partition_columns: Vec, + object_store: ObjectStoreRef, + target_file_size: Option, + write_batch_size: Option, + writer_properties: Option, + writer_stats_config: WriterStatsConfig, + sender: Option>, +) -> DeltaResult> { + write_execution_plan_with_predicate( + None, + snapshot, + state, + plan, + partition_columns, + object_store, + target_file_size, + write_batch_size, + writer_properties, + writer_stats_config, + sender, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +pub(crate) async fn execute_non_empty_expr( + snapshot: &DeltaTableState, + log_store: LogStoreRef, + state: SessionState, + partition_columns: Vec, + expression: &Expr, + rewrite: &[Add], + writer_properties: Option, + writer_stats_config: WriterStatsConfig, + partition_scan: bool, + insert_df: DataFrame, + operation_id: Uuid, +) -> DeltaResult> { + // For each identified file perform a parquet scan + filter + limit (1) + count. + // If returned count is not zero then append the file to be rewritten and removed from the log. Otherwise do nothing to the file. + let mut actions: Vec = Vec::new(); + + // Take the insert plan schema since it might have been schema evolved, if its not + // it is simply the table schema + let scan_config = DeltaScanConfigBuilder::new() + .with_schema(snapshot.input_schema()?) + .build(snapshot)?; + + let target_provider = Arc::new( + DeltaTableProvider::try_new(snapshot.clone(), log_store.clone(), scan_config.clone())? + .with_files(rewrite.to_vec()), + ); + + let target_provider = provider_as_source(target_provider); + let source = LogicalPlanBuilder::scan("target", target_provider.clone(), None)?.build()?; + // We don't want to verify the predicate against existing data + + let df = DataFrame::new(state.clone(), source); + + if !partition_scan { + // Apply the negation of the filter and rewrite files + let negated_expression = Expr::Not(Box::new(Expr::IsTrue(Box::new(expression.clone())))); + + let filter = df + .clone() + .filter(negated_expression)? + .create_physical_plan() + .await?; + + let add_actions: Vec = write_execution_plan( + Some(snapshot), + state.clone(), + filter, + partition_columns.clone(), + log_store.object_store(Some(operation_id)), + Some(snapshot.table_config().target_file_size() as usize), + None, + writer_properties.clone(), + writer_stats_config.clone(), + None, + ) + .await?; + + actions.extend(add_actions); + } + + // CDC logic, simply filters data with predicate and adds the _change_type="delete" as literal column + // Only write when CDC actions when it was not a partition scan, load_cdf can deduce the deletes in that case + // based on the remove actions if a partition got deleted + if !partition_scan { + // We only write deletions when it was not a partition scan + if let Some(cdc_actions) = execute_non_empty_expr_cdc( + snapshot, + log_store, + state.clone(), + df, + expression, + partition_columns, + writer_properties, + writer_stats_config, + insert_df, + operation_id, + ) + .await? + { + actions.extend(cdc_actions) + } + } + Ok(actions) +} + +/// If CDC is enabled it writes all the deletions based on predicate into _change_data directory +#[allow(clippy::too_many_arguments)] +pub(crate) async fn execute_non_empty_expr_cdc( + snapshot: &DeltaTableState, + log_store: LogStoreRef, + state: SessionState, + scan: DataFrame, + expression: &Expr, + table_partition_cols: Vec, + writer_properties: Option, + writer_stats_config: WriterStatsConfig, + insert_df: DataFrame, + operation_id: Uuid, +) -> DeltaResult>> { + match should_write_cdc(snapshot) { + // Create CDC scan + Ok(true) => { + let filter = scan.clone().filter(expression.clone())?; + + // Add literal column "_change_type" + let delete_change_type_expr = lit("delete").alias("_change_type"); + + let insert_change_type_expr = lit("insert").alias("_change_type"); + + let delete_df = filter.with_column("_change_type", delete_change_type_expr)?; + + let insert_df = insert_df.with_column("_change_type", insert_change_type_expr)?; + + let cdc_df = delete_df.union(insert_df)?; + + let cdc_actions = write_execution_plan_cdc( + Some(snapshot), + state.clone(), + cdc_df.create_physical_plan().await?, + table_partition_cols.clone(), + log_store.object_store(Some(operation_id)), + Some(snapshot.table_config().target_file_size() as usize), + None, + writer_properties, + writer_stats_config, + None, + ) + .await?; + Ok(Some(cdc_actions)) + } + _ => Ok(None), + } +} + +// This should only be called with a valid predicate +#[allow(clippy::too_many_arguments)] +pub(crate) async fn prepare_predicate_actions( + predicate: Expr, + log_store: LogStoreRef, + snapshot: &DeltaTableState, + state: SessionState, + partition_columns: Vec, + writer_properties: Option, + deletion_timestamp: i64, + writer_stats_config: WriterStatsConfig, + insert_df: DataFrame, + operation_id: Uuid, +) -> DeltaResult> { + let candidates = + find_files(snapshot, log_store.clone(), &state, Some(predicate.clone())).await?; + + let mut actions = execute_non_empty_expr( + snapshot, + log_store, + state, + partition_columns, + &predicate, + &candidates.candidates, + writer_properties, + writer_stats_config, + candidates.partition_scan, + insert_df, + operation_id, + ) + .await?; + + let remove = candidates.candidates; + + for action in remove { + actions.push(Action::Remove(Remove { + path: action.path, + deletion_timestamp: Some(deletion_timestamp), + data_change: true, + extended_file_metadata: Some(true), + partition_values: Some(action.partition_values), + size: Some(action.size), + deletion_vector: action.deletion_vector, + tags: None, + base_row_id: action.base_row_id, + default_row_commit_version: action.default_row_commit_version, + })) + } + Ok(actions) +} diff --git a/crates/core/src/operations/write/generated_columns.rs b/crates/core/src/operations/write/generated_columns.rs new file mode 100644 index 0000000000..f39866f734 --- /dev/null +++ b/crates/core/src/operations/write/generated_columns.rs @@ -0,0 +1,72 @@ +use datafusion::{execution::SessionState, prelude::DataFrame}; +use datafusion_common::ScalarValue; +use datafusion_expr::{col, when, Expr, ExprSchemable}; +use tracing::debug; + +use crate::{kernel::DataCheck, table::GeneratedColumn, DeltaResult}; + +/// Add generated column expressions to a dataframe +pub fn add_missing_generated_columns( + mut df: DataFrame, + generated_cols: &Vec, +) -> DeltaResult<(DataFrame, Vec)> { + let mut missing_cols = vec![]; + for generated_col in generated_cols { + let col_name = generated_col.get_name(); + + if df + .clone() + .schema() + .field_with_unqualified_name(col_name) + .is_err() + // implies it doesn't exist + { + debug!( + "Adding missing generated column {} in source as placeholder", + col_name + ); + // If column doesn't exist, we add a null column, later we will generate the values after + // all the merge is projected. + // Other generated columns that were provided upon the start we only validate during write + missing_cols.push(col_name.to_string()); + df = df + .clone() + .with_column(col_name, Expr::Literal(ScalarValue::Null))?; + } + } + Ok((df, missing_cols)) +} + +/// Add generated column expressions to a dataframe +pub fn add_generated_columns( + mut df: DataFrame, + generated_cols: &Vec, + generated_cols_missing_in_source: &[String], + state: &SessionState, +) -> DeltaResult { + debug!("Generating columns in dataframe"); + for generated_col in generated_cols { + // We only validate columns that were missing from the start. We don't update + // update generated columns that were provided during runtime + if !generated_cols_missing_in_source.contains(&generated_col.name) { + continue; + } + + let generation_expr = state.create_logical_expr( + generated_col.get_generation_expression(), + df.clone().schema(), + )?; + let col_name = generated_col.get_name(); + + df = df.clone().with_column( + generated_col.get_name(), + when(col(col_name).is_null(), generation_expr) + .otherwise(col(col_name))? + .cast_to( + &arrow_schema::DataType::try_from(&generated_col.data_type)?, + df.schema(), + )?, + )? + } + Ok(df) +} diff --git a/crates/core/src/operations/write.rs b/crates/core/src/operations/write/mod.rs similarity index 68% rename from crates/core/src/operations/write.rs rename to crates/core/src/operations/write/mod.rs index 5d67a637fc..2d0a5ffe62 100644 --- a/crates/core/src/operations/write.rs +++ b/crates/core/src/operations/write/mod.rs @@ -1,12 +1,11 @@ //! //! New Table Semantics -//! - The schema of the [RecordBatch] is used to initialize the table. +//! - The schema of the [Plan] is used to initialize the table. //! - The partition columns will be used to partition the table. //! //! Existing Table Semantics //! - The save mode will control how existing data is handled (i.e. overwrite, append, etc) -//! - (NOT YET IMPLEMENTED) The schema of the RecordBatch will be checked and if there are new columns present -//! they will be added to the tables schema. Conflicting columns (i.e. a INT, and a STRING) +//! - Conflicting columns (i.e. a INT, and a STRING) //! will result in an exception. //! - The partition columns, if present, are validated against the existing metadata. If not //! present, then the partitioning of the table is respected. @@ -24,6 +23,15 @@ //! let table = ops.write(vec![batch]).await?; //! ```` +pub mod configs; +pub(crate) mod execution; +pub(crate) mod generated_columns; +pub(crate) mod schema_evolution; + +use arrow_schema::Schema; +pub use configs::WriterStatsConfig; +use datafusion::execution::SessionStateBuilder; +use generated_columns::{add_generated_columns, add_missing_generated_columns}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -31,50 +39,34 @@ use std::time::{Instant, SystemTime, UNIX_EPOCH}; use std::vec; use arrow_array::RecordBatch; -use arrow_cast::can_cast_types; -use arrow_schema::{ArrowError, DataType, Fields, SchemaRef as ArrowSchemaRef}; use datafusion::catalog::TableProvider; -use datafusion::datasource::{provider_as_source, MemTable}; -use datafusion::execution::context::{SessionContext, SessionState, TaskContext}; +use datafusion::datasource::MemTable; +use datafusion::execution::context::{SessionContext, SessionState}; use datafusion::prelude::DataFrame; -use datafusion_common::DFSchema; -use datafusion_expr::{col, lit, when, Expr, ExprSchemable, LogicalPlan, LogicalPlanBuilder}; -use datafusion_physical_plan::ExecutionPlan; +use datafusion_common::{Column, DFSchema, Result, ScalarValue}; +use datafusion_expr::{cast, col, lit, Expr, LogicalPlan, UNNAMED_TABLE}; +use execution::{prepare_predicate_actions, write_execution_plan_with_predicate}; use futures::future::BoxFuture; -use futures::StreamExt; -use object_store::prefix::PrefixStore; use parquet::file::properties::WriterProperties; +use schema_evolution::try_cast_schema; use serde::{Deserialize, Serialize}; use tracing::log::*; -use uuid::Uuid; -use super::cdc::should_write_cdc; use super::datafusion_utils::Expression; use super::transaction::{CommitBuilder, CommitProperties, TableReference, PROTOCOL}; -use super::writer::{DeltaWriter, WriterConfig}; use super::{CreateBuilder, CustomExecuteHandler, Operation}; use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::expr::parse_predicate_expression; -use crate::delta_datafusion::{ - find_files, register_store, DeltaScanConfigBuilder, DeltaTableProvider, -}; -use crate::delta_datafusion::{DataFusionMixins, DeltaDataChecker}; +use crate::delta_datafusion::register_store; +use crate::delta_datafusion::DataFusionMixins; use crate::errors::{DeltaResult, DeltaTableError}; -use crate::kernel::{ - Action, ActionType, Add, AddCDCFile, DataCheck, Metadata, PartitionsExt, Remove, StructType, - StructTypeExt, -}; +use crate::kernel::{Action, ActionType, Metadata, StructType, StructTypeExt}; use crate::logstore::LogStoreRef; -use crate::operations::cast::{cast_record_batch, merge_schema::merge_arrow_schema}; +use crate::operations::cast::merge_schema::merge_arrow_schema; use crate::protocol::{DeltaOperation, SaveMode}; -use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; -use crate::table::{Constraint as DeltaConstraint, GeneratedColumn}; -use crate::writer::record_batch::divide_by_partition_values; use crate::DeltaTable; -use tokio::sync::mpsc::Sender; - #[derive(thiserror::Error, Debug)] pub(crate) enum WriteError { #[error("No data source supplied to write command.")] @@ -147,8 +139,6 @@ pub struct WriteBuilder { target_file_size: Option, /// Number of records to be written in single batch to underlying writer write_batch_size: Option, - /// RecordBatches to be written into the table - batches: Option>, /// whether to overwrite the schema or to merge it. None means to fail on schmema drift schema_mode: Option, /// how to handle cast failures, either return NULL (safe=true) or return ERR (safe=false) @@ -203,7 +193,6 @@ impl WriteBuilder { predicate: None, target_file_size: None, write_batch_size: None, - batches: None, safe_cast: false, schema_mode: None, writer_properties: None, @@ -255,12 +244,6 @@ impl WriteBuilder { self } - /// Execution plan that produces the data to be written to the delta table - pub fn with_input_batches(mut self, batches: impl IntoIterator) -> Self { - self.batches = Some(batches.into_iter().collect()); - self - } - /// Specify the target file size for data files written to the delta table. pub fn with_target_file_size(mut self, target_file_size: usize) -> Self { self.target_file_size = Some(target_file_size); @@ -323,6 +306,44 @@ impl WriteBuilder { self } + /// Execution plan that produces the data to be written to the delta table + pub fn with_input_batches(mut self, batches: impl IntoIterator) -> Self { + let ctx = SessionContext::new(); + let batches: Vec = batches.into_iter().collect(); + if !batches.is_empty() { + let table_provider: Arc = + Arc::new(MemTable::try_new(batches[0].schema(), vec![batches]).unwrap()); + let df = ctx.read_table(table_provider).unwrap(); + self.input = Some(Arc::new(df.logical_plan().clone())); + } + self + } + + fn get_partition_columns(&self) -> Result, WriteError> { + // validate partition columns + let active_partitions = self + .snapshot + .as_ref() + .map(|s| s.metadata().partition_columns.clone()); + + if let Some(active_part) = active_partitions { + if let Some(ref partition_columns) = self.partition_columns { + if &active_part != partition_columns { + Err(WriteError::PartitionColumnMismatch { + expected: active_part, + got: partition_columns.to_vec(), + }) + } else { + Ok(partition_columns.clone()) + } + } else { + Ok(active_part) + } + } else { + Ok(self.partition_columns.clone().unwrap_or_default()) + } + } + async fn check_preconditions(&self) -> DeltaResult> { if self.schema_mode == Some(SchemaMode::Overwrite) && self.mode != SaveMode::Overwrite { return Err(DeltaTableError::Generic( @@ -330,28 +351,11 @@ impl WriteBuilder { )); } - let batches: &Vec = match &self.batches { - Some(batches) => { - if batches.is_empty() { - error!("The WriteBuilder was an empty set of batches!"); - return Err(WriteError::MissingData.into()); - } - batches - } - None => { - if self.input.is_none() { - error!("The WriteBuilder must have an input plan _or_ batches!"); - return Err(WriteError::MissingData.into()); - } - // provide an empty array in the case that an input plan exists - &vec![] - } - }; - - let schema: StructType = match &self.input { - Some(plan) => (plan.schema().as_arrow()).try_into()?, - None => (batches[0].schema()).try_into()?, - }; + let input = self + .input + .clone() + .ok_or::(WriteError::MissingData.into())?; + let schema: StructType = input.schema().as_arrow().try_into()?; match &self.snapshot { Some(snapshot) => { @@ -397,384 +401,6 @@ impl WriteBuilder { } } } -/// Configuration for the writer on how to collect stats -#[derive(Clone)] -pub struct WriterStatsConfig { - /// Number of columns to collect stats for, idx based - pub num_indexed_cols: i32, - /// Optional list of columns which to collect stats for, takes precedende over num_index_cols - pub stats_columns: Option>, -} - -impl WriterStatsConfig { - /// Create new writer stats config - pub fn new(num_indexed_cols: i32, stats_columns: Option>) -> Self { - Self { - num_indexed_cols, - stats_columns, - } - } -} - -#[allow(clippy::too_many_arguments)] -async fn write_execution_plan_with_predicate( - predicate: Option, - snapshot: Option<&DeltaTableState>, - state: SessionState, - plan: Arc, - partition_columns: Vec, - object_store: ObjectStoreRef, - target_file_size: Option, - write_batch_size: Option, - writer_properties: Option, - writer_stats_config: WriterStatsConfig, - sender: Option>, -) -> DeltaResult> { - // We always take the plan Schema since the data may contain Large/View arrow types, - // the schema and batches were prior constructed with this in mind. - let schema: ArrowSchemaRef = plan.schema(); - let checker = if let Some(snapshot) = snapshot { - DeltaDataChecker::new(snapshot) - } else { - debug!("Using plan schema to derive generated columns, since no snapshot was provided. Implies first write."); - let delta_schema: StructType = schema.as_ref().try_into()?; - DeltaDataChecker::new_with_generated_columns( - delta_schema.get_generated_columns().unwrap_or_default(), - ) - }; - let checker = match predicate { - Some(pred) => { - // TODO: get the name of the outer-most column? `*` will also work but would it be slower? - let chk = DeltaConstraint::new("*", &fmt_expr_to_sql(&pred)?); - checker.with_extra_constraints(vec![chk]) - } - _ => checker, - }; - // Write data to disk - let mut tasks = vec![]; - for i in 0..plan.properties().output_partitioning().partition_count() { - let inner_plan = plan.clone(); - let inner_schema = schema.clone(); - let task_ctx = Arc::new(TaskContext::from(&state)); - let config = WriterConfig::new( - inner_schema.clone(), - partition_columns.clone(), - writer_properties.clone(), - target_file_size, - write_batch_size, - writer_stats_config.num_indexed_cols, - writer_stats_config.stats_columns.clone(), - ); - let mut writer = DeltaWriter::new(object_store.clone(), config); - let checker_stream = checker.clone(); - let sender_stream = sender.clone(); - let mut stream = inner_plan.execute(i, task_ctx)?; - - let handle: tokio::task::JoinHandle>> = tokio::task::spawn( - async move { - let sendable = sender_stream.clone(); - while let Some(maybe_batch) = stream.next().await { - let batch = maybe_batch?; - - checker_stream.check_batch(&batch).await?; - - if let Some(s) = sendable.as_ref() { - if let Err(e) = s.send(batch.clone()).await { - error!("Failed to send data to observer: {e:#?}"); - } - } else { - debug!("write_execution_plan_with_predicate did not send any batches, no sender."); - } - writer.write(&batch).await?; - } - let add_actions = writer.close().await; - match add_actions { - Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::>()), - Err(err) => Err(err), - } - }, - ); - - tasks.push(handle); - } - let actions = futures::future::join_all(tasks) - .await - .into_iter() - .collect::, _>>() - .map_err(|err| WriteError::WriteTask { source: err })? - .into_iter() - .collect::, _>>()? - .concat() - .into_iter() - .collect::>(); - // Collect add actions to add to commit - Ok(actions) -} - -#[allow(clippy::too_many_arguments)] -pub(crate) async fn write_execution_plan_cdc( - snapshot: Option<&DeltaTableState>, - state: SessionState, - plan: Arc, - partition_columns: Vec, - object_store: ObjectStoreRef, - target_file_size: Option, - write_batch_size: Option, - writer_properties: Option, - writer_stats_config: WriterStatsConfig, - sender: Option>, -) -> DeltaResult> { - let cdc_store = Arc::new(PrefixStore::new(object_store, "_change_data")); - - Ok(write_execution_plan( - snapshot, - state, - plan, - partition_columns, - cdc_store, - target_file_size, - write_batch_size, - writer_properties, - writer_stats_config, - sender, - ) - .await? - .into_iter() - .map(|add| { - // Modify add actions into CDC actions - match add { - Action::Add(add) => { - Action::Cdc(AddCDCFile { - // This is a gnarly hack, but the action needs the nested path, not the - // path isnide the prefixed store - path: format!("_change_data/{}", add.path), - size: add.size, - partition_values: add.partition_values, - data_change: false, - tags: add.tags, - }) - } - _ => panic!("Expected Add action"), - } - }) - .collect::>()) -} - -#[allow(clippy::too_many_arguments)] -pub(crate) async fn write_execution_plan( - snapshot: Option<&DeltaTableState>, - state: SessionState, - plan: Arc, - partition_columns: Vec, - object_store: ObjectStoreRef, - target_file_size: Option, - write_batch_size: Option, - writer_properties: Option, - writer_stats_config: WriterStatsConfig, - sender: Option>, -) -> DeltaResult> { - write_execution_plan_with_predicate( - None, - snapshot, - state, - plan, - partition_columns, - object_store, - target_file_size, - write_batch_size, - writer_properties, - writer_stats_config, - sender, - ) - .await -} - -#[allow(clippy::too_many_arguments)] -async fn execute_non_empty_expr( - snapshot: &DeltaTableState, - log_store: LogStoreRef, - state: SessionState, - partition_columns: Vec, - expression: &Expr, - rewrite: &[Add], - writer_properties: Option, - writer_stats_config: WriterStatsConfig, - partition_scan: bool, - insert_df: DataFrame, - operation_id: Uuid, -) -> DeltaResult> { - // For each identified file perform a parquet scan + filter + limit (1) + count. - // If returned count is not zero then append the file to be rewritten and removed from the log. Otherwise do nothing to the file. - let mut actions: Vec = Vec::new(); - - // Take the insert plan schema since it might have been schema evolved, if its not - // it is simply the table schema - let scan_config = DeltaScanConfigBuilder::new() - .with_schema(snapshot.input_schema()?) - .build(snapshot)?; - - let target_provider = Arc::new( - DeltaTableProvider::try_new(snapshot.clone(), log_store.clone(), scan_config.clone())? - .with_files(rewrite.to_vec()), - ); - - let target_provider = provider_as_source(target_provider); - let source = LogicalPlanBuilder::scan("target", target_provider.clone(), None)?.build()?; - // We don't want to verify the predicate against existing data - - let df = DataFrame::new(state.clone(), source); - - if !partition_scan { - // Apply the negation of the filter and rewrite files - let negated_expression = Expr::Not(Box::new(Expr::IsTrue(Box::new(expression.clone())))); - - let filter = df - .clone() - .filter(negated_expression)? - .create_physical_plan() - .await?; - - let add_actions: Vec = write_execution_plan( - Some(snapshot), - state.clone(), - filter, - partition_columns.clone(), - log_store.object_store(Some(operation_id)), - Some(snapshot.table_config().target_file_size() as usize), - None, - writer_properties.clone(), - writer_stats_config.clone(), - None, - ) - .await?; - - actions.extend(add_actions); - } - - // CDC logic, simply filters data with predicate and adds the _change_type="delete" as literal column - // Only write when CDC actions when it was not a partition scan, load_cdf can deduce the deletes in that case - // based on the remove actions if a partition got deleted - if !partition_scan { - // We only write deletions when it was not a partition scan - if let Some(cdc_actions) = execute_non_empty_expr_cdc( - snapshot, - log_store, - state.clone(), - df, - expression, - partition_columns, - writer_properties, - writer_stats_config, - insert_df, - operation_id, - ) - .await? - { - actions.extend(cdc_actions) - } - } - Ok(actions) -} - -/// If CDC is enabled it writes all the deletions based on predicate into _change_data directory -#[allow(clippy::too_many_arguments)] -pub(crate) async fn execute_non_empty_expr_cdc( - snapshot: &DeltaTableState, - log_store: LogStoreRef, - state: SessionState, - scan: DataFrame, - expression: &Expr, - table_partition_cols: Vec, - writer_properties: Option, - writer_stats_config: WriterStatsConfig, - insert_df: DataFrame, - operation_id: Uuid, -) -> DeltaResult>> { - match should_write_cdc(snapshot) { - // Create CDC scan - Ok(true) => { - let filter = scan.clone().filter(expression.clone())?; - - // Add literal column "_change_type" - let delete_change_type_expr = lit("delete").alias("_change_type"); - - let insert_change_type_expr = lit("insert").alias("_change_type"); - - let delete_df = filter.with_column("_change_type", delete_change_type_expr)?; - - let insert_df = insert_df.with_column("_change_type", insert_change_type_expr)?; - - let cdc_df = delete_df.union(insert_df)?; - - let cdc_actions = write_execution_plan_cdc( - Some(snapshot), - state.clone(), - cdc_df.create_physical_plan().await?, - table_partition_cols.clone(), - log_store.object_store(Some(operation_id)), - Some(snapshot.table_config().target_file_size() as usize), - None, - writer_properties, - writer_stats_config, - None, - ) - .await?; - Ok(Some(cdc_actions)) - } - _ => Ok(None), - } -} - -// This should only be called with a valid predicate -#[allow(clippy::too_many_arguments)] -async fn prepare_predicate_actions( - predicate: Expr, - log_store: LogStoreRef, - snapshot: &DeltaTableState, - state: SessionState, - partition_columns: Vec, - writer_properties: Option, - deletion_timestamp: i64, - writer_stats_config: WriterStatsConfig, - insert_df: DataFrame, - operation_id: Uuid, -) -> DeltaResult> { - let candidates = - find_files(snapshot, log_store.clone(), &state, Some(predicate.clone())).await?; - - let mut actions = execute_non_empty_expr( - snapshot, - log_store, - state, - partition_columns, - &predicate, - &candidates.candidates, - writer_properties, - writer_stats_config, - candidates.partition_scan, - insert_df, - operation_id, - ) - .await?; - - let remove = candidates.candidates; - - for action in remove { - actions.push(Action::Remove(Remove { - path: action.path, - deletion_timestamp: Some(deletion_timestamp), - data_change: true, - extended_file_metadata: Some(true), - partition_values: Some(action.partition_values), - size: Some(action.size), - deletion_vector: action.deletion_vector, - tags: None, - base_row_id: action.base_row_id, - default_row_commit_version: action.default_row_commit_version, - })) - } - Ok(actions) -} impl std::future::IntoFuture for WriteBuilder { type Output = DeltaResult; @@ -791,190 +417,111 @@ impl std::future::IntoFuture for WriteBuilder { let mut metrics = WriteMetrics::default(); let exec_start = Instant::now(); - // Create table actions to initialize table in case it does not yet exist and should be - // created + // Create table actions to initialize table in case it does not yet exist + // and should be created let mut actions = this.check_preconditions().await?; - let active_partitions = this - .snapshot - .as_ref() - .map(|s| s.metadata().partition_columns.clone()); - - // validate partition columns - let partition_columns = if let Some(active_part) = active_partitions { - if let Some(ref partition_columns) = this.partition_columns { - if &active_part != partition_columns { - Err(WriteError::PartitionColumnMismatch { - expected: active_part, - got: partition_columns.to_vec(), - }) - } else { - Ok(partition_columns.clone()) - } - } else { - Ok(active_part) - } - } else { - Ok(this.partition_columns.unwrap_or_default()) - }?; + let partition_columns = this.get_partition_columns()?; let state = match this.state { Some(state) => state, None => { - let ctx = SessionContext::new(); - register_store(this.log_store.clone(), ctx.runtime_env()); - ctx.state() + let state = SessionStateBuilder::new().with_default_features().build(); + register_store(this.log_store.clone(), state.runtime_env().clone()); + state } }; - let generated_col_expressions = this .snapshot .as_ref() .map(|v| v.schema().get_generated_columns().unwrap_or_default()) .unwrap_or_default(); + let mut schema_drift = false; - let mut df = if let Some(plan) = this.input { - if this.schema_mode == Some(SchemaMode::Merge) { - return Err(DeltaTableError::Generic( - "Schema merge not supported yet for Datafusion".to_string(), - )); - } - Ok(DataFrame::new(state.clone(), plan.as_ref().clone())) - } else if let Some(batches) = this.batches { - if batches.is_empty() { - Err(WriteError::MissingData) - } else { - let mut schema = batches[0].schema(); - - // Schema merging code should be aware of columns that can be generated during write - // so they might be empty in the batch, but the will exist in the input_schema() - // in this case we have to insert the generated column and it's type in the schema of the batch - let mut new_schema = None; - if let Some(snapshot) = &this.snapshot { - let table_schema = snapshot.input_schema()?; - - // Merge schema's initial round when there are generated columns expressions - // This is to have the batch schema be the same as the input schema without adding new fields - // from the incoming batch - if !generated_col_expressions.is_empty() { - schema = merge_arrow_schema(table_schema.clone(), schema, true)?; - } + let mut source = DataFrame::new(state.clone(), this.input.unwrap().as_ref().clone()); - if let Err(schema_err) = - try_cast_batch(schema.fields(), table_schema.fields()) - { - schema_drift = true; - if this.mode == SaveMode::Overwrite - && this.schema_mode == Some(SchemaMode::Overwrite) - { - if generated_col_expressions.is_empty() { - new_schema = None // we overwrite anyway, so no need to cast - } else { - new_schema = Some(schema.clone()) // we need to cast the batch to include the generated col as empty null - } - } else if this.schema_mode == Some(SchemaMode::Merge) { - new_schema = Some(merge_arrow_schema( - table_schema.clone(), - schema.clone(), - schema_drift, - )?); - } else { - return Err(schema_err.into()); - } - } else if this.mode == SaveMode::Overwrite - && this.schema_mode == Some(SchemaMode::Overwrite) - { - if generated_col_expressions.is_empty() { - new_schema = None // we overwrite anyway, so no need to cast - } else { - new_schema = Some(schema.clone()) // we need to cast the batch to include the generated col as empty null - } - } else { - // Schema needs to be merged so that utf8/binary/list types are preserved from the batch side if both table - // and batch contains such type. Other types are preserved from the table side. - // At this stage it will never introduce more fields since try_cast_batch passed correctly. - new_schema = Some(merge_arrow_schema( - table_schema.clone(), - schema.clone(), - schema_drift, - )?); - } + // Add missing generated columns to source_df + let (mut source, missing_generated_columns) = + add_missing_generated_columns(source, &generated_col_expressions)?; + + let source_schema: Arc = Arc::new(source.schema().as_arrow().clone()); + + // Schema merging code should be aware of columns that can be generated during write + // so they might be empty in the batch, but the will exist in the input_schema() + // in this case we have to insert the generated column and it's type in the schema of the batch + let mut new_schema = None; + if let Some(snapshot) = &this.snapshot { + let table_schema = snapshot.input_schema()?; + + if let Err(schema_err) = + try_cast_schema(source_schema.fields(), table_schema.fields()) + { + schema_drift = true; + if this.mode == SaveMode::Overwrite + && this.schema_mode == Some(SchemaMode::Overwrite) + { + new_schema = None // we overwrite anyway, so no need to cast + } else if this.schema_mode == Some(SchemaMode::Merge) { + new_schema = Some(merge_arrow_schema( + table_schema.clone(), + source_schema.clone(), + schema_drift, + )?); + } else { + return Err(schema_err.into()); } - let data = if !partition_columns.is_empty() { - // TODO partitioning should probably happen in its own plan ... - let mut partitions: HashMap> = HashMap::new(); - let mut num_partitions = 0; - let mut num_added_rows = 0; - for batch in batches { - let real_batch = match new_schema.clone() { - Some(new_schema) => cast_record_batch( - &batch, - new_schema, - this.safe_cast, - schema_drift || !generated_col_expressions.is_empty(), // Schema drifted so we have to add the missing columns/structfields or missing generated cols.. - )?, - None => batch, - }; - - let divided = divide_by_partition_values( - new_schema.clone().unwrap_or(schema.clone()), - partition_columns.clone(), - &real_batch, - )?; - num_partitions += divided.len(); - for part in divided { - num_added_rows += part.record_batch.num_rows(); - let key = part.partition_values.hive_partition_path(); - match partitions.get_mut(&key) { - Some(part_batches) => { - part_batches.push(part.record_batch); - } - None => { - partitions.insert(key, vec![part.record_batch]); - } - } - } - } - metrics.num_partitions = num_partitions; - metrics.num_added_rows = num_added_rows; - partitions.into_values().collect::>() + } else if this.mode == SaveMode::Overwrite + && this.schema_mode == Some(SchemaMode::Overwrite) + { + new_schema = None // we overwrite anyway, so no need to cast + } else { + // Schema needs to be merged so that utf8/binary/list types are preserved from the batch side if both table + // and batch contains such type. Other types are preserved from the table side. + // At this stage it will never introduce more fields since try_cast_batch passed correctly. + new_schema = Some(merge_arrow_schema( + table_schema.clone(), + source_schema.clone(), + schema_drift, + )?); + } + } + if let Some(new_schema) = new_schema { + let mut schema_evolution_projection = Vec::new(); + for field in new_schema.fields() { + // If field exist in source data, we cast to new datatype + if source_schema.index_of(field.name()).is_ok() { + let cast_expr = cast( + Expr::Column(Column::from_name(field.name())), + // col(field.name()), + field.data_type().clone(), + ) + .alias(field.name()); + schema_evolution_projection.push(cast_expr) + // If field doesn't exist in source data, we insert the column + // with null values } else { - match new_schema { - Some(ref new_schema) => { - let mut new_batches = vec![]; - let mut num_added_rows = 0; - for batch in batches { - new_batches.push(cast_record_batch( - &batch, - new_schema.clone(), - this.safe_cast, - schema_drift || !generated_col_expressions.is_empty(), // Schema drifted so we have to add the missing columns/structfields or missing generated cols. - )?); - num_added_rows += batch.num_rows(); - } - metrics.num_added_rows = num_added_rows; - vec![new_batches] - } - None => { - metrics.num_added_rows = batches.iter().map(|b| b.num_rows()).sum(); - vec![batches] - } - } - }; + schema_evolution_projection.push( + cast( + lit(ScalarValue::Null).alias(field.name()), + field.data_type().clone(), + ) + .alias(field.name()), + ); + } + } + source = source.select(schema_evolution_projection)?; + } - let ctx = SessionContext::new(); - let table_provider: Arc = Arc::new( - MemTable::try_new(new_schema.unwrap_or(schema).clone(), data).unwrap(), - ); - let df = ctx.read_table(table_provider).unwrap(); + source = add_generated_columns( + source, + &generated_col_expressions, + &missing_generated_columns, + &state, + )?; - Ok(df) - } - } else { - Err(WriteError::MissingData) - }?; + let schema = Arc::new(source.schema().as_arrow().clone()); - let schema = Arc::new(df.schema().as_arrow().clone()); + // Maybe create schema action if this.schema_mode == Some(SchemaMode::Merge) && schema_drift { if let Some(snapshot) = &this.snapshot { let schema_struct: StructType = schema.clone().try_into()?; @@ -997,50 +544,6 @@ impl std::future::IntoFuture for WriteBuilder { } } - // Add when.then expr for generated columns - if !generated_col_expressions.is_empty() { - fn create_field( - field: &arrow_schema::Field, - generated_cols_map: &HashMap, - state: &datafusion::execution::session_state::SessionState, - dfschema: &DFSchema, - ) -> DeltaResult { - match generated_cols_map.get(field.name()) { - Some(generated_col) => { - let generation_expr = when( - col(generated_col.get_name()).is_null(), - state.create_logical_expr( - generated_col.get_generation_expression(), - dfschema, - )?, - ) - .otherwise(col(generated_col.get_name()))? - .cast_to( - &arrow_schema::DataType::try_from(&generated_col.data_type)?, - dfschema, - )? - .alias(field.name().to_owned()); - Ok(generation_expr) - } - None => Ok(col(field.name().to_owned())), - } - } - - let dfschema: DFSchema = schema.as_ref().clone().try_into()?; - let generated_cols_map = generated_col_expressions - .into_iter() - .map(|v| (v.name.clone(), v)) - .collect::>(); - let current_fields: DeltaResult> = df - .schema() - .fields() - .into_iter() - .map(|field| create_field(field, &generated_cols_map, &state, &dfschema)) - .collect(); - - df = df.select(current_fields?)?; - }; - let (predicate_str, predicate) = match this.predicate { Some(predicate) => { let pred = match predicate { @@ -1077,7 +580,7 @@ impl std::future::IntoFuture for WriteBuilder { predicate.clone(), this.snapshot.as_ref(), state.clone(), - df.clone().create_physical_plan().await?, + source.clone().create_physical_plan().await?, partition_columns.clone(), this.log_store.object_store(Some(operation_id)).clone(), target_file_size, @@ -1131,7 +634,7 @@ impl std::future::IntoFuture for WriteBuilder { this.writer_properties, deletion_timestamp, writer_stats_config, - df, + source, operation_id, ) .await?; @@ -1193,71 +696,6 @@ impl std::future::IntoFuture for WriteBuilder { } } -fn try_cast_batch(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowError> { - if from_fields.len() != to_fields.len() { - return Err(ArrowError::SchemaError(format!( - "Cannot cast schema, number of fields does not match: {} vs {}", - from_fields.len(), - to_fields.len() - ))); - } - - from_fields - .iter() - .map(|f| { - if let Some((_, target_field)) = to_fields.find(f.name()) { - if let (DataType::Struct(fields0), DataType::Struct(fields1)) = - (f.data_type(), target_field.data_type()) - { - try_cast_batch(fields0, fields1) - } else { - match (f.data_type(), target_field.data_type()) { - ( - DataType::Decimal128(left_precision, left_scale) | DataType::Decimal256(left_precision, left_scale), - DataType::Decimal128(right_precision, right_scale) - ) => { - if left_precision <= right_precision && left_scale <= right_scale { - Ok(()) - } else { - Err(ArrowError::SchemaError(format!( - "Cannot cast field {} from {} to {}", - f.name(), - f.data_type(), - target_field.data_type() - ))) - } - }, - ( - _, - DataType::Decimal256(_, _), - ) => { - unreachable!("Target field can never be Decimal 256. According to the protocol: 'The precision and scale can be up to 38.'") - }, - (left, right) => { - if !can_cast_types(left, right) { - Err(ArrowError::SchemaError(format!( - "Cannot cast field {} from {} to {}", - f.name(), - f.data_type(), - target_field.data_type() - ))) - } else { - Ok(()) - } - } - } - } - } else { - Err(ArrowError::SchemaError(format!( - "Field {} not found in schema", - f.name() - ))) - } - }) - .collect::, _>>()?; - Ok(()) -} - #[cfg(test)] mod tests { use super::*; @@ -1273,8 +711,10 @@ mod tests { }; use crate::TableProperty; use arrow_array::{Int32Array, StringArray, TimestampMicrosecondArray}; - use arrow_schema::{DataType, Field, Schema as ArrowSchema, TimeUnit}; + use arrow_schema::{DataType, Field, Fields, Schema as ArrowSchema, TimeUnit}; + use datafusion::prelude::*; use datafusion::{assert_batches_eq, assert_batches_sorted_eq}; + use datafusion_physical_plan::ExecutionPlan; use itertools::Itertools; use serde_json::{json, Value}; @@ -2362,8 +1802,6 @@ mod tests { /// SMall module to collect test cases which validate the [WriteBuilder]'s /// check_preconditions() function mod check_preconditions_test { - use crate::operations::transaction::TransactionError; - use super::*; #[tokio::test] @@ -2502,46 +1940,5 @@ mod tests { Ok(()) } - - #[tokio::test] - async fn test_max_retries_zero_disables_conflict_checker() { - let table_schema = get_delta_schema(); - let batch = get_record_batch(None, false); - - let table = DeltaOps::new_in_memory() - .create() - .with_columns(table_schema.fields().cloned()) - .await - .unwrap(); - assert_eq!(table.version(), 0); - assert_eq!(table.history(None).await.unwrap().len(), 1); - - let dt_for_conflicting_write = table.clone(); - - // write some data - let table = DeltaOps(table) - .write(vec![batch.clone()]) - .with_save_mode(SaveMode::Append) - .with_commit_properties(CommitProperties::default().with_max_retries(0)) - .await - .unwrap(); - assert_eq!(table.version(), 1); - assert_eq!(table.get_files_count(), 1); - - let dt_for_conflicting_write = DeltaOps(dt_for_conflicting_write) - .write(vec![batch.clone()]) - .with_save_mode(SaveMode::Append) - .with_commit_properties(CommitProperties::default().with_max_retries(0)) - .await; - - assert!(dt_for_conflicting_write.is_err()); - let err = dt_for_conflicting_write.err().unwrap(); - assert!(matches!( - err, - DeltaTableError::Transaction { - source: TransactionError::MaxCommitAttempts(0) - } - )); - } } } diff --git a/crates/core/src/operations/write/schema_evolution/mod.rs b/crates/core/src/operations/write/schema_evolution/mod.rs new file mode 100644 index 0000000000..2de8c1b6cf --- /dev/null +++ b/crates/core/src/operations/write/schema_evolution/mod.rs @@ -0,0 +1,67 @@ +use arrow_cast::can_cast_types; +use arrow_schema::{ArrowError, DataType, Fields}; + +pub(crate) fn try_cast_schema(from_fields: &Fields, to_fields: &Fields) -> Result<(), ArrowError> { + if from_fields.len() != to_fields.len() { + return Err(ArrowError::SchemaError(format!( + "Cannot cast schema, number of fields does not match: {} vs {}", + from_fields.len(), + to_fields.len() + ))); + } + + from_fields + .iter() + .map(|f| { + if let Some((_, target_field)) = to_fields.find(f.name()) { + if let (DataType::Struct(fields0), DataType::Struct(fields1)) = + (f.data_type(), target_field.data_type()) + { + try_cast_schema(fields0, fields1) + } else { + match (f.data_type(), target_field.data_type()) { + ( + DataType::Decimal128(left_precision, left_scale) | DataType::Decimal256(left_precision, left_scale), + DataType::Decimal128(right_precision, right_scale) + ) => { + if left_precision <= right_precision && left_scale <= right_scale { + Ok(()) + } else { + Err(ArrowError::SchemaError(format!( + "Cannot cast field {} from {} to {}", + f.name(), + f.data_type(), + target_field.data_type() + ))) + } + }, + ( + _, + DataType::Decimal256(_, _), + ) => { + unreachable!("Target field can never be Decimal 256. According to the protocol: 'The precision and scale can be up to 38.'") + }, + (left, right) => { + if !can_cast_types(left, right) { + Err(ArrowError::SchemaError(format!( + "Cannot cast field {} from {} to {}", + f.name(), + f.data_type(), + target_field.data_type() + ))) + } else { + Ok(()) + } + } + } + } + } else { + Err(ArrowError::SchemaError(format!( + "Field {} not found in schema", + f.name() + ))) + } + }) + .collect::, _>>()?; + Ok(()) +} diff --git a/python/src/lib.rs b/python/src/lib.rs index 0652f5581b..4a32a3835d 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -6,7 +6,6 @@ mod query; mod schema; mod utils; -use core::num; use std::cmp::min; use std::collections::{HashMap, HashSet}; use std::ffi::CString; @@ -74,7 +73,6 @@ use crate::query::PyQueryBuilder; use crate::schema::{schema_to_pyobject, Field}; use crate::utils::rt; use deltalake::operations::update_field_metadata::UpdateFieldMetadataBuilder; -use deltalake::protocol::DeltaOperation::UpdateFieldMetadata; use pyo3::exceptions::{PyRuntimeError, PyValueError}; use pyo3::prelude::*; use pyo3::pybacked::PyBackedStr; @@ -85,7 +83,7 @@ use uuid::Uuid; #[cfg(all(target_family = "unix", not(target_os = "emscripten")))] use jemallocator::Jemalloc; -#[cfg(all(any(not(target_family = "unix"), target_os = "emscripten")))] +#[cfg(any(not(target_family = "unix"), target_os = "emscripten"))] use mimalloc::MiMalloc; #[global_allocator] @@ -93,7 +91,7 @@ use mimalloc::MiMalloc; static ALLOC: Jemalloc = Jemalloc; #[global_allocator] -#[cfg(all(any(not(target_family = "unix"), target_os = "emscripten")))] +#[cfg(any(not(target_family = "unix"), target_os = "emscripten"))] static ALLOC: MiMalloc = MiMalloc; #[derive(FromPyObject)] diff --git a/python/tests/test_writer.py b/python/tests/test_writer.py index cbf40dbfd1..16479b5e20 100644 --- a/python/tests/test_writer.py +++ b/python/tests/test_writer.py @@ -272,7 +272,7 @@ def test_write_type_castable_types(existing_table: DeltaTable): ) with pytest.raises( Exception, - match="Cast error: Failed to cast int8 from Int8 to Utf8: Cannot cast string 'hello' to value of Int8 type", + match="Cast error: Cannot cast string 'hello' to value of Int8 type", ): write_deltalake( existing_table, @@ -284,7 +284,7 @@ def test_write_type_castable_types(existing_table: DeltaTable): with pytest.raises( Exception, - match="Cast error: Failed to cast int8 from Int8 to Int64: Can't cast value 1000 to type Int8", + match="Cast error: Can't cast value 1000 to type Int8", ): write_deltalake( existing_table, @@ -1035,7 +1035,6 @@ def test_partition_overwrite( tmp_path, sample_data, mode="overwrite", predicate=f"p2 < {filter_string}" ) - @pytest.fixture() def sample_data_for_partitioning() -> pa.Table: return pa.table(