Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(connect): internal refactoring to make connect code more organized & extensible #3680

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
6fc29c9
wip
universalmind303 Jan 8, 2025
6ef13bc
wip
universalmind303 Jan 8, 2025
445b0d6
wip
universalmind303 Jan 8, 2025
5114ef6
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into rust…
universalmind303 Jan 8, 2025
f2c4074
wip
universalmind303 Jan 8, 2025
8535db9
wip
universalmind303 Jan 9, 2025
1ade876
wip
universalmind303 Jan 9, 2025
3ac05d9
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into rust…
universalmind303 Jan 9, 2025
0a1c028
ray runner for connect
universalmind303 Jan 10, 2025
1784fec
fix compile feature checks
universalmind303 Jan 10, 2025
67762b0
machete
universalmind303 Jan 10, 2025
8959c2f
fix compile feature checks
universalmind303 Jan 10, 2025
4b83883
fix compile feature checks
universalmind303 Jan 10, 2025
eb477e8
fix compile feature checks
universalmind303 Jan 10, 2025
42ebb47
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into rust…
universalmind303 Jan 10, 2025
f3849b5
add config var for "daft.runner.ray.address"
universalmind303 Jan 10, 2025
be7247f
chore: clean up daft spark with better error handling
universalmind303 Jan 13, 2025
3fbec97
revert handle_count changes
universalmind303 Jan 13, 2025
38118f1
fix feature flagging error
universalmind303 Jan 13, 2025
c24ceb4
add distinct and sort
universalmind303 Jan 13, 2025
bba17d6
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into erro…
universalmind303 Jan 14, 2025
05b73eb
code cleanup
universalmind303 Jan 14, 2025
7c172b3
code cleanup
universalmind303 Jan 14, 2025
ed0d04b
Merge branch 'error-messages' of https://github.com/universalmind303/…
universalmind303 Jan 14, 2025
016ac3e
refactor(connect): make functions more extensible and consolidate lp …
universalmind303 Jan 14, 2025
bf0d1b8
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into erro…
universalmind303 Jan 14, 2025
e667b96
fix logic bug
universalmind303 Jan 14, 2025
88add4e
minor cleanup
universalmind303 Jan 14, 2025
bfebd60
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into erro…
universalmind303 Jan 15, 2025
fed17d3
properly remove file
universalmind303 Jan 15, 2025
4a1c5f7
Merge branch 'error-messages' of https://github.com/universalmind303/…
universalmind303 Jan 15, 2025
3b5901a
Merge branch 'main' of https://github.com/Eventual-Inc/Daft into conn…
universalmind303 Jan 15, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions src/daft-connect/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ dashmap = "6.1.0"
eyre = "0.6.12"
futures = "0.3.31"
itertools = {workspace = true}
once_cell = {workspace = true}
pyo3 = {workspace = true, optional = true}
spark-connect = {workspace = true}
textwrap = "0.16.1"
Expand Down
4 changes: 2 additions & 2 deletions src/daft-connect/src/connect_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
invalid_argument_err, not_yet_implemented,
response_builder::ResponseBuilder,
session::Session,
translation::{self, SparkAnalyzer},
spark_analyzer::{to_spark_datatype, SparkAnalyzer},
util::FromOptionalField,
};

Expand Down Expand Up @@ -180,7 +180,7 @@

let daft_schema = daft_schema.to_struct();

let schema = translation::to_spark_datatype(&daft_schema);
let schema = to_spark_datatype(&daft_schema);

Check warning on line 183 in src/daft-connect/src/connect_service.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/connect_service.rs#L183

Added line #L183 was not covered by tests

Ok(Response::new(rb.schema_response(schema)))
}
Expand Down
10 changes: 5 additions & 5 deletions src/daft-connect/src/execute.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ use tonic::{codegen::tokio_stream::wrappers::ReceiverStream, Status};
use tracing::debug;

use crate::{
not_yet_implemented, response_builder::ResponseBuilder, session::Session, translation,
util::FromOptionalField, ExecuteStream, Runner,
not_yet_implemented, response_builder::ResponseBuilder, session::Session,
spark_analyzer::SparkAnalyzer, util::FromOptionalField, ExecuteStream, Runner,
};

