From 5c0f3382d25696bb8a5dca1b6e8e332982bedf41 Mon Sep 17 00:00:00 2001 From: Kenneth Moon Date: Wed, 15 May 2024 10:39:14 -0400 Subject: [PATCH] Add better syntax for metaprogramming --- src/exo/pyparser.py | 362 +++++++++++++++++++++++++--------- tests/test_metaprogramming.py | 44 ++--- 2 files changed, 287 insertions(+), 119 deletions(-) diff --git a/src/exo/pyparser.py b/src/exo/pyparser.py index 2efce4069..32d8cfca9 100644 --- a/src/exo/pyparser.py +++ b/src/exo/pyparser.py @@ -173,6 +173,24 @@ def pattern(s, filename=None, lineno=None): return parser.result() +QUOTE_CALLBACK_PREFIX = "__quote_callback" +QUOTE_BLOCK_PLACEHOLDER_PREFIX = "__quote_block" +OUTER_SCOPE_HELPER = "__outer_scope" +NESTED_SCOPE_HELPER = "__nested_scope" +UNQUOTE_RETURN_HELPER = "__unquote_val" +UNQUOTE_BLOCK_KEYWORD = "meta" + + +@dataclass +class ExoExpression: + _inner: Any # note: strict typing is not possible as long as PAST/UAST grammar definition is not static + + +@dataclass +class ExoStatementList: + _inner: tuple[Any, ...] + + class QuoteReplacer(pyast.NodeTransformer): def __init__( self, @@ -187,52 +205,72 @@ def __init__( def visit_With(self, node: pyast.With) -> pyast.Any: if ( len(node.items) == 1 - and isinstance(node.items[0].context_expr, pyast.Name) - and node.items[0].context_expr.id == "quote" - and isinstance(node.items[0].context_expr.ctx, pyast.Load) + and isinstance(node.items[0].context_expr, pyast.UnaryOp) + and isinstance(node.items[0].context_expr.op, pyast.Invert) + and isinstance(node.items[0].context_expr.operand, pyast.Name) + and node.items[0].context_expr.operand.id == UNQUOTE_BLOCK_KEYWORD + and isinstance(node.items[0].context_expr.operand.ctx, pyast.Load) + and ( + isinstance(node.items[0].optional_vars, pyast.Name) + or node.items[0].optional_vars is None + ) ): assert ( self.stmt_collector != None ), "Reached quote block with no buffer to place quoted statements" + should_append = node.items[0].optional_vars is None def quote_callback(): - self.stmt_collector.extend( - Parser( - node.body, - self.parser_parent.src_info, - parent_scope=get_parent_scope(depth=2), - is_quote_stmt=True, - parent_exo_locals=self.parser_parent.exo_locals, - ).result() - ) + stmts = Parser( + node.body, + self.parser_parent.src_info, + parent_scope=get_parent_scope(depth=2), + is_quote_stmt=True, + parent_exo_locals=self.parser_parent.exo_locals, + ).result() + if should_append: + self.stmt_collector.extend(stmts) + else: + return ExoStatementList(tuple(stmts)) callback_name = self.unquote_env.register_quote_callback(quote_callback) - return pyast.Expr( - value=pyast.Call( - func=pyast.Name(id=callback_name, ctx=pyast.Load()), - args=[], - keywords=[], + if should_append: + return pyast.Expr( + value=pyast.Call( + func=pyast.Name(id=callback_name, ctx=pyast.Load()), + args=[], + keywords=[], + ) + ) + else: + return pyast.Assign( + targets=[node.items[0].optional_vars], + value=pyast.Call( + func=pyast.Name(id=callback_name, ctx=pyast.Load()), + args=[], + keywords=[], + ), ) - ) else: return super().generic_visit(node) - def visit_Call(self, node: pyast.Call) -> Any: + def visit_UnaryOp(self, node: pyast.UnaryOp) -> Any: if ( - isinstance(node.func, pyast.Name) - and node.func.id == "quote" - and len(node.keywords) == 0 - and len(node.args) == 1 + isinstance(node.op, pyast.Invert) + and isinstance(node.operand, pyast.Set) + and len(node.operand.elts) == 1 ): def quote_callback(): - return Parser( - node.args[0], - self.parser_parent.src_info, - parent_scope=get_parent_scope(depth=2), - is_quote_expr=True, - parent_exo_locals=self.parser_parent.exo_locals, - ).result() + return ExoExpression( + Parser( + node.operand.elts[0], + self.parser_parent.src_info, + parent_scope=get_parent_scope(depth=2), + is_quote_expr=True, + parent_exo_locals=self.parser_parent.exo_locals, + ).result() + ) callback_name = self.unquote_env.register_quote_callback(quote_callback) return pyast.Call( @@ -244,17 +282,11 @@ def quote_callback(): return super().generic_visit(node) -QUOTE_CALLBACK_PREFIX = "__quote_callback" -QUOTE_BLOCK_PLACEHOLDER_PREFIX = "__quote_block" -OUTER_SCOPE_HELPER = "__outer_scope" -NESTED_SCOPE_HELPER = "__nested_scope" -UNQUOTE_RETURN_HELPER = "__unquote_val" - - @dataclass class UnquoteEnv: parent_globals: dict[str, Any] parent_locals: dict[str, Local] + exo_local_vars: dict[str, Any] def mangle_name(self, prefix: str) -> str: index = 0 @@ -279,6 +311,12 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: unbound_names = { name for name, val in self.parent_locals.items() if val is None } + quote_locals = { + name: ExoExpression(val) + for name, val in self.exo_local_vars.items() + if name not in self.parent_locals + } + env_locals = {**quote_locals, **bound_locals} exec( compile( pyast.fix_missing_locations( @@ -289,7 +327,9 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: args=pyast.arguments( posonlyargs=[], args=[ - pyast.arg(arg=arg) for arg in self.parent_locals + *[pyast.arg(arg=arg) for arg in bound_locals], + *[pyast.arg(arg=arg) for arg in unbound_names], + *[pyast.arg(arg=arg) for arg in quote_locals], ], kwonlyargs=[], kw_defaults=[], @@ -332,10 +372,20 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: ), body=pyast.Tuple( elts=[ - pyast.Name( - id=arg, ctx=pyast.Load() - ) - for arg in self.parent_locals + *[ + pyast.Name( + id=arg, + ctx=pyast.Load(), + ) + for arg in bound_locals + ], + *[ + pyast.Name( + id=arg, + ctx=pyast.Load(), + ) + for arg in unbound_names + ], ], ctx=pyast.Load(), ), @@ -370,12 +420,18 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: ctx=pyast.Load(), ), args=[ - ( + *[ + pyast.Name(id=name, ctx=pyast.Load()) + for name in bound_locals + ], + *[ pyast.Constant(value=None) - if val is None - else pyast.Name(id=name, ctx=pyast.Load()) - ) - for name, val in self.parent_locals.items() + for _ in unbound_names + ], + *[ + pyast.Name(id=name, ctx=pyast.Load()) + for name in quote_locals + ], ], keywords=[], ), @@ -388,9 +444,9 @@ def interpret_quote_block(self, stmts: list[pyast.stmt]) -> Any: "exec", ), self.parent_globals, - bound_locals, + env_locals, ) - return bound_locals[UNQUOTE_RETURN_HELPER] + return env_locals[UNQUOTE_RETURN_HELPER] def interpret_quote_expr(self, expr: pyast.expr): return self.interpret_quote_block([pyast.Return(value=expr)]) @@ -477,6 +533,51 @@ def pop(self): def err(self, node, errstr, origin=None): raise ParseError(f"{self.getsrcinfo(node)}: {errstr}") from origin + def make_exo_var_asts(self, srcinfo): + return { + name: self.AST.Read(val, [], srcinfo) + for name, val in self.exo_locals.items() + if isinstance(val, Sym) + } + + def try_eval_unquote( + self, unquote_node: pyast.expr + ) -> Union[tuple[()], tuple[Any]]: + if isinstance(unquote_node, pyast.Set): + if len(unquote_node.elts) != 1: + self.err(unquote_node, "Unquote must take 1 argument") + else: + unquote_env = UnquoteEnv( + self.parent_scope.get_globals(), + self.parent_scope.read_locals(), + self.make_exo_var_asts(self.getsrcinfo(unquote_node)), + ) + quote_replacer = QuoteReplacer(self, unquote_env) + unquoted = unquote_env.interpret_quote_expr( + quote_replacer.visit(copy.deepcopy(unquote_node.elts[0])) + ) + return (unquoted,) + elif ( + isinstance(unquote_node, pyast.Name) + and isinstance(unquote_node.ctx, pyast.Load) + and unquote_node.id not in self.exo_locals + ): + cur_globals = self.parent_scope.get_globals() + cur_locals = self.parent_scope.read_locals() + return ( + ( + UnquoteEnv( + cur_globals, + cur_locals, + self.make_exo_var_asts(self.getsrcinfo(unquote_node)), + ).interpret_quote_expr(unquote_node), + ) + if unquote_node.id in cur_locals or unquote_node.id in cur_globals + else tuple() + ) + else: + return tuple() + def eval_expr(self, expr): assert isinstance(expr, pyast.expr) return UnquoteEnv( @@ -485,6 +586,7 @@ def eval_expr(self, expr): **self.parent_scope.read_locals(), **{k: BoundLocal(v) for k, v in self.exo_locals.items()}, }, + self.make_exo_var_asts(self.getsrcinfo(expr)), ).interpret_quote_expr(expr) # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - # @@ -747,29 +849,16 @@ def parse_num_type(self, node, is_arg=False): elif isinstance(node, pyast.Name) and node.id in Parser._prim_types: return Parser._prim_types[node.id] - elif ( - isinstance(node, pyast.Call) - and isinstance(node.func, pyast.Name) - and node.func.id == "unquote" - ): - if len(node.keywords) != 0: - self.err(node, "Unquote must take non-keyword argument") - elif len(node.args) != 1: - self.err(node, "Unquote must take 1 argument") - else: - unquote_env = UnquoteEnv( - self.parent_scope.get_globals(), self.parent_scope.read_locals() - ) - quote_replacer = QuoteReplacer(self, unquote_env) - unquoted = unquote_env.interpret_quote_expr( - quote_replacer.visit(copy.deepcopy(node.args[0])) - ) + else: + unquote_eval_result = self.try_eval_unquote(node) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] if isinstance(unquoted, str) and unquoted in Parser._prim_types: return Parser._prim_types[unquoted] else: self.err(node, "Unquote computation did not yield valid type") - else: - self.err(node, "unrecognized type: " + astor.dump_tree(node)) + else: + self.err(node, "unrecognized type: " + astor.dump_tree(node)) def parse_stmt_block(self, stmts): assert isinstance(stmts, list) @@ -781,11 +870,14 @@ def parse_stmt_block(self, stmts): if ( len(s.items) == 1 and isinstance(s.items[0].context_expr, pyast.Name) - and s.items[0].context_expr.id == "unquote" + and s.items[0].context_expr.id == UNQUOTE_BLOCK_KEYWORD and isinstance(s.items[0].context_expr.ctx, pyast.Load) + and s.items[0].optional_vars is None ): unquote_env = UnquoteEnv( - self.parent_scope.get_globals(), self.parent_scope.read_locals() + self.parent_scope.get_globals(), + self.parent_scope.read_locals(), + self.make_exo_var_asts(self.getsrcinfo(s)), ) quoted_stmts = [] quote_stmt_replacer = QuoteReplacer(self, unquote_env, quoted_stmts) @@ -797,7 +889,28 @@ def parse_stmt_block(self, stmts): ) rstmts.extend(quoted_stmts) else: - self.err(s.id, "Expected unquote") + self.err(s, "Expected unquote") + elif isinstance(s, pyast.Expr) and isinstance(s.value, pyast.Set): + if len(s.value.elts) != 1: + self.err(s, "Unquote must take 1 argument") + else: + unquoted = self.try_eval_unquote(s.value)[0] + if ( + isinstance(unquoted, ExoStatementList) + and isinstance(unquoted._inner, tuple) + and all( + map( + lambda inner_s: isinstance(inner_s, self.AST.stmt), + unquoted._inner, + ) + ) + ): + rstmts.extend(unquoted._inner) + else: + self.err( + s, + "Statement-level unquote expression must return Exo statements", + ) # ----- Assginment, Reduction, Var Declaration/Allocation parsing elif isinstance(s, (pyast.Assign, pyast.AnnAssign, pyast.AugAssign)): # parse the rhs first, if it's present @@ -1128,12 +1241,75 @@ def parse_array_indexing(self, node): if not isinstance(node.value, pyast.Name): self.err(node, "expected access to have form 'x' or 'x[...]'") - is_window = any(isinstance(e, pyast.Slice) for e in dims) - idxs = [ - (self.parse_slice(e, node) if is_window else self.parse_expr(e)) - for e in dims - ] + def unquote_to_index(unquoted, ref_node, srcinfo, top_level): + if isinstance(unquoted, (int, float)): + return self.AST.Const(unquoted, self.getsrcinfo(e)) + elif isinstance(unquoted, ExoExpression) and isinstance( + unquoted._inner, self.AST.expr + ): + return unquoted._inner + elif isinstance(unquoted, slice) and top_level: + if unquoted.step is None: + return UAST.Interval( + ( + None + if unquoted.start is None + else unquote_to_index(unquoted.start, False) + ), + ( + None + if unquoted.stop is None + else unquote_to_index(unquoted.stop, False) + ), + srcinfo, + ) + else: + self.err(ref_node, "Unquote returned slice index with step") + else: + self.err( + ref_node, "Unquote received input that couldn't be unquoted" + ) + idxs = [] + srcinfo_for_idxs = [] + for e in dims: + if sys.version_info[:3] >= (3, 9): + srcinfo = self.getsrcinfo(e) + else: + if isinstance(e, pyast.Index): + e = e.value + srcinfo = self.getsrcinfo(e) + else: + srcinfo = self.getsrcinfo(node) + if isinstance(e, pyast.Slice): + idxs.append(self.parse_slice(e, node)) + srcinfo_for_idxs.append(srcinfo) + unquote_eval_result = self.try_eval_unquote(e) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + + else: + unquote_eval_result = self.try_eval_unquote(e) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + if isinstance(unquoted, tuple): + for unquoted_val in unquoted: + idxs.append( + unquote_to_index(unquoted_val, e, srcinfo, True) + ) + srcinfo_for_idxs.append(srcinfo) + else: + idxs.append(unquote_to_index(unquoted, e, srcinfo, True)) + srcinfo_for_idxs.append(srcinfo) + else: + idxs.append(self.parse_expr(e)) + srcinfo_for_idxs.append(srcinfo) + + is_window = any(map(lambda idx: isinstance(idx, UAST.Interval), idxs)) + if is_window: + for i in range(len(idxs)): + if not isinstance(idxs[i], UAST.Interval): + idxs[i] = UAST.Point(idxs[i], srcinfo_for_idxs[i]) return node.value, idxs, is_window else: assert False, "bad case" @@ -1166,7 +1342,18 @@ def parse_slice(self, e, node): # parse expressions, including values, indices, and booleans def parse_expr(self, e): - if isinstance(e, (pyast.Name, pyast.Subscript)): + unquote_eval_result = self.try_eval_unquote(e) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + if isinstance(unquoted, (int, float)): + return self.AST.Const(unquoted, self.getsrcinfo(e)) + elif isinstance(unquoted, ExoExpression) and isinstance( + unquoted._inner, self.AST.expr + ): + return unquoted._inner + else: + self.err(e, "Unquote received input that couldn't be unquoted") + elif isinstance(e, (pyast.Name, pyast.Subscript)): nm_node, idxs, is_window = self.parse_array_indexing(e) if self.is_fragment: @@ -1344,27 +1531,8 @@ def parse_expr(self, e): return res elif isinstance(e, pyast.Call): - if isinstance(e.func, pyast.Name) and e.func.id == "unquote": - if len(e.keywords) != 0: - self.err(e, "Unquote must take non-keyword argument") - elif len(e.args) != 1: - self.err(e, "Unquote must take 1 argument") - else: - unquote_env = UnquoteEnv( - self.parent_scope.get_globals(), self.parent_scope.read_locals() - ) - quote_replacer = QuoteReplacer(self, unquote_env) - unquoted = unquote_env.interpret_quote_expr( - quote_replacer.visit(copy.deepcopy(e.args[0])) - ) - if isinstance(unquoted, (int, float)): - return self.AST.Const(unquoted, self.getsrcinfo(e)) - elif isinstance(unquoted, self.AST.expr): - return unquoted - else: - self.err(e, "Unquote received input that couldn't be unquoted") # handle stride expression - elif isinstance(e.func, pyast.Name) and e.func.id == "stride": + if isinstance(e.func, pyast.Name) and e.func.id == "stride": if ( len(e.keywords) > 0 or len(e.args) != 2 diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py index ec73e1980..93bfe7459 100644 --- a/tests/test_metaprogramming.py +++ b/tests/test_metaprogramming.py @@ -8,9 +8,9 @@ def test_unrolling(golden): def foo(a: i8): b: i8 b = 0 - with unquote: + with meta: for _ in range(10): - with quote: + with ~meta: b += a c_file, _ = compile_procs_to_strings([foo], "test.h") @@ -23,12 +23,12 @@ def foo(cond: bool): @proc def bar(a: i8): b: i8 - with unquote: + with meta: if cond: - with quote: + with ~meta: b = 0 else: - with quote: + with ~meta: b += 1 return bar @@ -44,7 +44,7 @@ def test_scoping(golden): @proc def foo(a: i8): - a = unquote(a) + a = {a} c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden @@ -55,10 +55,10 @@ def test_scope_nesting(golden): @proc def foo(a: i8, b: i8): - with unquote: + with meta: y = 2 - with quote: - a = unquote(quote(b) if x == 3 and y == 2 else quote(a)) + with ~meta: + a = {~{b} if x == 3 and y == 2 else ~{a}} c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden @@ -70,9 +70,9 @@ def test_global_scope(): @proc def foo(a: i8): a = 0 - with unquote: - with quote: - with unquote: + with meta: + with ~meta: + with meta: global dict cell[0] = dict dict = None @@ -85,7 +85,7 @@ def test_constant_lifting(golden): @proc def foo(a: f64): - a = unquote((x**x + x) / x) + a = {(x**x + x) / x} c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden @@ -94,10 +94,10 @@ def foo(a: f64): def test_type_params(golden): def foo(T: str, U: str): @proc - def bar(a: unquote(T), b: unquote(U)): - c: unquote(T)[4] + def bar(a: {T}, b: {U}): + c: {T}[4] for i in seq(0, 3): - d: unquote(T) + d: {T} d = b c[i + 1] = a + c[i] * d a = c[3] @@ -118,11 +118,11 @@ def foo(): @proc def bar(a: i32): - with unquote: + with meta: for _ in range(10): foo() - with quote: - a += unquote(cell[0]) + with ~meta: + a += {cell[0]} c_file, _ = compile_procs_to_strings([bar], "test.h") assert c_file == golden @@ -133,10 +133,10 @@ def test_capture_nested_quote(golden): @proc def foo(a: i32): - with unquote: + with meta: for _ in range(3): - with quote: - a = unquote(a) + with ~meta: + a = {a} c_file, _ = compile_procs_to_strings([foo], "test.h") assert c_file == golden