Skip to content

Commit

Permalink
Merge branch 'main' into datafusion-aggregate-test-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko authored Sep 16, 2024
2 parents ac5c557 + bf4de3c commit 5c99b6b
Show file tree
Hide file tree
Showing 24 changed files with 191 additions and 4,276 deletions.
63 changes: 53 additions & 10 deletions ibis_substrait/compiler/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,12 @@ def function_id(
op_name = IBIS_SUBSTRAIT_OP_MAPPING[type(op).__name__]
sig_key = self.get_signature(op)

extension_signature = f"{op_name}:{'_'.join(sig_key)}"
# the keys for lookup up scalar functions consist of
# tuple(tuple(input dtypes), output dtype)
# but the signature we generate in the substrait plan only needs the input types
input_key = sig_key[0]

extension_signature = f"{op_name}:{'_'.join(input_key)}"

try:
function_extension = self.function_extensions[extension_signature]
Expand All @@ -91,7 +96,7 @@ def function_id(
)
return function_extension.function_anchor

def get_signature(self, op: ops.Node) -> tuple[str, ...]:
def get_signature(self, op: ops.Node) -> tuple[tuple[str, ...], str]:
"""Validate and upcast (if necessary) scalar function extension signature."""

op_name = IBIS_SUBSTRAIT_OP_MAPPING[type(op).__name__]
Expand All @@ -102,25 +107,31 @@ def get_signature(self, op: ops.Node) -> tuple[str, ...]:
)

anykey = ("any",) * len([arg for arg in op.args if arg is not None])
sigkey = anykey
input_type_key = anykey
output_type_key = IBIS_SUBSTRAIT_TYPE_MAPPING[op.dtype.name]
sigkey = (input_type_key, output_type_key)

any_sigkey = (anykey, output_type_key)

# First check if `any` is an option
# This function will take arguments of any type
# although we still want to check if the number of args is correct
function_extension = _extension_mapping[op_name].get(anykey)
function_extension = _extension_mapping[op_name].get(any_sigkey)

# Then try to look up extension based on input datatypes
# Each substrait function defines the types of the inputs and at this
# stage we should have performed the appropriate casts to ensure that
# argument types match.
if function_extension is None:
sigkey = tuple(
input_type_key = tuple(
[
IBIS_SUBSTRAIT_TYPE_MAPPING[arg.dtype.name] # type: ignore
for arg in op.args
if arg is not None and isinstance(arg, ops.Node)
if arg is not None and isinstance(arg, ops.Value)
]
)
output_type_key = IBIS_SUBSTRAIT_TYPE_MAPPING[op.dtype.name]
sigkey = (input_type_key, output_type_key)
function_extension = _extension_mapping[op_name].get(sigkey)

# Then check if extension is variadic
Expand All @@ -130,15 +141,45 @@ def get_signature(self, op: ops.Node) -> tuple[str, ...]:
# type is only repeated once, so we try to perform a lookup that way, then
# assert, if we find anything, that the function is, indeed, variadic.
if function_extension is None:
function_extension = _extension_mapping[op_name].get((sigkey[0],))
# variadic signature would be in the form of
# ((oneof_input_arg_dype,), output_dtype)
variadic_sig = ((sigkey[0][0],), sigkey[1])
function_extension = _extension_mapping[op_name].get(variadic_sig)
if function_extension is not None:
assert function_extension.variadic
# Function signature for a variadic should contain the type of
# the argument(s) at _least_ once but ideally should contain
# types == the minimum number of variadic args allowed (but keep
# it nonzero)
arg_count_min = max(function_extension.variadic.get("min", 0), 1)
sigkey = (sigkey[0],) * arg_count_min
input_type_key = (sigkey[0][0],) * arg_count_min
output_type_key = IBIS_SUBSTRAIT_TYPE_MAPPING[op.dtype.name]
sigkey = (input_type_key, output_type_key)

# Then check if we have an op that has a `date` somewhere in the input
# args and the output listed as `i32`.
# Ibis assumes i32 for the output of all time extraction functions
# because no one is going to be around in i64 years, but Substrait
# expects i64 as the output
if function_extension is None:
if "date" in sigkey[0] and sigkey[1] == "i32":
sigkey = (sigkey[0], "i64")
function_extension = _extension_mapping[op_name].get(sigkey)

# Ibis doesn't always handle decimal promotion correctly (I think?)
# And all decimal inputs are expected to be decimal outputs, so we have
# to massage the signature key
if function_extension is None:
if set(sigkey[0]) == {"dec"} and sigkey[1] != "dec":
sigkey = (sigkey[0], "dec")
function_extension = _extension_mapping[op_name].get(sigkey)

# How many special cases do you want? We've got lots.
# Some string functions can only have i64 outputs
if function_extension is None:
if isinstance(op, ops.StringLength):
sigkey = (sigkey[0], "i64")
function_extension = _extension_mapping[op_name].get(sigkey)

# If it's still None then we're borked.
if function_extension is None:
Expand All @@ -151,15 +192,17 @@ def get_signature(self, op: ops.Node) -> tuple[str, ...]:
def create_extension(
self,
op_name: str,
sigkey: tuple[str, ...],
sigkey: tuple[tuple[str, ...], str],
) -> ste.SimpleExtensionDeclaration.ExtensionFunction:
"""Register extension uri and create extension function."""

function_extension = _extension_mapping[op_name][sigkey]
extension_uri = self.register_extension_uri(function_extension.uri)

input_key = sigkey[0]

extension_function = self.create_extension_function(
extension_uri, f"{op_name}:{'_'.join(sigkey)}"
extension_uri, f"{op_name}:{'_'.join(input_key)}"
)

return extension_function
Expand Down
72 changes: 51 additions & 21 deletions ibis_substrait/compiler/mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"CountStar": "count",
"CountDistinct": "count",
"Divide": "divide",
"SubstraitDivide": "divide",
"EndsWith": "ends_with",
"Equals": "equal",
"Exp": "exp",
Expand Down Expand Up @@ -69,6 +70,7 @@
"RegexReplace": "regexp_replace",
"Repeat": "repeat",
"Reverse": "reverse",
"SubstraitRound": "round",
"Round": "round",
"RPad": "rpad",
"RStrip": "rtrim",
Expand Down Expand Up @@ -119,26 +121,52 @@
}