impl Session {
Expand Down Expand Up @@ -93,7 +93,7 @@ impl Session {

tokio::spawn(async move {
let execution_fut = async {
let translator = translation::SparkAnalyzer::new(&this);
let translator = SparkAnalyzer::new(&this);
match command.rel_type {
Some(RelType::ShowString(ss)) => {
let response = this.show_string(*ss, res.clone()).await?;
Expand Down Expand Up @@ -208,7 +208,7 @@ impl Session {
}
};

let translator = translation::SparkAnalyzer::new(&this);
let translator = SparkAnalyzer::new(&this);

let plan = translator.to_logical_plan(input).await?;

Expand Down Expand Up @@ -241,7 +241,7 @@ impl Session {
show_string: ShowString,
response_builder: ResponseBuilder<ExecutePlanResponse>,
) -> eyre::Result<ExecutePlanResponse> {
let translator = translation::SparkAnalyzer::new(self);
let translator = SparkAnalyzer::new(self);

let ShowString {
input,
Expand Down
55 changes: 55 additions & 0 deletions src/daft-connect/src/functions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use std::{collections::HashMap, sync::Arc};

use once_cell::sync::Lazy;
use spark_connect::Expression;

use crate::spark_analyzer::SparkAnalyzer;
mod core;

pub(crate) static CONNECT_FUNCTIONS: Lazy<SparkFunctions> = Lazy::new(|| {
let mut functions = SparkFunctions::new();
functions.register::<core::CoreFunctions>();
functions
});

pub trait SparkFunction: Send + Sync {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> eyre::Result<daft_dsl::ExprRef>;
}
Comment on lines +15 to +21
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Something like:

enum SparkFunction {
  BinaryOpFunction(BinaryOpFunction),
  UnaryFunction(UnaryFunction),
  CountFunction(CountFunction),
}

impl SparkFunction {
  fn to_expr(&self, args: &[Expression], analyzer: &SparkAnalyzer) -> eyre::Result<daft_dsl::ExprRef> {
    match self {
      Self::BinaryOpFunction(..) => ..,
      Self::UnaryFunction(..) => ..,
      Self::CountFunction(..) => ..,
    }
  }
}

This would allow us to avoid all the dyn SparkFunction stuff elsewhere.

Copy link
Contributor Author

@universalmind303 universalmind303 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so the main reasons behind using the trait impl instead of enums is

  • it's (subjectively) a little easier to work with
  • performance overhead here is "relatively" small in comparison to the actual execution
  • most importantly, it makes it possible to support udfs later down the road


pub struct SparkFunctions {
pub(crate) map: HashMap<String, Arc<dyn SparkFunction>>,
}

impl SparkFunctions {
/// Create a new [SparkFunction] instance.
#[must_use]
pub fn new() -> Self {
Self {
map: HashMap::new(),
}
}

/// Register the module to the [SparkFunctions] instance.
pub fn register<M: FunctionModule>(&mut self) {
M::register(self);
}
/// Add a [FunctionExpr] to the [SparkFunction] instance.
pub fn add_fn<F: SparkFunction + 'static>(&mut self, name: &str, func: F) {
self.map.insert(name.to_string(), Arc::new(func));
}

/// Get a function by name from the [SparkFunctions] instance.
#[must_use]
pub fn get(&self, name: &str) -> Option<&Arc<dyn SparkFunction>> {
self.map.get(name)
}
}

pub trait FunctionModule {
/// Register this module to the given [SparkFunctions] table.
fn register(_parent: &mut SparkFunctions);
}
105 changes: 105 additions & 0 deletions src/daft-connect/src/functions/core.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
use daft_core::count_mode::CountMode;
use daft_dsl::{binary_op, col, ExprRef, Operator};
use daft_schema::dtype::DataType;
use spark_connect::Expression;

use super::{FunctionModule, SparkFunction};
use crate::{invalid_argument_err, spark_analyzer::SparkAnalyzer};

// Core functions are the most basic functions such as `+`, `-`, `*`, `/`, not, notnull, etc.
pub struct CoreFunctions;

impl FunctionModule for CoreFunctions {
fn register(parent: &mut super::SparkFunctions) {
parent.add_fn("==", BinaryOpFunction(Operator::Eq));
parent.add_fn("!=", BinaryOpFunction(Operator::NotEq));
parent.add_fn("<", BinaryOpFunction(Operator::Lt));
parent.add_fn("<=", BinaryOpFunction(Operator::LtEq));
parent.add_fn(">", BinaryOpFunction(Operator::Gt));
parent.add_fn(">=", BinaryOpFunction(Operator::GtEq));
parent.add_fn("+", BinaryOpFunction(Operator::Plus));
parent.add_fn("-", BinaryOpFunction(Operator::Minus));
parent.add_fn("*", BinaryOpFunction(Operator::Multiply));
parent.add_fn("/", BinaryOpFunction(Operator::TrueDivide));
parent.add_fn("//", BinaryOpFunction(Operator::FloorDivide));
parent.add_fn("%", BinaryOpFunction(Operator::Modulus));
parent.add_fn("&", BinaryOpFunction(Operator::And));
parent.add_fn("|", BinaryOpFunction(Operator::Or));
parent.add_fn("^", BinaryOpFunction(Operator::Xor));
parent.add_fn("<<", BinaryOpFunction(Operator::ShiftLeft));
parent.add_fn(">>", BinaryOpFunction(Operator::ShiftRight));
parent.add_fn("isnotnull", UnaryFunction(|arg| arg.not_null()));
parent.add_fn("isnull", UnaryFunction(|arg| arg.is_null()));
parent.add_fn("not", UnaryFunction(|arg| arg.not()));
parent.add_fn("sum", UnaryFunction(|arg| arg.sum()));
parent.add_fn("mean", UnaryFunction(|arg| arg.mean()));
parent.add_fn("stddev", UnaryFunction(|arg| arg.stddev()));
parent.add_fn("min", UnaryFunction(|arg| arg.min()));
parent.add_fn("max", UnaryFunction(|arg| arg.max()));
parent.add_fn("count", CountFunction);
}
}

pub struct BinaryOpFunction(Operator);
pub struct UnaryFunction(fn(ExprRef) -> ExprRef);
pub struct CountFunction;

impl SparkFunction for BinaryOpFunction {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Doesn't need to be implemented in this PR, but usually, if you have a fixed number of types implementing a trait, then opting for an enum seems to be better form.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

so it's currently a fixed number, but this will expand quite a bit, and it will become impractical to contain all of those in a single enum. as spark has A LOT of functions.

Additionally, we'll need to support dynamically registered UDF's later down the road, so it makes sense to lay the groundwork now instead of needing to do another refactor later down the road.

fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> eyre::Result<daft_dsl::ExprRef> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not the right PR for this, but I'm just curious: Why are we using eyre instead of thiserror or anyhow?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally agree here. It's been on my todo list to change the error handling in daft-connect.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let args = args
.iter()
.map(|arg| analyzer.to_daft_expr(arg))
.collect::<eyre::Result<Vec<_>>>()?;

let [lhs, rhs] = args
.try_into()
.map_err(|args| eyre::eyre!("requires exactly two arguments; got {:?}", args))?;

Ok(binary_op(self.0, lhs, rhs))
Comment on lines +53 to +62
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we could try to keep the same pattern for to_expr as the other implementations? I.e., something like:

match args {
  [lhs, rhs] => ..,
  _ => return invalid_argument_err!("requires exactly two arguments; got {args:?}"),
}

}
}

impl SparkFunction for UnaryFunction {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> eyre::Result<daft_dsl::ExprRef> {
match args {
[arg] => {
let arg = analyzer.to_daft_expr(arg)?;
Ok(self.0(arg))
}
_ => invalid_argument_err!("requires exactly one argument")?,

Check warning on line 77 in src/daft-connect/src/functions/core.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/functions/core.rs#L77

Added line #L77 was not covered by tests
}
}
}

impl SparkFunction for CountFunction {
fn to_expr(
&self,
args: &[Expression],
analyzer: &SparkAnalyzer,
) -> eyre::Result<daft_dsl::ExprRef> {
match args {
[arg] => {
let arg = analyzer.to_daft_expr(arg)?;

let arg = if arg.as_literal().and_then(|lit| lit.as_i32()) == Some(1i32) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Isn't count(2) also equivalent to count(1), which is equivalent to count(*)? In fact, isn't count(n), for any integer n, equivalent to count(*)?

Copy link
Contributor Author

@universalmind303 universalmind303 Jan 15, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in spark, it is only ever sent over as count(1), and it never allows any args

https://spark.apache.org/docs/latest/api/python/reference/pyspark.sql/api/pyspark.sql.DataFrame.count.html

so it can only ever be called via

df.count()

which gets serialized via protobuf as count(1). So this is really about special casing for what spark connect is sending us vs how we internally represent a count(*) / count(n).

col("*")
} else {
arg

Check warning on line 95 in src/daft-connect/src/functions/core.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/functions/core.rs#L95

Added line #L95 was not covered by tests
};

let count = arg.count(CountMode::All).cast(&DataType::Int64);

Ok(count)
}
_ => invalid_argument_err!("requires exactly one argument")?,

Check warning on line 102 in src/daft-connect/src/functions/core.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-connect/src/functions/core.rs#L102

Added line #L102 was not covered by tests
}
}
}
5 changes: 4 additions & 1 deletion src/daft-connect/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@ mod config;
#[cfg(feature = "python")]
mod connect_service;

#[cfg(feature = "python")]
mod functions;

#[cfg(feature = "python")]
mod display;
#[cfg(feature = "python")]
Expand All @@ -23,7 +26,7 @@ mod response_builder;
#[cfg(feature = "python")]
mod session;
#[cfg(feature = "python")]
mod translation;
mod spark_analyzer;
#[cfg(feature = "python")]
pub mod util;

Expand Down
Loading
Loading