Skip to content

Commit

Permalink
feat: add extension language interpreter
Browse files Browse the repository at this point in the history
  • Loading branch information
tokoko committed Sep 27, 2024
1 parent 399dac5 commit 41eddad
Show file tree
Hide file tree
Showing 4 changed files with 279 additions and 73 deletions.
114 changes: 114 additions & 0 deletions ibis_substrait/compiler/extension_language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
from typing import Optional, Union

from pyparsing import (
Forward,
Group,
Literal,
ParseResults,
Word,
ZeroOrMore,
alphas,
identchars,
infix_notation,
nums,
oneOf,
opAssoc,
pyparsing_common,
)

supported_functions = ["max", "min"]

multop = oneOf("* /")
plusop = oneOf("+ -")
compop = oneOf("> <")
assign = Word(identchars + nums)("assign")
variable = Word(identchars + nums)("variable")
integer = Word(nums)
expr = Forward()
fn = Group(
oneOf(supported_functions)("fn")
+ Literal("(").suppress()
+ expr
+ Literal(",").suppress()
+ expr
+ Literal(")").suppress()
)
dtype = Group(
Word(alphas)("dtype")
+ Literal("<").suppress()
+ expr
+ Literal(",").suppress()
+ expr
+ Literal(">").suppress()
)

operand = pyparsing_common.integer | fn | dtype | Group(variable)

expr << infix_notation(
operand,
[
(multop, 2, opAssoc.LEFT),
(plusop, 2, opAssoc.LEFT),
(compop, 2, opAssoc.LEFT),
(("?", ":"), 3, opAssoc.RIGHT),
],
)

multiline_expr = ZeroOrMore(
Group(assign + Literal("=").suppress() + Group(expr))
) + expr("result")


def evaluate_pr(pr: ParseResults, values: dict) -> Union[str, int, dict]:
pr_dict = pr.as_dict()
tokens = []

for x in pr:
evaluated = evaluate_pr(x, values) if isinstance(x, ParseResults) else x
if isinstance(evaluated, dict):
values = {**values, **evaluated}
tokens.append(evaluated)

if "assign" in pr_dict:
return {pr_dict["assign"]: tokens[1]}
elif "fn" in pr_dict:
if pr_dict["fn"] == "min":
return min(tokens[1], tokens[2])
if pr_dict["fn"] == "max":
return max(tokens[1], tokens[2])
elif "variable" in pr_dict:
return values[pr_dict["variable"]]
elif "dtype" in pr_dict:
return f"{tokens[0]}<{','.join([str(x) for x in tokens[1:]])}>"
elif "result" in pr_dict:
return tokens[-1]

acc = tokens[0]
for i in range(len(tokens)):
if i % 2 != 0:
if tokens[i] == "*":
acc = acc * tokens[i + 1]
elif tokens[i] == "/":
acc = acc / tokens[i + 1]
elif tokens[i] == "+":
acc = acc + tokens[i + 1]
elif tokens[i] == "-":
acc = acc - tokens[i + 1]
elif tokens[i] == ">":
return acc > tokens[i + 1]
elif tokens[i] == "<":
return acc < tokens[i + 1]
elif tokens[i] == "?":
return tokens[i + 1] if acc else tokens[i + 3]
else:
raise Exception(f"Unknown {tokens[i]}")

return acc


def evaluate(txt: str, values: Optional[dict] = None) -> Union[str, int]:
if not values:
values = {}
result = evaluate_pr(multiline_expr.parseString(txt), values)
assert isinstance(result, (int, str))
return result
71 changes: 71 additions & 0 deletions ibis_substrait/tests/compiler/test_extension_language.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
from ibis_substrait.compiler.extension_language import evaluate


def test_simple_arithmetic():
assert evaluate("1 + 1") == 2


def test_simple_arithmetic_with_variables():
assert evaluate("1 + var", {"var": 2}) == 3


def test_simple_arithmetic_precedence():
assert evaluate("1 + var * 3", {"var": 2}) == 7


def test_simple_arithmetic_parenthesis():
assert evaluate("(1 + var) * 3", {"var": 2}) == 9


def test_min_max():
assert evaluate("min(var, 7) + max(var, 7) * 2", {"var": 5}) == 19


def test_ternary():
assert evaluate("var > 3 ? 1 : 0", {"var": 5}) == 1
assert evaluate("var > 3 ? 1 : 0", {"var": 2}) == 0


def test_multiline():
assert (
evaluate(
"""
temp = min(var, 7) + max(var, 7) * 2
temp + 1
""",
{"var": 5},
)
== 20
)


def test_data_type():
assert evaluate("decimal<S + 1, P + 1>", {"S": 20, "P": 10}) == "decimal<21,11>"


def test_decimal_example():
def func(P1, S1, P2, S2):
init_scale = max(S1, S2) # 14
init_prec = init_scale + max(P1 - S1, P2 - S2) + 1
min_scale = min(init_scale, 6)
delta = init_prec - 38
prec = min(init_prec, 38)
scale_after_borrow = max(init_scale - delta, min_scale)
scale = scale_after_borrow if init_prec > 38 else init_scale
return f"DECIMAL<{prec},{scale}>"

args = {"P1": 10, "S1": 8, "P2": 14, "S2": 2}

assert evaluate(
"""
init_scale = max(S1,S2)
init_prec = init_scale + max(P1 - S1, P2 - S2) + 1
min_scale = min(init_scale, 6)
delta = init_prec - 38
prec = min(init_prec, 38)
scale_after_borrow = max(init_scale - delta, min_scale)
scale = init_prec > 38 ? scale_after_borrow : init_scale
DECIMAL<prec, scale>
""",
args,
) == func(**args)
Loading

0 comments on commit 41eddad

Please sign in to comment.