_normalized_key_names = {
# decimal precision and scale aren't part of the
# extension signature they're passed in separately
"decimal<p, s>": "dec",
"decimal<p,s>": "dec",
"decimal<p1,s1>": "dec",
"decimal<p2,s2>": "dec",
# we don't care about string length
"fixedchar<l1>": "str",
"fixedchar<l2>": "str",
"varchar<l1>": "str",
"varchar<l2>": "str",
"varchar<l3>": "str",
# for now ignore nullability marker
"boolean?": "bool",
# why is there a 1?
"any1": "any",
"Date": "date",
"binary": "vbin",
"interval_compound": "icompound",
"interval_day": "iday",
"interval_year": "iyear",
"string": "str",
"timestamp": "ts",
"timestamp_tz": "tstz",
}


def normalize_substrait_type_names(typ: str) -> str:
# First strip off any punctuation
typ = typ.strip("?").lower()

# Common prefixes whose information does not matter to an extension function
# signature
for complex_type, abbr in [
("fixedchar", "fchar"),
("varchar", "vchar"),
("fixedbinary", "fbin"),
("decimal", "dec"),
("precision_timestamp", "pts"),
("precision_timestamp_tz", "ptstz"),
("struct", "struct"),
("list", "list"),
("map", "map"),
("any", "any"),
("boolean", "bool"),
# Absolute garbage type info
("decimal", "dec"),
("delta", "dec"),
("prec", "dec"),
("scale", "dec"),
("init_", "dec"),
("min_", "dec"),
("max_", "dec"),
]:
if typ.lower().startswith(complex_type):
typ = abbr

# Then pass through the dictionary of mappings, defaulting to just the
# existing string
typ = _normalized_key_names.get(typ.lower(), typ.lower())
return typ


_extension_mapping: Mapping[str, Any] = defaultdict(dict)


Expand All @@ -151,13 +179,13 @@ def __init__(self, name: str) -> None:
self.uri: str = ""

