Skip to content

Commit

Permalink
perf: Improve stats for join side determination (#3655)
Browse files Browse the repository at this point in the history
This PR updates swordfish join side determination logic to compare num
rows instead of upper bound size bytes.

### Details:
- Use a fixed `num_rows` and `size_bytes` in the `ApproxStats` instead
of lower / upper bounds.
- Instead of having a fixed 20% selectivity for filters, take into
account the complexity of the filter expressions into the selectivity.
E.g. `ANDS` will be more selective than `ORS`, and `IS_NULL` will
generally be less selective than comparisons or equalities. (This is
useful because all of our joins have null filter pushdowns, but it tends
to be the case that the side with the more complex filters will be the
better side for the hash table, and having a fixed 20% selectivity will
miss out on this)

### Results on TPCH SF10:

| Query | Original (ms) | Latest (ms) | Change (%) |
|-------|--------------|-------------|------------|
| Q1    | 493.39 | 500.53 | +1.45% |
| Q2    | 158.53 | 149.09 | -5.95% |
| Q3    | 499.76 | 490.63 | -1.83% |
| Q4    | 2527.00 | 269.22 | -89.34% |
| Q5    | 757.68 | 721.30 | -4.80% |
| Q6    | 151.95 | 156.64 | +3.09% |
| Q7    | 471.33 | 458.81 | -2.66% |
| Q8    | 568.80 | 1519.60 | +167.16% |
| Q9    | 3644.70 | 3572.20 | -1.99% |
| Q10   | 752.60 | 722.10 | -4.05% |
| Q11   | 238.79 | 223.87 | -6.25% |
| Q12   | 2676.30 | 320.82 | -88.01% |
| Q13   | 979.72 | 962.65 | -1.74% |
| Q14   | 510.17 | 504.23 | -1.16% |
| Q15   | 480.39 | 468.77 | -2.42% |
| Q16   | 183.70 | 188.36 | +2.54% |
| Q17   | 392.50 | 375.32 | -4.38% |
| Q18   | 7706.50 | 855.50 | -88.90% |
| Q19   | 955.58 | 977.44 | +2.29% |
| Q20   | 458.89 | 1191.30 | +159.61% |
| Q21   | 10236.90 | 9616.3 | -6.06% |
| Q22   | 2188.00 | 186.51 | -91.47% |

Total time:
- Before: 36.03s
- After: 24.36s

### Notes:
- Q8 and Q20 now have regressions. This is because cardinality
estimation for joins is not accurate (it assumes primary key / foreign
key join), leading to wrong probe side decisions for subsequent joins.

---------

Co-authored-by: Colin Ho <colinho@Colins-MBP.localdomain>
  • Loading branch information
colin-ho and Colin Ho authored Jan 10, 2025
1 parent bb85070 commit c932ec9
Show file tree
Hide file tree
Showing 19 changed files with 243 additions and 272 deletions.
11 changes: 10 additions & 1 deletion src/common/scan-info/src/pushdowns.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use std::sync::Arc;

use common_display::DisplayAs;
use daft_dsl::ExprRef;
use daft_dsl::{estimated_selectivity, ExprRef};
use daft_schema::schema::Schema;
use serde::{Deserialize, Serialize};

#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
Expand Down Expand Up @@ -103,6 +104,14 @@ impl Pushdowns {
}
res
}

pub fn estimated_selectivity(&self, schema: &Schema) -> f64 {
if let Some(filters) = &self.filters {
estimated_selectivity(filters, schema)
} else {
1.0
}
}
}

impl DisplayAs for Pushdowns {
Expand Down
73 changes: 73 additions & 0 deletions src/daft-dsl/src/expr/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1325,3 +1325,76 @@ pub fn count_actor_pool_udfs(exprs: &[ExprRef]) -> usize {
})
.sum()
}

