-
Notifications
You must be signed in to change notification settings - Fork 182
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
refactor(connect): internal refactoring to make connect code more organized & extensible #3680
Changes from all commits
6fc29c9
6ef13bc
445b0d6
5114ef6
f2c4074
8535db9
1ade876
3ac05d9
0a1c028
1784fec
67762b0
8959c2f
4b83883
eb477e8
42ebb47
f3849b5
be7247f
3fbec97
38118f1
c24ceb4
bba17d6
05b73eb
7c172b3
ed0d04b
016ac3e
bf0d1b8
e667b96
88add4e
bfebd60
fed17d3
4a1c5f7
3b5901a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
use std::{collections::HashMap, sync::Arc}; | ||
|
||
use once_cell::sync::Lazy; | ||
use spark_connect::Expression; | ||
|
||
use crate::spark_analyzer::SparkAnalyzer; | ||
mod core; | ||
|
||
pub(crate) static CONNECT_FUNCTIONS: Lazy<SparkFunctions> = Lazy::new(|| { | ||
let mut functions = SparkFunctions::new(); | ||
functions.register::<core::CoreFunctions>(); | ||
functions | ||
}); | ||
|
||
pub trait SparkFunction: Send + Sync { | ||
fn to_expr( | ||
&self, | ||
args: &[Expression], | ||
analyzer: &SparkAnalyzer, | ||
) -> eyre::Result<daft_dsl::ExprRef>; | ||
} | ||
|
||
pub struct SparkFunctions { | ||
pub(crate) map: HashMap<String, Arc<dyn SparkFunction>>, | ||
} | ||
|
||
impl SparkFunctions { | ||
/// Create a new [SparkFunction] instance. | ||
#[must_use] | ||
pub fn new() -> Self { | ||
Self { | ||
map: HashMap::new(), | ||
} | ||
} | ||
|
||
/// Register the module to the [SparkFunctions] instance. | ||
pub fn register<M: FunctionModule>(&mut self) { | ||
M::register(self); | ||
} | ||
/// Add a [FunctionExpr] to the [SparkFunction] instance. | ||
pub fn add_fn<F: SparkFunction + 'static>(&mut self, name: &str, func: F) { | ||
self.map.insert(name.to_string(), Arc::new(func)); | ||
} | ||
|
||
/// Get a function by name from the [SparkFunctions] instance. | ||
#[must_use] | ||
pub fn get(&self, name: &str) -> Option<&Arc<dyn SparkFunction>> { | ||
self.map.get(name) | ||
} | ||
} | ||
|
||
pub trait FunctionModule { | ||
/// Register this module to the given [SparkFunctions] table. | ||
fn register(_parent: &mut SparkFunctions); | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
use daft_core::count_mode::CountMode; | ||
use daft_dsl::{binary_op, col, ExprRef, Operator}; | ||
use daft_schema::dtype::DataType; | ||
use spark_connect::Expression; | ||
|
||
use super::{FunctionModule, SparkFunction}; | ||
use crate::{invalid_argument_err, spark_analyzer::SparkAnalyzer}; | ||
|
||
// Core functions are the most basic functions such as `+`, `-`, `*`, `/`, not, notnull, etc. | ||
pub struct CoreFunctions; | ||
|
||
impl FunctionModule for CoreFunctions { | ||
fn register(parent: &mut super::SparkFunctions) { | ||
parent.add_fn("==", BinaryOpFunction(Operator::Eq)); | ||
parent.add_fn("!=", BinaryOpFunction(Operator::NotEq)); | ||
parent.add_fn("<", BinaryOpFunction(Operator::Lt)); | ||
parent.add_fn("<=", BinaryOpFunction(Operator::LtEq)); | ||
parent.add_fn(">", BinaryOpFunction(Operator::Gt)); | ||
parent.add_fn(">=", BinaryOpFunction(Operator::GtEq)); | ||
parent.add_fn("+", BinaryOpFunction(Operator::Plus)); | ||
parent.add_fn("-", BinaryOpFunction(Operator::Minus)); | ||
parent.add_fn("*", BinaryOpFunction(Operator::Multiply)); | ||
parent.add_fn("/", BinaryOpFunction(Operator::TrueDivide)); | ||
parent.add_fn("//", BinaryOpFunction(Operator::FloorDivide)); | ||
parent.add_fn("%", BinaryOpFunction(Operator::Modulus)); | ||
parent.add_fn("&", BinaryOpFunction(Operator::And)); | ||
parent.add_fn("|", BinaryOpFunction(Operator::Or)); | ||
parent.add_fn("^", BinaryOpFunction(Operator::Xor)); | ||
parent.add_fn("<<", BinaryOpFunction(Operator::ShiftLeft)); | ||
parent.add_fn(">>", BinaryOpFunction(Operator::ShiftRight)); | ||
parent.add_fn("isnotnull", UnaryFunction(|arg| arg.not_null())); | ||
parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null())); | ||
parent.add_fn("not", UnaryFunction(|arg| arg.not())); | ||
parent.add_fn("sum", UnaryFunction(|arg| arg.sum())); | ||
parent.add_fn("mean", UnaryFunction(|arg| arg.mean())); | ||
parent.add_fn("stddev", UnaryFunction(|arg| arg.stddev())); | ||
parent.add_fn("min", UnaryFunction(|arg| arg.min())); | ||
parent.add_fn("max", UnaryFunction(|arg| arg.max())); | ||
parent.add_fn("count", CountFunction); | ||
} | ||
} | ||
|
||
pub struct BinaryOpFunction(Operator); | ||
pub struct UnaryFunction(fn(ExprRef) -> ExprRef); | ||
pub struct CountFunction; | ||
|
||
impl SparkFunction for BinaryOpFunction { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Doesn't need to be implemented in this PR, but usually, if you have a fixed number of types implementing a trait, then opting for an enum seems to be better form. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. so it's currently a fixed number, but this will expand quite a bit, and it will become impractical to contain all of those in a single enum. as spark has A LOT of functions. Additionally, we'll need to support dynamically registered UDF's later down the road, so it makes sense to lay the groundwork now instead of needing to do another refactor later down the road. |
||
fn to_expr( | ||
&self, | ||
args: &[Expression], | ||
analyzer: &SparkAnalyzer, | ||
) -> eyre::Result<daft_dsl::ExprRef> { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not the right PR for this, but I'm just curious: Why are we using There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally agree here. It's been on my todo list to change the error handling in daft-connect. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
let args = args | ||
.iter() | ||
.map(|arg| analyzer.to_daft_expr(arg)) | ||
.collect::<eyre::Result<Vec<_>>>()?; | ||
|
||
let [lhs, rhs] = args | ||
.try_into() | ||
.map_err(|args| eyre::eyre!("requires exactly two arguments; got {:?}", args))?; | ||
|
||
Ok(binary_op(self.0, lhs, rhs)) | ||
Comment on lines
+53
to
+62
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe we could try to keep the same pattern for match args {
[lhs, rhs] => ..,
_ => return invalid_argument_err!("requires exactly two arguments; got {args:?}"),
} |
||
} | ||
} | ||
|
||
impl SparkFunction for UnaryFunction { | ||
fn to_expr( | ||
&self, | ||
args: &[Expression], | ||
analyzer: &SparkAnalyzer, | ||
) -> eyre::Result<daft_dsl::ExprRef> { | ||
match args { | ||
[arg] => { | ||
let arg = analyzer.to_daft_expr(arg)?; | ||
Ok(self.0(arg)) | ||
} | ||
_ => invalid_argument_err!("requires exactly one argument")?, | ||
} | ||
} | ||
} | ||
|
||
impl SparkFunction for CountFunction { | ||
fn to_expr( | ||
&self, | ||
args: &[Expression], | ||
analyzer: &SparkAnalyzer, | ||
) -> eyre::Result<daft_dsl::ExprRef> { | ||
match args { | ||
[arg] => { | ||
let arg = analyzer.to_daft_expr(arg)?; | ||
|
||
let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Isn't There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. in spark, it is only ever sent over as so it can only ever be called via df.count() which gets serialized via protobuf as |
||
col("*") | ||
} else { | ||
arg | ||
}; | ||
|
||
let count = arg.count(CountMode::All).cast(&DataType::Int64); | ||
|
||
Ok(count) | ||
} | ||
_ => invalid_argument_err!("requires exactly one argument")?, | ||
} | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Something like:
This would allow us to avoid all the
dyn SparkFunction
stuff elsewhere.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
so the main reasons behind using the trait impl instead of enums is