def parse(self, impl: Mapping[str, Any]) -> None:
self.rtn = impl["return"]
self.rtn = normalize_substrait_type_names(impl["return"])
self.nullability = impl.get("nullability", False)
self.variadic = impl.get("variadic", False)
if input_args := impl.get("args", []):
for val in input_args:
if typ := val.get("value", None):
typ = _normalized_key_names.get(typ.lower(), typ.lower())
if typ := val.get("value"):
typ = normalize_substrait_type_names(typ)
self.inputs.append(typ)
elif arg_name := val.get("name", None):
self.arg_names.append(arg_name)
Expand Down Expand Up @@ -212,7 +240,9 @@ def register_extension_yaml(
for function in named_functions:
for func in _parse_func(function):
func.uri = uri or f"{prefix}/{fname.name}"
_extension_mapping[function["name"]][tuple(func.inputs)] = func
_extension_mapping[function["name"]][(tuple(func.inputs), func.rtn)] = (
func
)


def _populate_default_extensions() -> None:
Expand Down
64 changes: 56 additions & 8 deletions ibis_substrait/compiler/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
import operator
import uuid
from collections.abc import Iterable, Mapping, MutableMapping, Sequence
from typing import Any, TypeVar, Union
from typing import Any, Optional, TypeVar, Union

import ibis
import ibis.expr.datatypes as dt
import ibis.expr.operations as ops
import ibis.expr.rules as rlz
import ibis.expr.schema as sch
import ibis.expr.types as ir
from ibis import util
Expand All @@ -29,6 +30,7 @@
from ibis_substrait.compiler.core import SubstraitCompiler, _get_fields
from ibis_substrait.compiler.mapping import (
IBIS_SUBSTRAIT_OP_MAPPING,
IBIS_SUBSTRAIT_TYPE_MAPPING,
_extension_mapping,
)

Expand Down Expand Up @@ -505,17 +507,17 @@ def value_op(
) -> stalg.Expression:
# Check if scalar function is valid for input dtype(s) and insert casts as needed to
# make sure inputs are correct.
op = _check_and_upcast(op)
newop = _check_and_upcast(op)
# given the details of `op` -> function id
return stalg.Expression(
scalar_function=stalg.Expression.ScalarFunction(
function_reference=compiler.function_id(op),
output_type=translate(op.dtype),
function_reference=compiler.function_id(newop),
output_type=translate(newop.dtype),
arguments=[
stalg.FunctionArgument(
value=translate(arg, compiler=compiler, **kwargs)
)
for arg in op.args
for arg in newop.args
if isinstance(arg, ops.Value)
],
)
Expand All @@ -538,6 +540,8 @@ def window_op(

lower_bound, upper_bound = _translate_window_bounds(start, end)

func = _check_and_upcast(func)

return stalg.Expression(
window_function=stalg.Expression.WindowFunction(
function_reference=compiler.function_id(func),
Expand Down Expand Up @@ -565,6 +569,7 @@ def _reduction(
compiler: SubstraitCompiler,
**kwargs: Any,
) -> stalg.AggregateFunction:
op = _check_and_upcast(op)
return stalg.AggregateFunction(
function_reference=compiler.function_id(op),
arguments=[
Expand Down Expand Up @@ -1408,8 +1413,11 @@ def _check_and_upcast(op: ops.Node) -> ops.Node:
op_name = IBIS_SUBSTRAIT_OP_MAPPING[type(op).__name__]
anykey = ("any",) * len([arg for arg in op.args if arg is not None])

output_type_key = IBIS_SUBSTRAIT_TYPE_MAPPING[op.dtype.name]
any_sigkey = (anykey, output_type_key)

# First check if `any` is an option
function_extension = _extension_mapping[op_name].get(anykey)
function_extension = _extension_mapping[op_name].get(any_sigkey)

# Otherwise, if the types don't match, cast up
if function_extension is None:
Expand Down Expand Up @@ -1463,15 +1471,55 @@ def _upcast_string_op(op: string_op) -> string_op:
return type(op)(*casted_args)


# Ibis has (usually good) opinions about what the dtypes of certain ops should be
# Substrait disagrees sometimes
class SubstraitRound(ops.Value):
"""Round a value."""

arg: ops.Value[dt.Numeric]
digits: Optional[ops.Value[dt.Integer]] = None

shape = rlz.shape_like("arg")

@property
def dtype(self) -> dt.DataType:
return self.arg.dtype


class SubstraitDivide(ops.NumericBinary):
"""Divide that always returns the same dtype as the inputs."""

@property
def dtype(self) -> dt.DataType:
return self.left.dtype


@_upcast.register(ops.Round)
def _upcast_round_digits(op: ops.Round) -> ops.Round:
def _upcast_round_digits(op: ops.Round) -> SubstraitRound:
# Substrait wants Int32 for decimal place argument to round
if op.digits is None:
raise ValueError(
"Substrait requires that a rounding operation specify the number of digits to round to"
)
elif not isinstance(op.digits.dtype, dt.Int32):
return ops.Round(
return SubstraitRound(
op.arg, op.digits.copy(dtype=dt.Int32(nullable=op.digits.dtype.nullable))
)
return SubstraitRound(op.arg, op.digits)


@_upcast.register(ops.Mean)
def _upcast_mean(op: ops.Mean) -> ops.Mean:
# Substrait wants the input types and output types of reductions to match
# We cast the _input_ type to match the output type
# So mean(some_int) -> float will go to mean(cast(some_int as float)) -> float
if op.arg.dtype != op.dtype:
return ops.Mean(arg=ops.Cast(op.arg, to=op.dtype), where=op.where)

return op


@_upcast.register(ops.Divide)
def _matchy_matchy_divide(op: ops.Divide) -> SubstraitDivide:
new_op = SubstraitDivide(op.left, op.right)
return _upcast_bin_op(new_op)
Empty file.
10 changes: 0 additions & 10 deletions ibis_substrait/extensions/extension_types.yaml

This file was deleted.

Loading

0 comments on commit 5c99b6b

Please sign in to comment.