Skip to content

Commit

Permalink
refactor(connect): internal refactoring to make connect code more org…
Browse files Browse the repository at this point in the history
…anized & extensible (#3680)

There's 2 refactors that happen in this PR. 

### coalesce `SparkAnalyzer` impls

a lot of method impls for `SparkAnalyzer` were in a single file making
the code difficult to navigate. Most IDE's/editors have much better
support for "goto symbol" or "symbol search" for single buffers instead
of project wide symbol search. So coalescing all `impl SparkAnalyzer`
into a single file makes things much more navigable without needing to
use project wide symbol search, and without needing to jump between many
files.

## functions refactor for extensibility

previously, all of the supported spark functions were hardcoded and
inlined inside a single function. This felt kinda unintuitive. Adding
certain functionality _(udf)_ becomes difficult, if not impossible
without a registry. So I refactored it to mirror our daft-sql function
implementation. Now there is a function registry, and you just need to
impl the trait and register it. So implementing a connect function
should now feel very similar to implementing a sql function.

ex:


#### you can register a single function

```rs
pub struct CountFunction;

impl SparkFunction for CountFunction {
    fn to_expr(
        &self,
        args: &[Expression],
        analyzer: &SparkAnalyzer,
    ) -> eyre::Result<daft_dsl::ExprRef> {
        todo!()
    }
}

// functions.rs
let mut functions = SparkFunctions::new();
functions.add_fn("count", CountFunction);
```


#### or you can register an entire function module
```rs

// functions/core.rs
pub struct CoreFunctions;

impl FunctionModule for CoreFunctions {
    fn register(parent: &mut super::SparkFunctions) {
        parent.add_fn("count", CountFunction);
    }
}
// functions.rs
let mut functions = SparkFunctions::new();
functions.register::<core::CoreFunctions>();

```
  • Loading branch information
universalmind303 authored Jan 16, 2025
1 parent 4b67e5a commit beae462
Show file tree
Hide file tree
Showing 26 changed files with 1,102 additions and 1,293 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dashmap = "6.1.0"
eyre = "0.6.12"
futures = "0.3.31"
itertools = {workspace = true}
once_cell = {workspace = true}
pyo3 = {workspace = true, optional = true}
spark-connect = {workspace = true}
textwrap = "0.16.1"
Expand Down
4 changes: 2 additions & 2 deletions src/daft-connect/src/connect_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ use crate::{
invalid_argument_err, not_yet_implemented,
response_builder::ResponseBuilder,
session::Session,
translation::{self, SparkAnalyzer},
spark_analyzer::{to_spark_datatype, SparkAnalyzer},
util::FromOptionalField,
};

Expand Down Expand Up @@ -180,7 +180,7 @@ impl SparkConnectService for DaftSparkConnectService {

let daft_schema = daft_schema.to_struct();

let schema = translation::to_spark_datatype(&daft_schema);
let schema = to_spark_datatype(&daft_schema);

Ok(Response::new(rb.schema_response(schema)))
}
Expand Down
10 changes: 5 additions & 5 deletions src/daft-connect/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};
use tracing::debug;

use crate::{
not_yet_implemented, response_builder::ResponseBuilder, session::Session, translation,
util::FromOptionalField, ExecuteStream, Runner,
not_yet_implemented, response_builder::ResponseBuilder, session::Session,
spark_analyzer::SparkAnalyzer, util::FromOptionalField, ExecuteStream, Runner,
};

