From 382d9e5ec06ac4b176fb4b68dcca04e0ca2e6ec4 Mon Sep 17 00:00:00 2001 From: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Date: Sun, 16 Feb 2025 22:17:49 +0100 Subject: [PATCH] feat: write metrics extension planner Signed-off-by: Ion Koutsouris <15728914+ion-elgreco@users.noreply.github.com> Signed-off-by: Liam Brannigan --- crates/core/src/operations/delete.rs | 2 - crates/core/src/operations/update.rs | 2 - crates/core/src/operations/write/execution.rs | 27 ++-------- crates/core/src/operations/write/metrics.rs | 45 ++++++++++++++++ crates/core/src/operations/write/mod.rs | 54 ++++++++++++++----- crates/core/tests/integration_datafusion.rs | 1 + python/src/lib.rs | 3 +- .../write/lazy.rs => python/src/writer.rs | 20 +++---- 8 files changed, 102 insertions(+), 52 deletions(-) create mode 100644 crates/core/src/operations/write/metrics.rs rename crates/core/src/operations/write/lazy.rs => python/src/writer.rs (74%) diff --git a/crates/core/src/operations/delete.rs b/crates/core/src/operations/delete.rs index ecf27819a9..c6d943e1c0 100644 --- a/crates/core/src/operations/delete.rs +++ b/crates/core/src/operations/delete.rs @@ -259,7 +259,6 @@ async fn execute_non_empty_expr( None, writer_properties.clone(), writer_stats_config.clone(), - None, ) .await?; @@ -296,7 +295,6 @@ async fn execute_non_empty_expr( None, writer_properties, writer_stats_config, - None, ) .await?; actions.extend(cdc_actions) diff --git a/crates/core/src/operations/update.rs b/crates/core/src/operations/update.rs index eba990014e..8f7a15f6ce 100644 --- a/crates/core/src/operations/update.rs +++ b/crates/core/src/operations/update.rs @@ -399,7 +399,6 @@ async fn execute( None, writer_properties.clone(), writer_stats_config.clone(), - None, ) .await?; @@ -462,7 +461,6 @@ async fn execute( None, writer_properties, writer_stats_config, - None, ) .await?; actions.extend(cdc_actions); diff --git a/crates/core/src/operations/write/execution.rs b/crates/core/src/operations/write/execution.rs index cdf57a56d0..3a5796d964 100644 --- a/crates/core/src/operations/write/execution.rs +++ b/crates/core/src/operations/write/execution.rs @@ -1,7 +1,6 @@ use std::sync::Arc; use std::vec; -use arrow_array::RecordBatch; use arrow_schema::SchemaRef as ArrowSchemaRef; use datafusion::datasource::provider_as_source; use datafusion::execution::context::{SessionState, TaskContext}; @@ -25,7 +24,6 @@ use crate::operations::writer::{DeltaWriter, WriterConfig}; use crate::storage::ObjectStoreRef; use crate::table::state::DeltaTableState; use crate::table::Constraint as DeltaConstraint; -use tokio::sync::mpsc::Sender; use super::configs::WriterStatsConfig; use super::WriteError; @@ -42,7 +40,6 @@ pub(crate) async fn write_execution_plan_with_predicate( write_batch_size: Option, writer_properties: Option, writer_stats_config: WriterStatsConfig, - sender: Option>, ) -> DeltaResult> { // We always take the plan Schema since the data may contain Large/View arrow types, // the schema and batches were prior constructed with this in mind. @@ -81,24 +78,13 @@ pub(crate) async fn write_execution_plan_with_predicate( ); let mut writer = DeltaWriter::new(object_store.clone(), config); let checker_stream = checker.clone(); - let sender_stream = sender.clone(); let mut stream = inner_plan.execute(i, task_ctx)?; - let handle: tokio::task::JoinHandle>> = tokio::task::spawn( - async move { - let sendable = sender_stream.clone(); + let handle: tokio::task::JoinHandle>> = + tokio::task::spawn(async move { while let Some(maybe_batch) = stream.next().await { let batch = maybe_batch?; - checker_stream.check_batch(&batch).await?; - - if let Some(s) = sendable.as_ref() { - if let Err(e) = s.send(batch.clone()).await { - error!("Failed to send data to observer: {e:#?}"); - } - } else { - debug!("write_execution_plan_with_predicate did not send any batches, no sender."); - } writer.write(&batch).await?; } let add_actions = writer.close().await; @@ -106,8 +92,7 @@ pub(crate) async fn write_execution_plan_with_predicate( Ok(actions) => Ok(actions.into_iter().map(Action::Add).collect::>()), Err(err) => Err(err), } - }, - ); + }); tasks.push(handle); } @@ -136,7 +121,6 @@ pub(crate) async fn write_execution_plan_cdc( write_batch_size: Option, writer_properties: Option, writer_stats_config: WriterStatsConfig, - sender: Option>, ) -> DeltaResult> { let cdc_store = Arc::new(PrefixStore::new(object_store, "_change_data")); @@ -150,7 +134,6 @@ pub(crate) async fn write_execution_plan_cdc( write_batch_size, writer_properties, writer_stats_config, - sender, ) .await? .into_iter() @@ -185,7 +168,6 @@ pub(crate) async fn write_execution_plan( write_batch_size: Option, writer_properties: Option, writer_stats_config: WriterStatsConfig, - sender: Option>, ) -> DeltaResult> { write_execution_plan_with_predicate( None, @@ -198,7 +180,6 @@ pub(crate) async fn write_execution_plan( write_batch_size, writer_properties, writer_stats_config, - sender, ) .await } @@ -258,7 +239,6 @@ pub(crate) async fn execute_non_empty_expr( None, writer_properties.clone(), writer_stats_config.clone(), - None, ) .await?; @@ -330,7 +310,6 @@ pub(crate) async fn execute_non_empty_expr_cdc( None, writer_properties, writer_stats_config, - None, ) .await?; Ok(Some(cdc_actions)) diff --git a/crates/core/src/operations/write/metrics.rs b/crates/core/src/operations/write/metrics.rs new file mode 100644 index 0000000000..bfb5072cb8 --- /dev/null +++ b/crates/core/src/operations/write/metrics.rs @@ -0,0 +1,45 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use datafusion::{ + execution::SessionState, + physical_planner::{ExtensionPlanner, PhysicalPlanner}, +}; +use datafusion_common::Result as DataFusionResult; +use datafusion_expr::{LogicalPlan, UserDefinedLogicalNode}; +use datafusion_physical_plan::{metrics::MetricBuilder, ExecutionPlan}; + +use crate::delta_datafusion::{logical::MetricObserver, physical::MetricObserverExec}; + +pub(crate) const SOURCE_COUNT_ID: &str = "write_source_count"; +pub(crate) const SOURCE_COUNT_METRIC: &str = "num_source_rows"; + +#[derive(Clone, Debug)] +pub(crate) struct WriteMetricExtensionPlanner {} + +#[async_trait] +impl ExtensionPlanner for WriteMetricExtensionPlanner { + async fn plan_extension( + &self, + _planner: &dyn PhysicalPlanner, + node: &dyn UserDefinedLogicalNode, + _logical_inputs: &[&LogicalPlan], + physical_inputs: &[Arc], + _session_state: &SessionState, + ) -> DataFusionResult>> { + if let Some(metric_observer) = node.as_any().downcast_ref::() { + if metric_observer.id.eq(SOURCE_COUNT_ID) { + return Ok(Some(MetricObserverExec::try_new( + SOURCE_COUNT_ID.into(), + physical_inputs, + |batch, metrics| { + MetricBuilder::new(metrics) + .global_counter(SOURCE_COUNT_METRIC) + .add(batch.num_rows()); + }, + )?)); + } + } + Ok(None) + } +} diff --git a/crates/core/src/operations/write/mod.rs b/crates/core/src/operations/write/mod.rs index 74061b2c57..02ef8b94ef 100644 --- a/crates/core/src/operations/write/mod.rs +++ b/crates/core/src/operations/write/mod.rs @@ -26,13 +26,14 @@ pub mod configs; pub(crate) mod execution; pub(crate) mod generated_columns; -pub mod lazy; +pub(crate) mod metrics; pub(crate) mod schema_evolution; use arrow_schema::Schema; pub use configs::WriterStatsConfig; use datafusion::execution::SessionStateBuilder; use generated_columns::{add_generated_columns, add_missing_generated_columns}; +use metrics::{WriteMetricExtensionPlanner, SOURCE_COUNT_ID, SOURCE_COUNT_METRIC}; use std::collections::HashMap; use std::str::FromStr; use std::sync::Arc; @@ -45,7 +46,7 @@ use datafusion::datasource::MemTable; use datafusion::execution::context::{SessionContext, SessionState}; use datafusion::prelude::DataFrame; use datafusion_common::{Column, DFSchema, Result, ScalarValue}; -use datafusion_expr::{cast, lit, Expr, LogicalPlan}; +use datafusion_expr::{cast, lit, try_cast, Expr, Extension, LogicalPlan}; use execution::{prepare_predicate_actions, write_execution_plan_with_predicate}; use futures::future::BoxFuture; use parquet::file::properties::WriterProperties; @@ -58,6 +59,9 @@ use super::transaction::{CommitBuilder, CommitProperties, TableReference, PROTOC use super::{CreateBuilder, CustomExecuteHandler, Operation}; use crate::delta_datafusion::expr::fmt_expr_to_sql; use crate::delta_datafusion::expr::parse_predicate_expression; +use crate::delta_datafusion::logical::MetricObserver; +use crate::delta_datafusion::physical::{find_metric_node, get_metric}; +use crate::delta_datafusion::planner::DeltaPlanner; use crate::delta_datafusion::register_store; use crate::delta_datafusion::DataFusionMixins; use crate::errors::{DeltaResult, DeltaTableError}; @@ -418,6 +422,10 @@ impl std::future::IntoFuture for WriteBuilder { let mut metrics = WriteMetrics::default(); let exec_start = Instant::now(); + let write_planner = DeltaPlanner:: { + extension_planner: WriteMetricExtensionPlanner {}, + }; + // Create table actions to initialize table in case it does not yet exist // and should be created let mut actions = this.check_preconditions().await?; @@ -425,9 +433,14 @@ impl std::future::IntoFuture for WriteBuilder { let partition_columns = this.get_partition_columns()?; let state = match this.state { - Some(state) => state, + Some(state) => SessionStateBuilder::new_from_existing(state.clone()) + .with_query_planner(Arc::new(write_planner)) + .build(), None => { - let state = SessionStateBuilder::new().with_default_features().build(); + let state = SessionStateBuilder::new() + .with_default_features() + .with_query_planner(Arc::new(write_planner)) + .build(); register_store(this.log_store.clone(), state.runtime_env().clone()); state } @@ -491,7 +504,8 @@ impl std::future::IntoFuture for WriteBuilder { for field in new_schema.fields() { // If field exist in source data, we cast to new datatype if source_schema.index_of(field.name()).is_ok() { - let cast_expr = cast( + let cast_fn = if this.safe_cast { try_cast } else { cast }; + let cast_expr = cast_fn( Expr::Column(Column::from_name(field.name())), // col(field.name()), field.data_type().clone(), @@ -520,6 +534,16 @@ impl std::future::IntoFuture for WriteBuilder { &state, )?; + let source = LogicalPlan::Extension(Extension { + node: Arc::new(MetricObserver { + id: "write_source_count".into(), + input: source.logical_plan().clone(), + enable_pushdown: false, + }), + }); + + let source = DataFrame::new(state.clone(), source); + let schema = Arc::new(source.schema().as_arrow().clone()); // Maybe create schema action @@ -576,21 +600,31 @@ impl std::future::IntoFuture for WriteBuilder { stats_columns, }; + let source_plan = source.clone().create_physical_plan().await?; + // Here we need to validate if the new data conforms to a predicate if one is provided let add_actions = write_execution_plan_with_predicate( predicate.clone(), this.snapshot.as_ref(), state.clone(), - source.clone().create_physical_plan().await?, + source_plan.clone(), partition_columns.clone(), this.log_store.object_store(Some(operation_id)).clone(), target_file_size, this.write_batch_size, this.writer_properties.clone(), writer_stats_config.clone(), - None, ) .await?; + + let source_count = + find_metric_node(SOURCE_COUNT_ID, &source_plan).ok_or_else(|| { + DeltaTableError::Generic("Unable to locate expected metric node".into()) + })?; + let source_count_metrics = source_count.metrics().unwrap(); + let num_added_rows = get_metric(&source_count_metrics, SOURCE_COUNT_METRIC); + metrics.num_added_rows = num_added_rows; + metrics.num_added_files = add_actions.len(); actions.extend(add_actions); @@ -989,7 +1023,6 @@ mod tests { assert_eq!(table.version(), 0); assert_eq!(table.get_files_count(), 2); let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; - assert!(write_metrics.num_partitions > 0); assert_eq!(write_metrics.num_added_files, 2); assert_common_write_metrics(write_metrics); @@ -1003,7 +1036,6 @@ mod tests { assert_eq!(table.get_files_count(), 4); let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; - assert!(write_metrics.num_partitions > 0); assert_eq!(write_metrics.num_added_files, 4); assert_common_write_metrics(write_metrics); } @@ -1093,7 +1125,6 @@ mod tests { assert_eq!(table.version(), 0); let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; - assert!(write_metrics.num_partitions > 0); assert_common_write_metrics(write_metrics); let mut new_schema_builder = arrow_schema::SchemaBuilder::new(); @@ -1146,7 +1177,6 @@ mod tests { assert_eq!(part_cols, vec!["id", "value"]); // we want to preserve partitions let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; - assert!(write_metrics.num_partitions > 0); assert_common_write_metrics(write_metrics); } @@ -1668,7 +1698,6 @@ mod tests { assert_eq!(table.version(), 1); let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; assert_eq!(write_metrics.num_added_rows, 3); - assert!(write_metrics.num_partitions > 0); assert_common_write_metrics(write_metrics); let table = DeltaOps(table) @@ -1680,7 +1709,6 @@ mod tests { assert_eq!(table.version(), 2); let write_metrics: WriteMetrics = get_write_metrics(table.clone()).await; assert_eq!(write_metrics.num_added_rows, 1); - assert!(write_metrics.num_partitions > 0); assert!(write_metrics.num_removed_files > 0); assert_common_write_metrics(write_metrics); diff --git a/crates/core/tests/integration_datafusion.rs b/crates/core/tests/integration_datafusion.rs index e55145accb..bf58d8cf29 100644 --- a/crates/core/tests/integration_datafusion.rs +++ b/crates/core/tests/integration_datafusion.rs @@ -1119,6 +1119,7 @@ mod local { let _ = write_builder .with_input_execution_plan(plan) .with_save_mode(SaveMode::Overwrite) + .with_schema_mode(deltalake_core::operations::write::SchemaMode::Overwrite) .await .unwrap(); diff --git a/python/src/lib.rs b/python/src/lib.rs index 1792785a6f..d1ff324a23 100644 --- a/python/src/lib.rs +++ b/python/src/lib.rs @@ -5,6 +5,7 @@ mod merge; mod query; mod schema; mod utils; +mod writer; use std::cmp::min; use std::collections::{HashMap, HashSet}; @@ -2179,9 +2180,9 @@ fn write_to_deltalake( ); builder = builder.with_input_batches(data.0.map(|batch| batch.unwrap())); } else { + use crate::writer::to_lazy_table; use deltalake::datafusion::datasource::provider_as_source; use deltalake::datafusion::logical_expr::LogicalPlanBuilder; - use deltalake::operations::write::lazy::to_lazy_table; let table_provider = to_lazy_table(data.0).map_err(PythonError::from)?; let plan = LogicalPlanBuilder::scan("source", provider_as_source(table_provider), None) diff --git a/crates/core/src/operations/write/lazy.rs b/python/src/writer.rs similarity index 74% rename from crates/core/src/operations/write/lazy.rs rename to python/src/writer.rs index a16f579d65..0e8c20cf8f 100644 --- a/crates/core/src/operations/write/lazy.rs +++ b/python/src/writer.rs @@ -1,19 +1,17 @@ //! This module contains helper functions to create a LazyTableProvider from an ArrowArrayStreamReader use crate::DeltaResult; -use arrow::ffi_stream::ArrowArrayStreamReader; -use datafusion::catalog::TableProvider; -use datafusion::physical_plan::memory::LazyBatchGenerator; -use delta_datafusion::LazyTableProvider; +use deltalake::arrow::ffi_stream::ArrowArrayStreamReader; +use deltalake::datafusion::catalog::TableProvider; +use deltalake::datafusion::physical_plan::memory::LazyBatchGenerator; +use deltalake::delta_datafusion::LazyTableProvider; use parking_lot::RwLock; use std::fmt::{self}; use std::sync::{Arc, Mutex}; -use crate::delta_datafusion; - /// Convert an [ArrowArrayStreamReader] into a [LazyTableProvider] pub fn to_lazy_table(source: ArrowArrayStreamReader) -> DeltaResult> { - use arrow::array::RecordBatchReader; + use deltalake::arrow::array::RecordBatchReader; let schema = source.schema(); let arrow_stream: Arc> = Arc::new(Mutex::new(source)); let arrow_stream_batch_generator: Arc> = @@ -49,16 +47,18 @@ impl ArrowStreamBatchGenerator { impl LazyBatchGenerator for ArrowStreamBatchGenerator { fn generate_next_batch( &mut self, - ) -> datafusion::error::Result> { + ) -> deltalake::datafusion::error::Result> { let mut stream_reader = self.array_stream.lock().map_err(|_| { - datafusion::error::DataFusionError::Execution( + deltalake::datafusion::error::DataFusionError::Execution( "Failed to lock the ArrowArrayStreamReader".to_string(), ) })?; match stream_reader.next() { Some(Ok(record_batch)) => Ok(Some(record_batch)), - Some(Err(err)) => Err(datafusion::error::DataFusionError::ArrowError(err, None)), + Some(Err(err)) => Err(deltalake::datafusion::error::DataFusionError::ArrowError( + err, None, + )), None => Ok(None), // End of stream } }