Skip to content

Commit

Permalink
feat(connect): printSchema (#3617)
Browse files Browse the repository at this point in the history
### TODO

- should we reuse `TreeDisplay`?
- remove `unwrap`s


https://github.com/Eventual-Inc/Daft/blob/56e872c6297a37fc3b9406fad28af3e38926717d/src/common/display/src/tree.rs#L3

- should we make our own?


- Example own impl that would need to be tested (don't look at
seriously!)

```rust
pub fn to_tree_string(schema: &Schema) -> eyre::Result<String> {
    let mut output = String::new();
    // Start with root
    writeln!(&mut output, "root")?;
    // Now print each top-level field
    for (name, field) in &schema.fields {
        print_field(&mut output, name, &field.dtype, /*nullable*/ true, 1)?;
    }
    Ok(output)
}

// A helper function to print a field at a given level of indentation.
// level=1 means a single " |-- " prefix, level=2 means
// " |    |-- " and so on, mimicking Spark's indentation style.
fn print_field(
    w: &mut String, 
    field_name: &str, 
    dtype: &DataType, 
    nullable: bool, 
    level: usize
) -> eyre::Result<()> {
    // Construct the prefix for indentation.
    // Spark indentation levels:
    // level 1:  " |-- "
    // level 2:  " |    |-- "
    // level n:  " |" followed by (4*(n-1)) spaces + "-- "
    let indent = if level == 1 {
        format!(" |-- ")
    } else {
        let spaces = " ".repeat(4*(level-1));
        format!(" |{}-- ", spaces)
    };

    // Get a user-friendly string for dtype
    let dtype_str = type_to_string(dtype);

    writeln!(
        w,
        "{}{}: {} (nullable = {})",
        indent, field_name, dtype_str, nullable
    )?;

    // If the dtype is a struct, we must print its child fields with increased indentation.
    if let DataType::Struct(fields) = dtype {
        for field in fields {
            print_field(w, &field.name, &field.dtype, true, level + 1)?;
        }
    }

    Ok(())
}

fn type_to_string(dtype: &DataType) -> String {
    // We want a nice, human-readable type string.
    // Spark generally prints something like "integer", "string", etc.
    // We'll follow a similar style here:
    match dtype {
        DataType::Null => "null".to_string(),
        DataType::Boolean => "boolean".to_string(),
        DataType::Int8
        | DataType::Int16
        | DataType::Int32
        | DataType::Int64
        | DataType::UInt8
        | DataType::UInt16
        | DataType::UInt32
        | DataType::UInt64 => "integer".to_string(), // Spark doesn't differentiate sizes
        DataType::Float32 | DataType::Float64 => "double".to_string(), // Spark calls all floats double for printing
        DataType::Decimal128(_, _) => "decimal".to_string(),
        DataType::Timestamp(_, _) => "timestamp".to_string(),
        DataType::Date => "date".to_string(),
        DataType::Time(_) => "time".to_string(),
        DataType::Duration(_) => "duration".to_string(),
        DataType::Interval => "interval".to_string(),
        DataType::Binary => "binary".to_string(),
        DataType::FixedSizeBinary(_) => "fixed_size_binary".to_string(),
        DataType::Utf8 => "string".to_string(),
        DataType::FixedSizeList(_, _) => "array".to_string(), // Spark calls them arrays
        DataType::List(_) => "array".to_string(),
        DataType::Struct(_) => "struct".to_string(),
        DataType::Map { .. } => "map".to_string(),
        DataType::Extension(_, _, _) => "extension".to_string(),
        DataType::Embedding(_, _) => "embedding".to_string(),
        DataType::Image(_) => "image".to_string(),
        DataType::FixedShapeImage(_, _, _) => "fixed_shape_image".to_string(),
        DataType::Tensor(_) => "tensor".to_string(),
        DataType::FixedShapeTensor(_, _) => "fixed_shape_tensor".to_string(),
        DataType::SparseTensor(_) => "sparse_tensor".to_string(),
        DataType::FixedShapeSparseTensor(_, _) => "fixed_shape_sparse_tensor".to_string(),
        #[cfg(feature = "python")]
        DataType::Python => "python_object".to_string(),
        DataType::Unknown => "unknown".to_string(),
    }
}
```

---------

Co-authored-by: universalmind303 <cory.grinstead@gmail.com>
  • Loading branch information
andrewgazelka and universalmind303 authored Jan 8, 2025
1 parent 426ddd0 commit 519afce
Show file tree
Hide file tree
Showing 7 changed files with 509 additions and 20 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ chrono = "0.4.38"
chrono-tz = "0.10.0"
comfy-table = "7.1.1"
common-daft-config = {path = "src/common/daft-config"}
common-display = {path = "src/common/display", default-features = false}
common-error = {path = "src/common/error", default-features = false}
common-file-formats = {path = "src/common/file-formats"}
common-runtime = {path = "src/common/runtime", default-features = false}
Expand Down
362 changes: 362 additions & 0 deletions src/daft-connect/src/display.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,362 @@
use std::fmt::Write;

use daft_core::prelude::*;

// note: right now this is only implemented for Schema, but we'll want to extend this for our dataframe output, and the plan repr.
pub trait SparkDisplay {
fn repr_spark_string(&self) -> String;
}

impl SparkDisplay for Schema {
fn repr_spark_string(&self) -> String {
let mut output = String::new();
// Start with root
writeln!(&mut output, "root").unwrap();

// Print each top-level field with indentation level 1
for (name, field) in &self.fields {
// We'll rely on a helper function that knows how to print a field with given indentation
write_field(&mut output, name, &field.dtype, 1).unwrap();
}
output
}
}

// Private helpers to mimic the original indentation style and recursive printing:
fn write_field(
w: &mut String,
field_name: &str,
dtype: &DataType,
level: usize,
) -> eyre::Result<()> {
fn write_field_inner(
w: &mut String,
field_name: &str,
dtype: &DataType,
level: usize,
is_list: bool,
) -> eyre::Result<()> {
let indent = make_indent(level);

let dtype_str = type_to_string(dtype);

writeln!(
w,
"{indent}{field_name}: {dtype_str} ({nullable} = true)",
// for some reason, spark prints "containsNulls" instead of "nullable" for lists
nullable = if is_list { "containsNulls" } else { "nullable" }
)?;

// handle nested dtypes
match dtype {
DataType::List(inner) => {
write_field_inner(w, "element", inner, level + 1, true)?;
}
DataType::FixedSizeList(inner, _) => {
write_field_inner(w, "element", inner, level + 1, true)?;
}
DataType::Struct(fields) => {
for field in fields {
write_field_inner(w, &field.name, &field.dtype, level + 1, false)?;
}
}
_ => {}
};
Ok(())
}

write_field_inner(w, field_name, dtype, level, false)
}

// This helper creates indentation of the form:
// level=1: " |-- "
// level=2: " | |-- "
// and so forth.
fn make_indent(level: usize) -> String {
if level == 0 {
// If top-level (i.e., a bare field not in a schema), just return empty.
String::new()
} else if level == 1 {
" |-- ".to_string()
} else {
format!(" |{}-- ", " |".repeat(level - 1))
}
}

fn type_to_string(dtype: &DataType) -> String {
match dtype {
DataType::Null => "null".to_string(),
DataType::Boolean => "boolean".to_string(),
DataType::Int8 => "byte".to_string(),
DataType::Int16 => "short".to_string(),
DataType::Int32 => "integer".to_string(),
DataType::Int64 => "long".to_string(),
DataType::Float32 => "float".to_string(),
DataType::Float64 => "double".to_string(),
DataType::Decimal128(precision, scale) => format!("decimal({precision},{scale})"),
DataType::Timestamp(_, _) => "timestamp".to_string(),
DataType::Date => "date".to_string(),
DataType::Time(_) => "time".to_string(),
DataType::Duration(_) => "duration".to_string(),
DataType::Interval => "interval".to_string(),
DataType::Binary => "binary".to_string(),
DataType::FixedSizeBinary(_) => "arrow.fixed_size_binary".to_string(),
DataType::Utf8 => "string".to_string(),
DataType::FixedSizeList(_, _) => "arrow.fixed_size_list".to_string(),
DataType::List(_) => "array".to_string(),
DataType::Struct(_) => "struct".to_string(),
DataType::Map { .. } => "map".to_string(),
DataType::Extension(_, _, _) => "daft.extension".to_string(),
DataType::Embedding(_, _) => "daft.embedding".to_string(),
DataType::Image(_) => "daft.image".to_string(),
DataType::FixedShapeImage(_, _, _) => "daft.fixed_shape_image".to_string(),
DataType::Tensor(_) => "daft.tensor".to_string(),
DataType::FixedShapeTensor(_, _) => "daft.fixed_shape_tensor".to_string(),
DataType::SparseTensor(_) => "daft.sparse_tensor".to_string(),
DataType::FixedShapeSparseTensor(_, _) => "daft.fixed_shape_sparse_tensor".to_string(),
#[cfg(feature = "python")]
DataType::Python => "daft.python".to_string(),
DataType::Unknown => "unknown".to_string(),
DataType::UInt8 => "arrow.uint8".to_string(),
DataType::UInt16 => "arrow.uint16".to_string(),
DataType::UInt32 => "arrow.uint32".to_string(),
DataType::UInt64 => "arrow.uint64".to_string(),
}
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_empty_schema() -> eyre::Result<()> {
let schema = Schema::empty();
let output = schema.repr_spark_string();
let expected = "root\n";
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_single_field_schema() -> eyre::Result<()> {
let mut fields = Vec::new();
fields.push(Field::new("step", DataType::Int32));
let schema = Schema::new(fields)?;
let output = schema.repr_spark_string();
let expected = "root\n |-- step: integer (nullable = true)\n";
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_multiple_simple_fields() -> eyre::Result<()> {
let mut fields = Vec::new();
fields.push(Field::new("step", DataType::Int32));
fields.push(Field::new("type", DataType::Utf8));
fields.push(Field::new("amount", DataType::Float64));
let schema = Schema::new(fields)?;
let output = schema.repr_spark_string();
let expected = "\
root
|-- step: integer (nullable = true)
|-- type: string (nullable = true)
|-- amount: double (nullable = true)
";
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_struct_field() -> eyre::Result<()> {
// Create a schema with a struct field
let inner_fields = vec![
Field::new("inner1", DataType::Utf8),
Field::new("inner2", DataType::Float32),
];
let struct_dtype = DataType::Struct(inner_fields);

let mut fields = Vec::new();
fields.push(Field::new("parent", struct_dtype));
fields.push(Field::new("count", DataType::Int64));
let schema = Schema::new(fields)?;

let output = schema.repr_spark_string();
let expected = "\
root
|-- parent: struct (nullable = true)
| |-- inner1: string (nullable = true)
| |-- inner2: float (nullable = true)
|-- count: long (nullable = true)
";
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_nested_struct_in_struct() -> eyre::Result<()> {
let inner_struct = DataType::Struct(vec![
Field::new("deep", DataType::Boolean),
Field::new("deeper", DataType::Utf8),
]);
let mid_struct = DataType::Struct(vec![
Field::new("mid1", DataType::Int8),
Field::new("nested", inner_struct),
]);

let mut fields = Vec::new();
fields.push(Field::new("top", mid_struct));
let schema = Schema::new(fields)?;

let output = schema.repr_spark_string();
let expected = "\
root
|-- top: struct (nullable = true)
| |-- mid1: byte (nullable = true)
| |-- nested: struct (nullable = true)
| | |-- deep: boolean (nullable = true)
| | |-- deeper: string (nullable = true)
";
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_list_fields() -> eyre::Result<()> {
let list_of_int = DataType::List(Box::new(DataType::Int16));
let fixed_list_of_floats = DataType::FixedSizeList(Box::new(DataType::Float32), 3);

let mut fields = Vec::new();
fields.push(Field::new("ints", list_of_int));
fields.push(Field::new("floats", fixed_list_of_floats));
let schema = Schema::new(fields)?;

let output = schema.repr_spark_string();
let expected = "\
root
|-- ints: array (nullable = true)
| |-- element: short (containsNulls = true)
|-- floats: arrow.fixed_size_list (nullable = true)
| |-- element: float (containsNulls = true)
";
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_map_field() -> eyre::Result<()> {
let map_type = DataType::Map {
key: Box::new(DataType::Utf8),
value: Box::new(DataType::Int32),
};

let mut fields = Vec::new();
fields.push(Field::new("m", map_type));
let schema = Schema::new(fields)?;

let output = schema.repr_spark_string();
// Spark-like print doesn't show the internal "entries" struct by name, but we do show it as "struct":
let expected = "\
root
|-- m: map (nullable = true)
";
// Note: If you decide to recurse into Map children (currently we do not), you'd see something like:
// | |-- key: string (nullable = true)
// | |-- value: integer (nullable = true)
// If you update the code to print the internals of a map, update the test accordingly.
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_extension_type() -> eyre::Result<()> {
let extension_type =
DataType::Extension("some_ext_type".to_string(), Box::new(DataType::Int32), None);

let mut fields = Vec::new();
fields.push(Field::new("ext_field", extension_type));
let schema = Schema::new(fields)?;

let output = schema.repr_spark_string();
let expected = "\
root
|-- ext_field: daft.extension (nullable = true)
";
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_complex_nested_schema() -> eyre::Result<()> {
// A very nested schema to test indentation and various types together
let struct_inner = DataType::Struct(vec![
Field::new("sub_list", DataType::List(Box::new(DataType::Utf8))),
Field::new(
"sub_struct",
DataType::Struct(vec![
Field::new("a", DataType::Int32),
Field::new("b", DataType::Float64),
]),
),
]);

let main_fields = vec![
Field::new("name", DataType::Utf8),
Field::new("values", DataType::List(Box::new(DataType::Int64))),
Field::new("nested", struct_inner),
];

let mut fields = Vec::new();
fields.push(Field::new("record", DataType::Struct(main_fields)));
let schema = Schema::new(fields)?;

let output = schema.repr_spark_string();
let expected = "\
root
|-- record: struct (nullable = true)
| |-- name: string (nullable = true)
| |-- values: array (nullable = true)
| | |-- element: long (containsNulls = true)
| |-- nested: struct (nullable = true)
| | |-- sub_list: array (nullable = true)
| | | |-- element: string (containsNulls = true)
| | |-- sub_struct: struct (nullable = true)
| | | |-- a: integer (nullable = true)
| | | |-- b: double (nullable = true)
";
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_field_name_special_chars() -> eyre::Result<()> {
// Field with spaces and special characters
let mut fields = Vec::new();
fields.push(Field::new("weird field@!#", DataType::Utf8));
let schema = Schema::new(fields)?;
let output = schema.repr_spark_string();
let expected = "\
root
|-- weird field@!#: string (nullable = true)
";
assert_eq!(output, expected);
Ok(())
}

#[test]
fn test_zero_sized_fixed_list() -> eyre::Result<()> {
// Although unusual, test a fixed size list with size=0
let zero_sized_list = DataType::FixedSizeList(Box::new(DataType::Int8), 0);
let mut fields = Vec::new();
fields.push(Field::new("empty_list", zero_sized_list));
let schema = Schema::new(fields)?;

let output = schema.repr_spark_string();
let expected = "\
root
|-- empty_list: arrow.fixed_size_list (nullable = true)
| |-- element: byte (containsNulls = true)
";
assert_eq!(output, expected);
Ok(())
}
}
Loading

0 comments on commit 519afce

Please sign in to comment.