Skip to content

Commit

Permalink
Rewrites FROM with string path as a table-value function
Browse files Browse the repository at this point in the history
  • Loading branch information
rchowell committed Jan 14, 2025
1 parent 5c3fa87 commit 87aceda
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 11 deletions.
3 changes: 2 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ repos:
(?x)^(
tutorials/.*\.ipynb|
docs/.*\.ipynb|
docs/source/user_guide/fotw/data/
docs/source/user_guide/fotw/data/|
.*\.jsonl
)$
args:
- --autofix
Expand Down
57 changes: 49 additions & 8 deletions src/daft-sql/src/planner.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::{
cell::{Ref, RefCell, RefMut},
collections::{HashMap, HashSet},
path::Path,
rc::Rc,
sync::Arc,
};
Expand All @@ -21,10 +22,11 @@ use daft_functions::{
use daft_logical_plan::{LogicalPlanBuilder, LogicalPlanRef};
use sqlparser::{
ast::{
ArrayElemTypeDef, BinaryOperator, CastKind, ColumnDef, DateTimeField, Distinct,
ExactNumberInfo, ExcludeSelectItem, GroupByExpr, Ident, ObjectName, Query, SelectItem,
SetExpr, Statement, StructField, Subscript, TableAlias, TableWithJoins, TimezoneInfo,
UnaryOperator, Value, WildcardAdditionalOptions, With,
self, ArrayElemTypeDef, BinaryOperator, CastKind, ColumnDef, DateTimeField, Distinct,
ExactNumberInfo, ExcludeSelectItem, FunctionArg, FunctionArgExpr, GroupByExpr, Ident,
ObjectName, Query, SelectItem, SetExpr, Statement, StructField, Subscript, TableAlias,
TableFunctionArgs, TableWithJoins, TimezoneInfo, UnaryOperator, Value,
WildcardAdditionalOptions, With,
},
dialect::GenericDialect,
parser::{Parser, ParserOptions},
Expand Down Expand Up @@ -1056,9 +1058,24 @@ impl<'a> SQLPlanner<'a> {
}
}

/// Plan a `FROM 'path/to/file.extension'` table factor.
fn plan_relation_path(&self, _name: &ObjectName) -> SQLPlannerResult<Relation> {
unsupported_sql_err!("Unsupported table factor: Path")
/// Plan a `FROM <path>` table factor by rewriting to relevant table-value function.
fn plan_relation_path(&self, name: &ObjectName) -> SQLPlannerResult<Relation> {
let path = name.to_string();
let path = &path[1..path.len() - 1]; // strip single-quotes ' '
let func = match Path::new(path).extension() {
Some(ext) if ext.eq_ignore_ascii_case("csv") => "read_csv",
Some(ext) if ext.eq_ignore_ascii_case("json") => "read_json",
Some(ext) if ext.eq_ignore_ascii_case("parquet") => "read_parquet",
Some(_) => invalid_operation_err!("unsupported file path extension: {}", name),
None => invalid_operation_err!("unsupported file path, no extension: {}", name),

Check warning on line 1070 in src/daft-sql/src/planner.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/planner.rs#L1069-L1070

Added lines #L1069 - L1070 were not covered by tests
};
let args = TableFunctionArgs {
args: vec![FunctionArg::Unnamed(FunctionArgExpr::Expr(
ast::Expr::Value(Value::SingleQuotedString(path.to_string())),
))],
settings: None,
};
self.plan_table_function(func, &args)
}

/// Plan a `FROM <table>` table factor.
Expand Down Expand Up @@ -2280,8 +2297,9 @@ fn unresolve_alias(expr: ExprRef, projection: &[ExprRef]) -> SQLPlannerResult<Ex
#[cfg(test)]
mod tests {
use daft_core::prelude::*;
use sqlparser::ast::{Ident, ObjectName};

use crate::sql_schema;
use crate::{planner::is_table_path, sql_schema};

#[test]
fn test_sql_schema_creates_expected_schema() {
Expand Down Expand Up @@ -2324,4 +2342,27 @@ mod tests {
let expected = Schema::new(vec![Field::new("col1", DataType::Int32)]).unwrap();
assert_eq!(&*result, &expected);
}

#[test]
fn test_is_table_path() {
// single-quoted path should return true
assert!(is_table_path(&ObjectName(vec![Ident {
value: "path/to/file.ext".to_string(),
quote_style: Some('\'')
}])));
// multiple identifiers should return false
assert!(!is_table_path(&ObjectName(vec![
Ident::new("a"),
Ident::new("b")
])));
// double-quoted identifier should return false
assert!(!is_table_path(&ObjectName(vec![Ident {
value: "path/to/file.ext".to_string(),
quote_style: Some('"')
}])));
// unquoted identifier should return false
assert!(!is_table_path(&ObjectName(vec![Ident::new(
"path/to/file.ext"
)])));
}
}
25 changes: 25 additions & 0 deletions tests/assets/json-data/small.jsonl
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
{ "x": 42, "y": "apple", "z": true }
{ "x": 17, "y": "banana", "z": false }
{ "x": 89, "y": "cherry", "z": true }
{ "x": 3, "y": "date", "z": false }
{ "x": 156, "y": "elderberry", "z": true }
{ "x": 23, "y": "fig", "z": true }
{ "x": 777, "y": "grape", "z": false }
{ "x": 444, "y": "honeydew", "z": true }
{ "x": 91, "y": "kiwi", "z": false }
{ "x": 12, "y": "lemon", "z": true }
{ "x": 365, "y": "mango", "z": false }
{ "x": 55, "y": "nectarine", "z": true }
{ "x": 888, "y": "orange", "z": false }
{ "x": 247, "y": "papaya", "z": true }
{ "x": 33, "y": "quince", "z": false }
{ "x": 159, "y": "raspberry", "z": true }
{ "x": 753, "y": "strawberry", "z": false }
{ "x": 951, "y": "tangerine", "z": true }
{ "x": 426, "y": "ugli fruit", "z": false }
{ "x": 87, "y": "vanilla", "z": true }
{ "x": 234, "y": "watermelon", "z": false }
{ "x": 567, "y": "xigua", "z": true }
{ "x": 111, "y": "yuzu", "z": false }
{ "x": 999, "y": "zucchini", "z": true }
{ "x": 123, "y": "apricot", "z": false }
16 changes: 14 additions & 2 deletions tests/sql/test_table_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,26 @@ def sample_schema():
return {"a": daft.DataType.float32(), "b": daft.DataType.string()}


@pytest.mark.skip("read_json table function not supported (yet) see github #3196")
def test_sql_read_json():
df = daft.sql("SELECT * FROM read_json('tests/assets/json-data/small.jsonl')").collect()
expected = daft.read_json("tests/assets/json-data/small.jsonl").collect()
assert df.to_pydict() == expected.to_pydict()


@pytest.mark.skip("read_json table function not supported (yet) see github #3196")
def test_sql_read_json_path():
df = daft.sql("SELECT * FROM 'tests/assets/json-data/small.jsonl'").collect()
expected = daft.read_json("tests/assets/json-data/small.jsonl").collect()
assert df.to_pydict() == expected.to_pydict()


def test_sql_read_parquet():
df = daft.sql("SELECT * FROM read_parquet('tests/assets/parquet-data/mvp.parquet')").collect()
expected = daft.read_parquet("tests/assets/parquet-data/mvp.parquet").collect()
assert df.to_pydict() == expected.to_pydict()


@pytest.mark.skip(reason="Daft SQL does not support table paths (yet)")
def test_sql_read_parquet_path():
df = daft.sql("SELECT * FROM 'tests/assets/parquet-data/mvp.parquet'").collect()
expected = daft.read_parquet("tests/assets/parquet-data/mvp.parquet").collect()
Expand All @@ -32,7 +45,6 @@ def test_sql_read_csv(sample_csv_path):
assert df.to_pydict() == expected.to_pydict()


@pytest.mark.skip(reason="Daft SQL does not support table paths (yet)")
def test_sql_read_csv_path(sample_csv_path):
df = daft.sql(f"SELECT * FROM '{sample_csv_path}'").collect()
expected = daft.read_csv(sample_csv_path).collect()
Expand Down

0 comments on commit 87aceda

Please sign in to comment.