impl Session {
Expand Down Expand Up @@ -93,7 +93,7 @@ impl Session {

tokio::spawn(async move {
let execution_fut = async {
let translator = translation::SparkAnalyzer::new(&this);
let translator = SparkAnalyzer::new(&this);
match command.rel_type {
Some(RelType::ShowString(ss)) => {
let response = this.show_string(*ss, res.clone()).await?;
Expand Down Expand Up @@ -208,7 +208,7 @@ impl Session {
}
};

let translator = translation::SparkAnalyzer::new(&this);
let translator = SparkAnalyzer::new(&this);

let plan = translator.to_logical_plan(input).await?;

Expand Down Expand Up @@ -241,7 +241,7 @@ impl Session {
show_string: ShowString,
response_builder: ResponseBuilder<ExecutePlanResponse>,
) -> eyre::Result<ExecutePlanResponse> {
let translator = translation::SparkAnalyzer::new(self);
let translator = SparkAnalyzer::new(self);

let ShowString {
input,
Expand Down
55 changes: 55 additions & 0 deletions src/daft-connect/src/functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use std::{collections::HashMap, sync::Arc};

use once_cell::sync::Lazy;
use spark_connect::Expression;

use crate::spark_analyzer::SparkAnalyzer;
mod core;

pub(crate) static CONNECT_FUNCTIONS: Lazy<SparkFunctions> = Lazy::new(|| {
let mut functions = SparkFunctions::new();
functions.register::<core::CoreFunctions>();
functions
});

pub trait SparkFunction: Send + Sync {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> eyre::Result<daft_dsl::ExprRef>;
}

pub struct SparkFunctions {
pub(crate) map: HashMap<String, Arc<dyn SparkFunction>>,
}

impl SparkFunctions {
/// Create a new [SparkFunction] instance.
#[must_use]
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}

/// Register the module to the [SparkFunctions] instance.
pub fn register<M: FunctionModule>(&mut self) {
M::register(self);
}
/// Add a [FunctionExpr] to the [SparkFunction] instance.
pub fn add_fn<F: SparkFunction + 'static>(&mut self, name: &str, func: F) {
self.map.insert(name.to_string(), Arc::new(func));
}

/// Get a function by name from the [SparkFunctions] instance.
#[must_use]
pub fn get(&self, name: &str) -> Option<&Arc<dyn SparkFunction>> {
self.map.get(name)
}
}

pub trait FunctionModule {
/// Register this module to the given [SparkFunctions] table.
fn register(_parent: &mut SparkFunctions);
}
105 changes: 105 additions & 0 deletions src/daft-connect/src/functions/core.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use daft_core::count_mode::CountMode;
use daft_dsl::{binary_op, col, ExprRef, Operator};
use daft_schema::dtype::DataType;
use spark_connect::Expression;

use super::{FunctionModule, SparkFunction};
use crate::{invalid_argument_err, spark_analyzer::SparkAnalyzer};

// Core functions are the most basic functions such as `+`, `-`, `*`, `/`, not, notnull, etc.
pub struct CoreFunctions;

impl FunctionModule for CoreFunctions {
fn register(parent: &mut super::SparkFunctions) {
parent.add_fn("==", BinaryOpFunction(Operator::Eq));
parent.add_fn("!=", BinaryOpFunction(Operator::NotEq));
parent.add_fn("<", BinaryOpFunction(Operator::Lt));
parent.add_fn("<=", BinaryOpFunction(Operator::LtEq));
parent.add_fn(">", BinaryOpFunction(Operator::Gt));
parent.add_fn(">=", BinaryOpFunction(Operator::GtEq));
parent.add_fn("+", BinaryOpFunction(Operator::Plus));
parent.add_fn("-", BinaryOpFunction(Operator::Minus));
parent.add_fn("*", BinaryOpFunction(Operator::Multiply));
parent.add_fn("/", BinaryOpFunction(Operator::TrueDivide));
parent.add_fn("//", BinaryOpFunction(Operator::FloorDivide));
parent.add_fn("%", BinaryOpFunction(Operator::Modulus));
parent.add_fn("&", BinaryOpFunction(Operator::And));
parent.add_fn("|", BinaryOpFunction(Operator::Or));
parent.add_fn("^", BinaryOpFunction(Operator::Xor));
parent.add_fn("<<", BinaryOpFunction(Operator::ShiftLeft));
parent.add_fn(">>", BinaryOpFunction(Operator::ShiftRight));
parent.add_fn("isnotnull", UnaryFunction(|arg| arg.not_null()));
parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null()));
parent.add_fn("not", UnaryFunction(|arg| arg.not()));
parent.add_fn("sum", UnaryFunction(|arg| arg.sum()));
parent.add_fn("mean", UnaryFunction(|arg| arg.mean()));
parent.add_fn("stddev", UnaryFunction(|arg| arg.stddev()));
parent.add_fn("min", UnaryFunction(|arg| arg.min()));
parent.add_fn("max", UnaryFunction(|arg| arg.max()));
parent.add_fn("count", CountFunction);
}
}

pub struct BinaryOpFunction(Operator);
pub struct UnaryFunction(fn(ExprRef) -> ExprRef);
pub struct CountFunction;

impl SparkFunction for BinaryOpFunction {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> eyre::Result<daft_dsl::ExprRef> {
let args = args
.iter()
.map(|arg| analyzer.to_daft_expr(arg))
.collect::<eyre::Result<Vec<_>>>()?;

let [lhs, rhs] = args
.try_into()
.map_err(|args| eyre::eyre!("requires exactly two arguments; got {:?}", args))?;

Ok(binary_op(self.0, lhs, rhs))
}
}

impl SparkFunction for UnaryFunction {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> eyre::Result<daft_dsl::ExprRef> {
match args {
[arg] => {
let arg = analyzer.to_daft_expr(arg)?;
Ok(self.0(arg))
}
_ => invalid_argument_err!("requires exactly one argument")?,
}
}
}

impl SparkFunction for CountFunction {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> eyre::Result<daft_dsl::ExprRef> {
match args {
[arg] => {
let arg = analyzer.to_daft_expr(arg)?;

let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) {
col("*")
} else {
arg
};

let count = arg.count(CountMode::All).cast(&DataType::Int64);

Ok(count)
}
_ => invalid_argument_err!("requires exactly one argument")?,
}
}
}
5 changes: 4 additions & 1 deletion src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ mod config;
#[cfg(feature = "python")]
mod connect_service;

#[cfg(feature = "python")]
mod functions;

#[cfg(feature = "python")]
mod display;
#[cfg(feature = "python")]
Expand All @@ -23,7 +26,7 @@ mod response_builder;
#[cfg(feature = "python")]
mod session;
#[cfg(feature = "python")]
mod translation;
mod spark_analyzer;
#[cfg(feature = "python")]
pub mod util;

Expand Down
Loading

0 comments on commit beae462

Please sign in to comment.