-
Notifications
You must be signed in to change notification settings - Fork 182
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor(connect): internal refactoring to make connect code more org…
…anized & extensible (#3680) There's 2 refactors that happen in this PR. ### coalesce `SparkAnalyzer` impls a lot of method impls for `SparkAnalyzer` were in a single file making the code difficult to navigate. Most IDE's/editors have much better support for "goto symbol" or "symbol search" for single buffers instead of project wide symbol search. So coalescing all `impl SparkAnalyzer` into a single file makes things much more navigable without needing to use project wide symbol search, and without needing to jump between many files. ## functions refactor for extensibility previously, all of the supported spark functions were hardcoded and inlined inside a single function. This felt kinda unintuitive. Adding certain functionality _(udf)_ becomes difficult, if not impossible without a registry. So I refactored it to mirror our daft-sql function implementation. Now there is a function registry, and you just need to impl the trait and register it. So implementing a connect function should now feel very similar to implementing a sql function. ex: #### you can register a single function ```rs pub struct CountFunction; impl SparkFunction for CountFunction { fn to_expr( &self, args: &[Expression], analyzer: &SparkAnalyzer, ) -> eyre::Result<daft_dsl::ExprRef> { todo!() } } // functions.rs let mut functions = SparkFunctions::new(); functions.add_fn("count", CountFunction); ``` #### or you can register an entire function module ```rs // functions/core.rs pub struct CoreFunctions; impl FunctionModule for CoreFunctions { fn register(parent: &mut super::SparkFunctions) { parent.add_fn("count", CountFunction); } } // functions.rs let mut functions = SparkFunctions::new(); functions.register::<core::CoreFunctions>(); ```
- Loading branch information
1 parent
4b67e5a
commit beae462
Showing
26 changed files
with
1,102 additions
and
1,293 deletions.
There are no files selected for viewing
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
use std::{collections::HashMap, sync::Arc}; | ||
|
||
use once_cell::sync::Lazy; | ||
use spark_connect::Expression; | ||
|
||
use crate::spark_analyzer::SparkAnalyzer; | ||
mod core; | ||
|
||
pub(crate) static CONNECT_FUNCTIONS: Lazy<SparkFunctions> = Lazy::new(|| { | ||
let mut functions = SparkFunctions::new(); | ||
functions.register::<core::CoreFunctions>(); | ||
functions | ||
}); | ||
|
||
pub trait SparkFunction: Send + Sync { | ||
fn to_expr( | ||
&self, | ||
args: &[Expression], | ||
analyzer: &SparkAnalyzer, | ||
) -> eyre::Result<daft_dsl::ExprRef>; | ||
} | ||
|
||
pub struct SparkFunctions { | ||
pub(crate) map: HashMap<String, Arc<dyn SparkFunction>>, | ||
} | ||
|
||
impl SparkFunctions { | ||
/// Create a new [SparkFunction] instance. | ||
#[must_use] | ||
pub fn new() -> Self { | ||
Self { | ||
map: HashMap::new(), | ||
} | ||
} | ||
|
||
/// Register the module to the [SparkFunctions] instance. | ||
pub fn register<M: FunctionModule>(&mut self) { | ||
M::register(self); | ||
} | ||
/// Add a [FunctionExpr] to the [SparkFunction] instance. | ||
pub fn add_fn<F: SparkFunction + 'static>(&mut self, name: &str, func: F) { | ||
self.map.insert(name.to_string(), Arc::new(func)); | ||
} | ||
|
||
/// Get a function by name from the [SparkFunctions] instance. | ||
#[must_use] | ||
pub fn get(&self, name: &str) -> Option<&Arc<dyn SparkFunction>> { | ||
self.map.get(name) | ||
} | ||
} | ||
|
||
pub trait FunctionModule { | ||
/// Register this module to the given [SparkFunctions] table. | ||
fn register(_parent: &mut SparkFunctions); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
use daft_core::count_mode::CountMode; | ||
use daft_dsl::{binary_op, col, ExprRef, Operator}; | ||
use daft_schema::dtype::DataType; | ||
use spark_connect::Expression; | ||
|
||
use super::{FunctionModule, SparkFunction}; | ||
use crate::{invalid_argument_err, spark_analyzer::SparkAnalyzer}; | ||
|
||
// Core functions are the most basic functions such as `+`, `-`, `*`, `/`, not, notnull, etc. | ||
pub struct CoreFunctions; | ||
|
||
impl FunctionModule for CoreFunctions { | ||
fn register(parent: &mut super::SparkFunctions) { | ||
parent.add_fn("==", BinaryOpFunction(Operator::Eq)); | ||
parent.add_fn("!=", BinaryOpFunction(Operator::NotEq)); | ||
parent.add_fn("<", BinaryOpFunction(Operator::Lt)); | ||
parent.add_fn("<=", BinaryOpFunction(Operator::LtEq)); | ||
parent.add_fn(">", BinaryOpFunction(Operator::Gt)); | ||
parent.add_fn(">=", BinaryOpFunction(Operator::GtEq)); | ||
parent.add_fn("+", BinaryOpFunction(Operator::Plus)); | ||
parent.add_fn("-", BinaryOpFunction(Operator::Minus)); | ||
parent.add_fn("*", BinaryOpFunction(Operator::Multiply)); | ||
parent.add_fn("/", BinaryOpFunction(Operator::TrueDivide)); | ||
parent.add_fn("//", BinaryOpFunction(Operator::FloorDivide)); | ||
parent.add_fn("%", BinaryOpFunction(Operator::Modulus)); | ||
parent.add_fn("&", BinaryOpFunction(Operator::And)); | ||
parent.add_fn("|", BinaryOpFunction(Operator::Or)); | ||
parent.add_fn("^", BinaryOpFunction(Operator::Xor)); | ||
parent.add_fn("<<", BinaryOpFunction(Operator::ShiftLeft)); | ||
parent.add_fn(">>", BinaryOpFunction(Operator::ShiftRight)); | ||
parent.add_fn("isnotnull", UnaryFunction(|arg| arg.not_null())); | ||
parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null())); | ||
parent.add_fn("not", UnaryFunction(|arg| arg.not())); | ||
parent.add_fn("sum", UnaryFunction(|arg| arg.sum())); | ||
parent.add_fn("mean", UnaryFunction(|arg| arg.mean())); | ||
parent.add_fn("stddev", UnaryFunction(|arg| arg.stddev())); | ||
parent.add_fn("min", UnaryFunction(|arg| arg.min())); | ||
parent.add_fn("max", UnaryFunction(|arg| arg.max())); | ||
parent.add_fn("count", CountFunction); | ||
} | ||
} | ||
|
||
pub struct BinaryOpFunction(Operator); | ||
pub struct UnaryFunction(fn(ExprRef) -> ExprRef); | ||
pub struct CountFunction; | ||
|
||
impl SparkFunction for BinaryOpFunction { | ||
fn to_expr( | ||
&self, | ||
args: &[Expression], | ||
analyzer: &SparkAnalyzer, | ||
) -> eyre::Result<daft_dsl::ExprRef> { | ||
let args = args | ||
.iter() | ||
.map(|arg| analyzer.to_daft_expr(arg)) | ||
.collect::<eyre::Result<Vec<_>>>()?; | ||
|
||
let [lhs, rhs] = args | ||
.try_into() | ||
.map_err(|args| eyre::eyre!("requires exactly two arguments; got {:?}", args))?; | ||
|
||
Ok(binary_op(self.0, lhs, rhs)) | ||
} | ||
} | ||
|
||
impl SparkFunction for UnaryFunction { | ||
fn to_expr( | ||
&self, | ||
args: &[Expression], | ||
analyzer: &SparkAnalyzer, | ||
) -> eyre::Result<daft_dsl::ExprRef> { | ||
match args { | ||
[arg] => { | ||
let arg = analyzer.to_daft_expr(arg)?; | ||
Ok(self.0(arg)) | ||
} | ||
_ => invalid_argument_err!("requires exactly one argument")?, | ||
} | ||
} | ||
} | ||
|
||
impl SparkFunction for CountFunction { | ||
fn to_expr( | ||
&self, | ||
args: &[Expression], | ||
analyzer: &SparkAnalyzer, | ||
) -> eyre::Result<daft_dsl::ExprRef> { | ||
match args { | ||
[arg] => { | ||
let arg = analyzer.to_daft_expr(arg)?; | ||
|
||
let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) { | ||
col("*") | ||
} else { | ||
arg | ||
}; | ||
|
||
let count = arg.count(CountMode::All).cast(&DataType::Int64); | ||
|
||
Ok(count) | ||
} | ||
_ => invalid_argument_err!("requires exactly one argument")?, | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.