diff --git a/src/daft-logical-plan/src/builder.rs b/src/daft-logical-plan/src/builder.rs index 937fb45f44..cd24cb9e5b 100644 --- a/src/daft-logical-plan/src/builder.rs +++ b/src/daft-logical-plan/src/builder.rs @@ -10,7 +10,7 @@ use common_file_formats::FileFormat; use common_io_config::IOConfig; use common_scan_info::{PhysicalScanInfo, Pushdowns, ScanOperatorRef}; use daft_core::join::{JoinStrategy, JoinType}; -use daft_dsl::{col, ExprRef}; +use daft_dsl::{col, ExprRef, ExprResolver}; use daft_schema::schema::{Schema, SchemaRef}; #[cfg(feature = "python")] use { @@ -188,11 +188,19 @@ impl LogicalPlanBuilder { } pub fn select(&self, to_select: Vec) -> DaftResult { + let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); + + let (to_select, _) = expr_resolver.resolve(to_select, &self.schema())?; + let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), to_select)?.into(); Ok(self.with_new_plan(logical_plan)) } pub fn with_columns(&self, columns: Vec) -> DaftResult { + let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); + + let (columns, _) = expr_resolver.resolve(columns, &self.schema())?; + let fields = &self.schema().fields; let current_col_names = fields .iter() @@ -245,6 +253,10 @@ impl LogicalPlanBuilder { } pub fn filter(&self, predicate: ExprRef) -> DaftResult { + let expr_resolver = ExprResolver::default(); + + let (predicate, _) = expr_resolver.resolve_single(predicate, &self.schema())?; + let logical_plan: LogicalPlan = ops::Filter::try_new(self.plan.clone(), predicate)?.into(); Ok(self.with_new_plan(logical_plan)) } @@ -438,17 +450,35 @@ impl LogicalPlanBuilder { join_prefix: Option<&str>, keep_join_keys: bool, ) -> DaftResult { + let left_plan = self.plan.clone(); + let right_plan = right.into(); + + let expr_resolver = ExprResolver::default(); + + let (left_on, _) = expr_resolver.resolve(left_on, &left_plan.schema())?; + let (right_on, _) = expr_resolver.resolve(right_on, &right_plan.schema())?; + + let (left_on, right_on) = ops::Join::rename_join_keys(left_on, right_on); + + let (right_plan, right_on) = ops::Join::rename_right_columns( + left_plan.clone(), + right_plan, + left_on.clone(), + right_on, + join_type, + join_suffix, + join_prefix, + keep_join_keys, + )?; + let logical_plan: LogicalPlan = ops::Join::try_new( - self.plan.clone(), - right.into(), + left_plan, + right_plan, left_on, right_on, null_equals_nulls, join_type, join_strategy, - join_suffix, - join_prefix, - keep_join_keys, )? .into(); Ok(self.with_new_plan(logical_plan)) diff --git a/src/daft-logical-plan/src/logical_plan.rs b/src/daft-logical-plan/src/logical_plan.rs index 4abd15fcaa..8b502b6ad1 100644 --- a/src/daft-logical-plan/src/logical_plan.rs +++ b/src/daft-logical-plan/src/logical_plan.rs @@ -361,9 +361,6 @@ impl LogicalPlan { null_equals_nulls.clone(), *join_type, *join_strategy, - None, // The suffix is already eagerly computed in the constructor - None, // the prefix is already eagerly computed in the constructor - false // this is already eagerly computed in the constructor ).unwrap()), _ => panic!("Logical op {} has one input, but got two", self), }, diff --git a/src/daft-logical-plan/src/ops/filter.rs b/src/daft-logical-plan/src/ops/filter.rs index 2a046b66e7..47fcb6b7ad 100644 --- a/src/daft-logical-plan/src/ops/filter.rs +++ b/src/daft-logical-plan/src/ops/filter.rs @@ -2,7 +2,7 @@ use std::sync::Arc; use common_error::DaftError; use daft_core::prelude::*; -use daft_dsl::{estimated_selectivity, ExprRef, ExprResolver}; +use daft_dsl::{estimated_selectivity, ExprRef}; use snafu::ResultExt; use crate::{ @@ -22,16 +22,12 @@ pub struct Filter { impl Filter { pub(crate) fn try_new(input: Arc, predicate: ExprRef) -> Result { - let expr_resolver = ExprResolver::default(); + let dtype = predicate.to_field(&input.schema())?.dtype; - let (predicate, field) = expr_resolver - .resolve_single(predicate, &input.schema()) - .context(CreationSnafu)?; - - if !matches!(field.dtype, DataType::Boolean) { + if !matches!(dtype, DataType::Boolean) { return Err(DaftError::ValueError(format!( "Expected expression {predicate} to resolve to type Boolean, but received: {}", - field.dtype + dtype ))) .context(CreationSnafu); } diff --git a/src/daft-logical-plan/src/ops/join.rs b/src/daft-logical-plan/src/ops/join.rs index 18dede0720..c51e9553f1 100644 --- a/src/daft-logical-plan/src/ops/join.rs +++ b/src/daft-logical-plan/src/ops/join.rs @@ -9,7 +9,7 @@ use daft_dsl::{ col, join::{get_common_join_keys, infer_join_schema}, optimization::replace_columns_with_expressions, - Expr, ExprRef, ExprResolver, + Expr, ExprRef, }; use itertools::Itertools; use snafu::ResultExt; @@ -19,7 +19,7 @@ use crate::{ logical_plan::{self, CreationSnafu}, ops::Project, stats::{ApproxStats, PlanStats, StatsState}, - LogicalPlan, + LogicalPlan, LogicalPlanRef, }; #[derive(Clone, Debug, PartialEq, Eq)] @@ -51,30 +51,6 @@ impl std::hash::Hash for Join { } impl Join { - #[allow(clippy::too_many_arguments)] - pub(crate) fn new( - left: Arc, - right: Arc, - left_on: Vec, - right_on: Vec, - null_equals_nulls: Option>, - join_type: JoinType, - join_strategy: Option, - output_schema: SchemaRef, - ) -> Self { - Self { - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - output_schema, - stats_state: StatsState::NotMaterialized, - } - } - #[allow(clippy::too_many_arguments)] pub(crate) fn try_new( left: Arc, @@ -84,45 +60,11 @@ impl Join { null_equals_nulls: Option>, join_type: JoinType, join_strategy: Option, - join_suffix: Option<&str>, - join_prefix: Option<&str>, - // if true, then duplicate column names will be kept - // ex: select * from a left join b on a.id = b.id - // if true, then the resulting schema will have two columns named id (id, and b.id) - // In SQL the join column is always kept, while in dataframes it is not - keep_join_keys: bool, ) -> logical_plan::Result { - let expr_resolver = ExprResolver::default(); - - let (left_on, _) = expr_resolver - .resolve(left_on, &left.schema()) - .context(CreationSnafu)?; - let (right_on, _) = expr_resolver - .resolve(right_on, &right.schema()) - .context(CreationSnafu)?; - - let (unique_left_on, unique_right_on) = - Self::rename_join_keys(left_on.clone(), right_on.clone()); - - let left_fields: Vec = unique_left_on - .iter() - .map(|e| e.to_field(&left.schema())) - .collect::>>() - .context(CreationSnafu)?; - - let right_fields: Vec = unique_right_on - .iter() - .map(|e| e.to_field(&right.schema())) - .collect::>>() - .context(CreationSnafu)?; - - for (on_exprs, on_fields) in [ - (&unique_left_on, &left_fields), - (&unique_right_on, &right_fields), - ] { - for (field, expr) in on_fields.iter().zip(on_exprs.iter()) { + for (on_exprs, side) in [(&left_on, &left), (&right_on, &right)] { + for expr in on_exprs { // Null type check for both fields and expressions - if matches!(field.dtype, DataType::Null) { + if matches!(expr.to_field(&side.schema())?.dtype, DataType::Null) { return Err(DaftError::ValueError(format!( "Can't join on null type expressions: {expr}" ))) @@ -141,22 +83,41 @@ impl Join { } } - if matches!(join_type, JoinType::Anti | JoinType::Semi) { - // The output schema is the same as the left input schema for anti and semi joins. + let output_schema = infer_join_schema( + &left.schema(), + &right.schema(), + &left_on, + &right_on, + join_type, + ) + .context(CreationSnafu)?; - let output_schema = left.schema(); + Ok(Self { + left, + right, + left_on, + right_on, + null_equals_nulls, + join_type, + join_strategy, + output_schema, + stats_state: StatsState::NotMaterialized, + }) + } - Ok(Self { - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - output_schema, - stats_state: StatsState::NotMaterialized, - }) + #[allow(clippy::too_many_arguments)] + pub(crate) fn rename_right_columns( + left: LogicalPlanRef, + right: LogicalPlanRef, + left_on: Vec, + right_on: Vec, + join_type: JoinType, + join_suffix: Option<&str>, + join_prefix: Option<&str>, + keep_join_keys: bool, + ) -> DaftResult<(LogicalPlanRef, Vec)> { + if matches!(join_type, JoinType::Anti | JoinType::Semi) { + Ok((right, right_on)) } else { let common_join_keys: HashSet<_> = get_common_join_keys(left_on.as_slice(), right_on.as_slice()) @@ -202,8 +163,8 @@ impl Join { }) .collect(); - let (right, right_on) = if right_rename_mapping.is_empty() { - (right, right_on) + if right_rename_mapping.is_empty() { + Ok((right, right_on)) } else { // projection to update the right side with the new column names let new_right_projection: Vec<_> = right_names @@ -230,29 +191,8 @@ impl Join { .map(|expr| replace_columns_with_expressions(expr, &right_on_replace_map)) .collect::>(); - (new_right.into(), new_right_on) - }; - - let output_schema = infer_join_schema( - &left.schema(), - &right.schema(), - &left_on, - &right_on, - join_type, - ) - .context(CreationSnafu)?; - - Ok(Self { - left, - right, - left_on, - right_on, - null_equals_nulls, - join_type, - join_strategy, - output_schema, - stats_state: StatsState::NotMaterialized, - }) + Ok((new_right.into(), new_right_on)) + } } } @@ -283,7 +223,7 @@ impl Join { /// /// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649). - fn rename_join_keys( + pub(crate) fn rename_join_keys( left_exprs: Vec>, right_exprs: Vec>, ) -> (Vec>, Vec>) { diff --git a/src/daft-logical-plan/src/ops/project.rs b/src/daft-logical-plan/src/ops/project.rs index 165d989a09..4f05d677c3 100644 --- a/src/daft-logical-plan/src/ops/project.rs +++ b/src/daft-logical-plan/src/ops/project.rs @@ -1,8 +1,9 @@ use std::sync::Arc; +use common_error::DaftResult; use common_treenode::Transformed; use daft_core::prelude::*; -use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef, ExprResolver}; +use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef}; use indexmap::{IndexMap, IndexSet}; use itertools::Itertools; use snafu::ResultExt; @@ -24,16 +25,15 @@ pub struct Project { impl Project { pub(crate) fn try_new(input: Arc, projection: Vec) -> Result { - let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build(); - - let (projection, fields) = expr_resolver - .resolve(projection, &input.schema()) - .context(CreationSnafu)?; - // Factor the projection and see if there are any substitutions to factor out. let (factored_input, factored_projection) = Self::try_factor_subexpressions(input, projection)?; + let fields = factored_projection + .iter() + .map(|expr| expr.to_field(&factored_input.schema())) + .collect::>()?; + let projected_schema = Schema::new(fields).context(CreationSnafu)?.into(); Ok(Self { diff --git a/src/daft-logical-plan/src/ops/set_operations.rs b/src/daft-logical-plan/src/ops/set_operations.rs index 42009182b6..43c02c4625 100644 --- a/src/daft-logical-plan/src/ops/set_operations.rs +++ b/src/daft-logical-plan/src/ops/set_operations.rs @@ -47,9 +47,6 @@ fn intersect_or_except_plan( Some(vec![true; left_on_size]), join_type, None, - None, - None, - false, ); join.map(|j| Distinct::new(j.into()).into()) } diff --git a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs index e9e3a2e524..d478757a35 100644 --- a/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs +++ b/src/daft-logical-plan/src/optimization/rules/eliminate_cross_join.rs @@ -303,12 +303,12 @@ fn find_inner_join( if !join_keys.is_empty() { all_join_keys.insert_all(join_keys.iter()); let right_input = rights.remove(i); - let join_schema = left_input - .schema() - .non_distinct_union(right_input.schema().as_ref()); let (left_keys, right_keys) = join_keys.iter().cloned().unzip(); - return Ok(LogicalPlan::Join(Join::new( + + let (left_keys, right_keys) = Join::rename_join_keys(left_keys, right_keys); + + return Ok(LogicalPlan::Join(Join::try_new( left_input, right_input, left_keys, @@ -316,8 +316,7 @@ fn find_inner_join( None, JoinType::Inner, None, - Arc::new(join_schema), - )) + )?) .arced()); } } @@ -325,11 +324,8 @@ fn find_inner_join( // no matching right plan had any join keys, cross join with the first right // plan let right = rights.remove(0); - let join_schema = left_input - .schema() - .non_distinct_union(right.schema().as_ref()); - Ok(LogicalPlan::Join(Join::new( + Ok(LogicalPlan::Join(Join::try_new( left_input, right, vec![], @@ -337,8 +333,7 @@ fn find_inner_join( None, JoinType::Inner, None, - Arc::new(join_schema), - )) + )?) .arced()) } diff --git a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs index 5039cc9767..cb2285a545 100644 --- a/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs +++ b/src/daft-logical-plan/src/optimization/rules/unnest_subquery.rs @@ -119,6 +119,8 @@ impl UnnestScalarSubquery { let (decorrelated_subquery, subquery_on, input_on) = pull_up_correlated_cols(subquery_plan)?; + let (input_on, subquery_on) = Join::rename_join_keys(input_on, subquery_on); + if subquery_on.is_empty() { // uncorrelated scalar subquery Ok(Arc::new(LogicalPlan::Join(Join::try_new( @@ -129,9 +131,6 @@ impl UnnestScalarSubquery { None, JoinType::Inner, None, - None, - None, - false, )?))) } else { // correlated scalar subquery @@ -143,9 +142,6 @@ impl UnnestScalarSubquery { None, JoinType::Left, None, - None, - None, - false, )?))) } })?; @@ -327,6 +323,8 @@ impl OptimizerRule for UnnestPredicateSubquery { return Err(DaftError::ValueError("Expected IN/EXISTS subquery to be correlated, found uncorrelated subquery.".to_string())); } + let (input_on, subquery_on) = Join::rename_join_keys(input_on, subquery_on); + Ok(Arc::new(LogicalPlan::Join(Join::try_new( curr_input, decorrelated_subquery, @@ -335,9 +333,6 @@ impl OptimizerRule for UnnestPredicateSubquery { None, join_type, None, - None, - None, - false )?))) })?;