pub fn estimated_selectivity(expr: &Expr, schema: &Schema) -> f64 {
match expr {
// Boolean operations that filter rows
Expr::BinaryOp { op, left, right } => {
let left_selectivity = estimated_selectivity(left, schema);
let right_selectivity = estimated_selectivity(right, schema);
match op {
// Fixed selectivity for all common comparisons
Operator::Eq => 0.1,
Operator::NotEq => 0.9,
Operator::Lt | Operator::LtEq | Operator::Gt | Operator::GtEq => 0.2,

// Logical operators with fixed estimates
// P(A and B) = P(A) * P(B)
Operator::And => left_selectivity * right_selectivity,
// P(A or B) = P(A) + P(B) - P(A and B)
Operator::Or => left_selectivity
.mul_add(-right_selectivity, left_selectivity + right_selectivity),
// P(A xor B) = P(A) + P(B) - 2 * P(A and B)
Operator::Xor => 2.0f64.mul_add(
-(left_selectivity * right_selectivity),
left_selectivity + right_selectivity,
),

// Non-boolean operators don't filter
Operator::Plus
| Operator::Minus
| Operator::Multiply
| Operator::TrueDivide
| Operator::FloorDivide
| Operator::Modulus
| Operator::ShiftLeft
| Operator::ShiftRight => 1.0,
}
}

// Revert selectivity for NOT
Expr::Not(expr) => 1.0 - estimated_selectivity(expr, schema),

// Fixed selectivity for IS NULL and IS NOT NULL, assume not many nulls
Expr::IsNull(_) => 0.1,
Expr::NotNull(_) => 0.9,

// All membership operations use same selectivity
Expr::IsIn(_, _) | Expr::Between(_, _, _) | Expr::InSubquery(_, _) | Expr::Exists(_) => 0.2,

// Pass through for expressions that wrap other expressions
Expr::Cast(expr, _) | Expr::Alias(expr, _) => estimated_selectivity(expr, schema),

// Boolean literals
Expr::Literal(lit) => match lit {
lit::LiteralValue::Boolean(true) => 1.0,
lit::LiteralValue::Boolean(false) => 0.0,
_ => 1.0,
},

// Everything else that could be boolean gets 0.2, non-boolean gets 1.0
Expr::ScalarFunction(_)
| Expr::Function { .. }
| Expr::Column(_)
| Expr::OuterReferenceColumn(_)
| Expr::IfElse { .. }
| Expr::FillNull(_, _) => match expr.to_field(schema) {
Ok(field) if field.dtype == DataType::Boolean => 0.2,
_ => 1.0,
},

// Everything else doesn't filter
Expr::Subquery(_) => 1.0,
Expr::Agg(_) => panic!("Aggregates are not allowed in WHERE clauses"),
}
}
6 changes: 3 additions & 3 deletions src/daft-dsl/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ mod resolve_expr;
mod treenode;
pub use common_treenode;
pub use expr::{
binary_op, col, count_actor_pool_udfs, has_agg, is_actor_pool_udf, is_partition_compatible,
AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator, OuterReferenceColumn, SketchType,
Subquery, SubqueryPlan,
binary_op, col, count_actor_pool_udfs, estimated_selectivity, has_agg, is_actor_pool_udf,
is_partition_compatible, AggExpr, ApproxPercentileParams, Expr, ExprRef, Operator,
OuterReferenceColumn, SketchType, Subquery, SubqueryPlan,
};
pub use lit::{lit, literal_value, literals_to_series, null_lit, Literal, LiteralValue};
#[cfg(feature = "python")]
Expand Down
68 changes: 24 additions & 44 deletions src/daft-local-execution/src/pipeline.rs
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,15 @@ pub fn physical_plan_to_pipeline(
StatsState::Materialized(left_stats),
StatsState::Materialized(right_stats),
) => {
let left_size = left_stats.approx_stats.upper_bound_bytes;
let right_size = right_stats.approx_stats.upper_bound_bytes;
left_size.zip(right_size).map_or(true, |(l, r)| l <= r)
let left_size = left_stats.approx_stats.size_bytes;
let right_size = right_stats.approx_stats.size_bytes;
left_size <= right_size
}
// If stats are only available on the right side of the join, and the upper bound bytes on the
// right are under the broadcast join size threshold, we build on the right instead of the left.
(StatsState::NotMaterialized, StatsState::Materialized(right_stats)) => {
right_stats
.approx_stats
.upper_bound_bytes
.map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold)
right_stats.approx_stats.size_bytes
> cfg.broadcast_join_size_bytes_threshold
}
_ => true,
},
Expand All @@ -304,21 +302,15 @@ pub fn physical_plan_to_pipeline(
StatsState::Materialized(left_stats),
StatsState::Materialized(right_stats),
) => {
let left_size = left_stats.approx_stats.upper_bound_bytes;
let right_size = right_stats.approx_stats.upper_bound_bytes;
left_size
.zip(right_size)
.map_or(false, |(l, r)| (r as f64) >= ((l as f64) * 1.5))
let left_size = left_stats.approx_stats.size_bytes;
let right_size = right_stats.approx_stats.size_bytes;
right_size as f64 >= left_size as f64 * 1.5
}
// If stats are only available on the left side of the join, and the upper bound bytes on the left
// are under the broadcast join size threshold, we build on the left instead of the right.
(StatsState::Materialized(left_stats), StatsState::NotMaterialized) => {
left_stats
.approx_stats
.upper_bound_bytes
.map_or(false, |size| {
size <= cfg.broadcast_join_size_bytes_threshold
})
left_stats.approx_stats.size_bytes
<= cfg.broadcast_join_size_bytes_threshold
}
_ => false,
},
Expand All @@ -329,19 +321,15 @@ pub fn physical_plan_to_pipeline(
StatsState::Materialized(left_stats),
StatsState::Materialized(right_stats),
) => {
let left_size = left_stats.approx_stats.upper_bound_bytes;
let right_size = right_stats.approx_stats.upper_bound_bytes;
left_size
.zip(right_size)
.map_or(true, |(l, r)| ((r as f64) * 1.5) >= (l as f64))
let left_size = left_stats.approx_stats.size_bytes;
let right_size = right_stats.approx_stats.size_bytes;
(right_size as f64 * 1.5) >= left_size as f64
}
// If stats are only available on the right side of the join, and the upper bound bytes on the
// right are under the broadcast join size threshold, we build on the right instead of the left.
(StatsState::NotMaterialized, StatsState::Materialized(right_stats)) => {
right_stats
.approx_stats
.upper_bound_bytes
.map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold)
right_stats.approx_stats.size_bytes
> cfg.broadcast_join_size_bytes_threshold
}
_ => true,
},
Expand All @@ -352,21 +340,15 @@ pub fn physical_plan_to_pipeline(
StatsState::Materialized(left_stats),
StatsState::Materialized(right_stats),
) => {
let left_size = left_stats.approx_stats.upper_bound_bytes;
let right_size = right_stats.approx_stats.upper_bound_bytes;
left_size
.zip(right_size)
.map_or(false, |(l, r)| (r as f64) > ((l as f64) * 1.5))
let left_size = left_stats.approx_stats.size_bytes;
let right_size = right_stats.approx_stats.size_bytes;
right_size as f64 > left_size as f64 * 1.5
}
// If stats are only available on the left side of the join, and the upper bound bytes on the left
// are under the broadcast join size threshold, we build on the left instead of the right.
(StatsState::Materialized(left_stats), StatsState::NotMaterialized) => {
left_stats
.approx_stats
.upper_bound_bytes
.map_or(false, |size| {
size <= cfg.broadcast_join_size_bytes_threshold
})
left_stats.approx_stats.size_bytes
<= cfg.broadcast_join_size_bytes_threshold
}
// Else, default to building on the right
_ => false,
Expand Down Expand Up @@ -494,15 +476,13 @@ pub fn physical_plan_to_pipeline(
// the larger side to stream so that it can be parallelized via an intermediate op. Default to left side.
let stream_on_left = match (left_stats_state, right_stats_state) {
(StatsState::Materialized(left_stats), StatsState::Materialized(right_stats)) => {
left_stats.approx_stats.upper_bound_bytes
> right_stats.approx_stats.upper_bound_bytes
left_stats.approx_stats.num_rows > right_stats.approx_stats.num_rows
}
// If stats are only available on the left side of the join, and the upper bound bytes on the
// left are under the broadcast join size threshold, we stream on the right.
(StatsState::Materialized(left_stats), StatsState::NotMaterialized) => left_stats
.approx_stats
.upper_bound_bytes
.map_or(true, |size| size > cfg.broadcast_join_size_bytes_threshold),
(StatsState::Materialized(left_stats), StatsState::NotMaterialized) => {
left_stats.approx_stats.size_bytes > cfg.broadcast_join_size_bytes_threshold
}
// If stats are not available, we fall back and stream on the left by default.
_ => true,
};
Expand Down
30 changes: 8 additions & 22 deletions src/daft-logical-plan/src/ops/agg.rs
Original file line number Diff line number Diff line change
Expand Up @@ -64,33 +64,19 @@ impl Aggregate {
pub(crate) fn with_materialized_stats(mut self) -> Self {
// TODO(desmond): We can use the schema here for better estimations. For now, use the old logic.
let input_stats = self.input.materialized_stats();
let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes
/ (input_stats.approx_stats.lower_bound_rows.max(1));
let est_bytes_per_row_upper =
input_stats
.approx_stats
.upper_bound_bytes
.and_then(|bytes| {
input_stats
.approx_stats
.upper_bound_rows
.map(|rows| bytes / rows.max(1))
});
let est_bytes_per_row =
input_stats.approx_stats.size_bytes / (input_stats.approx_stats.num_rows.max(1));
let approx_stats = if self.groupby.is_empty() {
ApproxStats {
lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1),
upper_bound_rows: Some(1),
lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1)
* est_bytes_per_row_lower,
upper_bound_bytes: est_bytes_per_row_upper,
num_rows: 1,
size_bytes: est_bytes_per_row,
}
} else {
// Assume high cardinality for group by columns, and 80% of rows are unique.
let est_num_groups = input_stats.approx_stats.num_rows * 4 / 5;
ApproxStats {
lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1),
upper_bound_rows: input_stats.approx_stats.upper_bound_rows,
lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1)
* est_bytes_per_row_lower,
upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes,
num_rows: est_num_groups,
size_bytes: est_bytes_per_row * est_num_groups,
}
};
self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into());
Expand Down
13 changes: 6 additions & 7 deletions src/daft-logical-plan/src/ops/distinct.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ impl Distinct {
pub(crate) fn with_materialized_stats(mut self) -> Self {
// TODO(desmond): We can simply use NDVs here. For now, do a naive estimation.
let input_stats = self.input.materialized_stats();
let est_bytes_per_row_lower = input_stats.approx_stats.lower_bound_bytes
/ (input_stats.approx_stats.lower_bound_rows.max(1));
let est_bytes_per_row =
input_stats.approx_stats.size_bytes / (input_stats.approx_stats.num_rows.max(1));
// Assume high cardinality, 80% of rows are distinct.
let est_distinct_values = input_stats.approx_stats.num_rows * 4 / 5;
let approx_stats = ApproxStats {
lower_bound_rows: input_stats.approx_stats.lower_bound_rows.min(1),
upper_bound_rows: input_stats.approx_stats.upper_bound_rows,
lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes.min(1)
* est_bytes_per_row_lower,
upper_bound_bytes: input_stats.approx_stats.upper_bound_bytes,
num_rows: est_distinct_values,
size_bytes: est_distinct_values * est_bytes_per_row,
};
self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into());
self
Expand Down
7 changes: 3 additions & 4 deletions src/daft-logical-plan/src/ops/explode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,11 +67,10 @@ impl Explode {

pub(crate) fn with_materialized_stats(mut self) -> Self {
let input_stats = self.input.materialized_stats();
let est_num_exploded_rows = input_stats.approx_stats.num_rows * 4;
let approx_stats = ApproxStats {
lower_bound_rows: input_stats.approx_stats.lower_bound_rows,
upper_bound_rows: None,
lower_bound_bytes: input_stats.approx_stats.lower_bound_bytes,
upper_bound_bytes: None,
num_rows: est_num_exploded_rows,
size_bytes: input_stats.approx_stats.size_bytes,
};
self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into());
self
Expand Down
13 changes: 6 additions & 7 deletions src/daft-logical-plan/src/ops/filter.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::sync::Arc;

use common_error::DaftError;
use daft_core::prelude::*;
use daft_dsl::{ExprRef, ExprResolver};
use daft_dsl::{estimated_selectivity, ExprRef, ExprResolver};
use snafu::ResultExt;

use crate::{
Expand Down Expand Up @@ -46,13 +46,12 @@ impl Filter {
// Assume no row/column pruning in cardinality-affecting operations.
// TODO(desmond): We can do better estimations here. For now, reuse the old logic.
let input_stats = self.input.materialized_stats();
let upper_bound_rows = input_stats.approx_stats.upper_bound_rows;
let upper_bound_bytes = input_stats.approx_stats.upper_bound_bytes;
let estimated_selectivity = estimated_selectivity(&self.predicate, &self.input.schema());
let approx_stats = ApproxStats {
lower_bound_rows: 0,
upper_bound_rows,
lower_bound_bytes: 0,
upper_bound_bytes,
num_rows: (input_stats.approx_stats.num_rows as f64 * estimated_selectivity).ceil()
as usize,
size_bytes: (input_stats.approx_stats.size_bytes as f64 * estimated_selectivity).ceil()
as usize,
};
self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into());
self
Expand Down
14 changes: 6 additions & 8 deletions src/daft-logical-plan/src/ops/join.rs
Original file line number Diff line number Diff line change
Expand Up @@ -317,16 +317,14 @@ impl Join {
let left_stats = self.left.materialized_stats();
let right_stats = self.right.materialized_stats();
let approx_stats = ApproxStats {
lower_bound_rows: 0,
upper_bound_rows: left_stats
num_rows: left_stats
.approx_stats
.upper_bound_rows
.and_then(|l| right_stats.approx_stats.upper_bound_rows.map(|r| l.max(r))),
lower_bound_bytes: 0,
upper_bound_bytes: left_stats
.num_rows
.max(right_stats.approx_stats.num_rows),
size_bytes: left_stats
.approx_stats
.upper_bound_bytes
.and_then(|l| right_stats.approx_stats.upper_bound_bytes.map(|r| l.max(r))),
.size_bytes
.max(right_stats.approx_stats.size_bytes),
};
self.stats_state = StatsState::Materialized(PlanStats::new(approx_stats).into());
self
Expand Down
Loading

0 comments on commit c932ec9

Please sign in to comment.