Skip to content

Commit

Permalink
chore: refactor logical op constructor+builder boundary
Browse files Browse the repository at this point in the history
  • Loading branch information
kevinzwang committed Jan 14, 2025
1 parent feab49a commit 562bc4e
Show file tree
Hide file tree
Showing 8 changed files with 101 additions and 151 deletions.
42 changes: 36 additions & 6 deletions src/daft-logical-plan/src/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use common_file_formats::FileFormat;
use common_io_config::IOConfig;
use common_scan_info::{PhysicalScanInfo, Pushdowns, ScanOperatorRef};
use daft_core::join::{JoinStrategy, JoinType};
use daft_dsl::{col, ExprRef};
use daft_dsl::{col, ExprRef, ExprResolver};
use daft_schema::schema::{Schema, SchemaRef};
#[cfg(feature = "python")]
use {
Expand Down Expand Up @@ -188,11 +188,19 @@ impl LogicalPlanBuilder {
}

pub fn select(&self, to_select: Vec<ExprRef>) -> DaftResult<Self> {
let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build();

let (to_select, _) = expr_resolver.resolve(to_select, &self.schema())?;

let logical_plan: LogicalPlan = ops::Project::try_new(self.plan.clone(), to_select)?.into();
Ok(self.with_new_plan(logical_plan))
}

pub fn with_columns(&self, columns: Vec<ExprRef>) -> DaftResult<Self> {
let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build();

let (columns, _) = expr_resolver.resolve(columns, &self.schema())?;

let fields = &self.schema().fields;
let current_col_names = fields
.iter()
Expand Down Expand Up @@ -245,6 +253,10 @@ impl LogicalPlanBuilder {
}

pub fn filter(&self, predicate: ExprRef) -> DaftResult<Self> {
let expr_resolver = ExprResolver::default();

let (predicate, _) = expr_resolver.resolve_single(predicate, &self.schema())?;

let logical_plan: LogicalPlan = ops::Filter::try_new(self.plan.clone(), predicate)?.into();
Ok(self.with_new_plan(logical_plan))
}
Expand Down Expand Up @@ -438,17 +450,35 @@ impl LogicalPlanBuilder {
join_prefix: Option<&str>,
keep_join_keys: bool,
) -> DaftResult<Self> {
let left_plan = self.plan.clone();
let right_plan = right.into();

let expr_resolver = ExprResolver::default();

let (left_on, _) = expr_resolver.resolve(left_on, &left_plan.schema())?;
let (right_on, _) = expr_resolver.resolve(right_on, &right_plan.schema())?;

let (left_on, right_on) = ops::Join::rename_join_keys(left_on, right_on);

let (right_plan, right_on) = ops::Join::rename_right_columns(
left_plan.clone(),
right_plan,
left_on.clone(),
right_on,
join_type,
join_suffix,
join_prefix,
keep_join_keys,
)?;

let logical_plan: LogicalPlan = ops::Join::try_new(
self.plan.clone(),
right.into(),
left_plan,
right_plan,
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
join_suffix,
join_prefix,
keep_join_keys,
)?
.into();
Ok(self.with_new_plan(logical_plan))
Expand Down
3 changes: 0 additions & 3 deletions src/daft-logical-plan/src/logical_plan.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,9 +361,6 @@ impl LogicalPlan {
null_equals_nulls.clone(),
*join_type,
*join_strategy,
None, // The suffix is already eagerly computed in the constructor
None, // the prefix is already eagerly computed in the constructor
false // this is already eagerly computed in the constructor
).unwrap()),
_ => panic!("Logical op {} has one input, but got two", self),
},
Expand Down
12 changes: 4 additions & 8 deletions src/daft-logical-plan/src/ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use common_error::DaftError;
use daft_core::prelude::*;
use daft_dsl::{estimated_selectivity, ExprRef, ExprResolver};
use daft_dsl::{estimated_selectivity, ExprRef};
use snafu::ResultExt;

use crate::{
Expand All @@ -22,16 +22,12 @@ pub struct Filter {

impl Filter {
pub(crate) fn try_new(input: Arc<LogicalPlan>, predicate: ExprRef) -> Result<Self> {
let expr_resolver = ExprResolver::default();
let dtype = predicate.to_field(&input.schema())?.dtype;

let (predicate, field) = expr_resolver
.resolve_single(predicate, &input.schema())
.context(CreationSnafu)?;

if !matches!(field.dtype, DataType::Boolean) {
if !matches!(dtype, DataType::Boolean) {
return Err(DaftError::ValueError(format!(
"Expected expression {predicate} to resolve to type Boolean, but received: {}",
field.dtype
dtype
)))
.context(CreationSnafu);
}
Expand Down
146 changes: 43 additions & 103 deletions src/daft-logical-plan/src/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use daft_dsl::{
col,
join::{get_common_join_keys, infer_join_schema},
optimization::replace_columns_with_expressions,
Expr, ExprRef, ExprResolver,
Expr, ExprRef,
};
use itertools::Itertools;
use snafu::ResultExt;
Expand All @@ -19,7 +19,7 @@ use crate::{
logical_plan::{self, CreationSnafu},
ops::Project,
stats::{ApproxStats, PlanStats, StatsState},
LogicalPlan,
LogicalPlan, LogicalPlanRef,
};

#[derive(Clone, Debug, PartialEq, Eq)]
Expand Down Expand Up @@ -51,30 +51,6 @@ impl std::hash::Hash for Join {
}

impl Join {
#[allow(clippy::too_many_arguments)]
pub(crate) fn new(
left: Arc<LogicalPlan>,
right: Arc<LogicalPlan>,
left_on: Vec<ExprRef>,
right_on: Vec<ExprRef>,
null_equals_nulls: Option<Vec<bool>>,
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
output_schema: SchemaRef,
) -> Self {
Self {
left,
right,
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
output_schema,
stats_state: StatsState::NotMaterialized,
}
}

#[allow(clippy::too_many_arguments)]
pub(crate) fn try_new(
left: Arc<LogicalPlan>,
Expand All @@ -84,45 +60,11 @@ impl Join {
null_equals_nulls: Option<Vec<bool>>,
join_type: JoinType,
join_strategy: Option<JoinStrategy>,
join_suffix: Option<&str>,
join_prefix: Option<&str>,
// if true, then duplicate column names will be kept
// ex: select * from a left join b on a.id = b.id
// if true, then the resulting schema will have two columns named id (id, and b.id)
// In SQL the join column is always kept, while in dataframes it is not
keep_join_keys: bool,
) -> logical_plan::Result<Self> {
let expr_resolver = ExprResolver::default();

let (left_on, _) = expr_resolver
.resolve(left_on, &left.schema())
.context(CreationSnafu)?;
let (right_on, _) = expr_resolver
.resolve(right_on, &right.schema())
.context(CreationSnafu)?;

let (unique_left_on, unique_right_on) =
Self::rename_join_keys(left_on.clone(), right_on.clone());

let left_fields: Vec<Field> = unique_left_on
.iter()
.map(|e| e.to_field(&left.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

let right_fields: Vec<Field> = unique_right_on
.iter()
.map(|e| e.to_field(&right.schema()))
.collect::<DaftResult<Vec<Field>>>()
.context(CreationSnafu)?;

for (on_exprs, on_fields) in [
(&unique_left_on, &left_fields),
(&unique_right_on, &right_fields),
] {
for (field, expr) in on_fields.iter().zip(on_exprs.iter()) {
for (on_exprs, side) in [(&left_on, &left), (&right_on, &right)] {
for expr in on_exprs {
// Null type check for both fields and expressions
if matches!(field.dtype, DataType::Null) {
if matches!(expr.to_field(&side.schema())?.dtype, DataType::Null) {
return Err(DaftError::ValueError(format!(
"Can't join on null type expressions: {expr}"
)))
Expand All @@ -141,22 +83,41 @@ impl Join {
}
}

if matches!(join_type, JoinType::Anti | JoinType::Semi) {
// The output schema is the same as the left input schema for anti and semi joins.
let output_schema = infer_join_schema(
&left.schema(),
&right.schema(),
&left_on,
&right_on,
join_type,
)
.context(CreationSnafu)?;

let output_schema = left.schema();
Ok(Self {
left,
right,
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
output_schema,
stats_state: StatsState::NotMaterialized,
})
}

Ok(Self {
left,
right,
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
output_schema,
stats_state: StatsState::NotMaterialized,
})
#[allow(clippy::too_many_arguments)]
pub(crate) fn rename_right_columns(
left: LogicalPlanRef,
right: LogicalPlanRef,
left_on: Vec<ExprRef>,
right_on: Vec<ExprRef>,
join_type: JoinType,
join_suffix: Option<&str>,
join_prefix: Option<&str>,
keep_join_keys: bool,
) -> DaftResult<(LogicalPlanRef, Vec<ExprRef>)> {
if matches!(join_type, JoinType::Anti | JoinType::Semi) {
Ok((right, right_on))
} else {
let common_join_keys: HashSet<_> =
get_common_join_keys(left_on.as_slice(), right_on.as_slice())
Expand Down Expand Up @@ -202,8 +163,8 @@ impl Join {
})
.collect();

let (right, right_on) = if right_rename_mapping.is_empty() {
(right, right_on)
if right_rename_mapping.is_empty() {
Ok((right, right_on))
} else {
// projection to update the right side with the new column names
let new_right_projection: Vec<_> = right_names
Expand All @@ -230,29 +191,8 @@ impl Join {
.map(|expr| replace_columns_with_expressions(expr, &right_on_replace_map))
.collect::<Vec<_>>();

(new_right.into(), new_right_on)
};

let output_schema = infer_join_schema(
&left.schema(),
&right.schema(),
&left_on,
&right_on,
join_type,
)
.context(CreationSnafu)?;

Ok(Self {
left,
right,
left_on,
right_on,
null_equals_nulls,
join_type,
join_strategy,
output_schema,
stats_state: StatsState::NotMaterialized,
})
Ok((new_right.into(), new_right_on))
}
}
}

Expand Down Expand Up @@ -283,7 +223,7 @@ impl Join {
///
/// For more details, see [issue #2649](https://github.com/Eventual-Inc/Daft/issues/2649).
fn rename_join_keys(
pub(crate) fn rename_join_keys(
left_exprs: Vec<Arc<Expr>>,
right_exprs: Vec<Arc<Expr>>,
) -> (Vec<Arc<Expr>>, Vec<Arc<Expr>>) {
Expand Down
14 changes: 7 additions & 7 deletions src/daft-logical-plan/src/ops/project.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
use std::sync::Arc;

use common_error::DaftResult;
use common_treenode::Transformed;
use daft_core::prelude::*;
use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef, ExprResolver};
use daft_dsl::{optimization, AggExpr, ApproxPercentileParams, Expr, ExprRef};
use indexmap::{IndexMap, IndexSet};
use itertools::Itertools;
use snafu::ResultExt;
Expand All @@ -24,16 +25,15 @@ pub struct Project {

impl Project {
pub(crate) fn try_new(input: Arc<LogicalPlan>, projection: Vec<ExprRef>) -> Result<Self> {
let expr_resolver = ExprResolver::builder().allow_actor_pool_udf(true).build();

let (projection, fields) = expr_resolver
.resolve(projection, &input.schema())
.context(CreationSnafu)?;

// Factor the projection and see if there are any substitutions to factor out.
let (factored_input, factored_projection) =
Self::try_factor_subexpressions(input, projection)?;

let fields = factored_projection
.iter()
.map(|expr| expr.to_field(&factored_input.schema()))
.collect::<DaftResult<_>>()?;

let projected_schema = Schema::new(fields).context(CreationSnafu)?.into();

Ok(Self {
Expand Down
3 changes: 0 additions & 3 deletions src/daft-logical-plan/src/ops/set_operations.rs
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,6 @@ fn intersect_or_except_plan(
Some(vec![true; left_on_size]),
join_type,
None,
None,
None,
false,
);
join.map(|j| Distinct::new(j.into()).into())
}
Expand Down
Loading

0 comments on commit 562bc4e

Please sign in to comment.