diff --git a/docs/Metaprogramming.md b/docs/Metaprogramming.md new file mode 100644 index 00000000..6be41d43 --- /dev/null +++ b/docs/Metaprogramming.md @@ -0,0 +1,128 @@ +# Metaprogramming + +In the context of Exo, metaprogramming refers to the composition of [object code](object_code.md) fragments, similar to macros in languages like C. Unlike scheduling operations, metaprogramming does not seek to preserve equivalence as it transforms the object code - instead, it stitches together Exo code fragments, allowing the user to make code more concise or parametrizable. + +The user can get a reference to one of these Exo code fragments through *quoting*, which produces a Python reference to the code fragment. After manipulating this code fragment as a Python object, the user can then paste in a code fragment from Python through *unquoting*. + +## Quoting and Unquoting Statements + +An unquote statement composes any quoted fragments that are executed within it. Syntactically, it is a block of *Python* code which is wrapped in a `with python:` block. Within this block, there may be multiple quoted *Exo* fragments which get executed, which are represented as `with exo:` blocks. + +Note that we are carefully distinguishing *Python* code from *Exo* code here. The Python code inside the `with python:` block does not describe any operations in Exo. Instead, it describes how the Exo fragments within it are composed. Thus, this code can use familiar Python constructs, such as `range(...)` loops (as opposed to Exo's `seq(...)` loops). + +An unquote statement will only read a quoted fragment when its corresponding `with exo:` block gets executed in the Python code. So, the following example results in an empty Exo procedure: +```python +@proc +def foo(a: i32): + with python: + if False: + with exo: + a += 1 +``` + +A `with exo:` may also be executed multiple times. The following example compiles to 10 `a += 1` statements in a row: +```python +@proc +def foo(a: i32): + with python: + for i in range(10): + with exo: + a += 1 +``` + +## Quoting and Unquoting Expressions + +An unquote expression reads the Exo expression that is referred to by a Python object. This is syntactically represented as `{...}`, where the insides of the braces are interpreted as a Python object. To obtain a Python object that refers to an Exo expression, one can use an unquote expression, represented as `~{...}`. + +As a simple example, we can try iterating through a list of Exo expressions. The following example should be equivalent to `a += a; a += b * 2`: +```python +@proc +def foo(a: i32, b: i32): + with python: + exprs = [~{a}, ~{b * 2}] + for expr in exprs: + with exo: + a += {expr} +``` + +### Implicit Quotes and Unquotes + +As we can see from the example, it is often the case that quote and unquote expressions will consist of a single variable. For convenience, if a variable name would otherwise be an invalid reference, the parser will try unquoting or quoting it before throwing an error. So, the following code is equivalent to the previous example: +```python +@proc +def foo(a: i32, b: i32): + with python: + exprs = [a, ~{b * 2}] + for expr in exprs: + with exo: + a += expr +``` + +### Unquoting Numbers + +Besides quoted expressions, a Python number can also be unquoted and converted into the corresponding numeric literal in Exo. The following example will alternate between `a += 1` and `a += 2` 10 times: +```python +@proc +def foo(a: i32): + with python: + for i in range(10): + with exo: + a += {i % 2} +``` + +### Unquoting Types + +When an unquote expression occurs in the place that a type would normally be used in Exo, for instance in the declaration of function arguments, the unquote expression will read the Python object as a string and parse it as the corresponding type. The following example will take an argument whose type depends on the first statement: +```python +T = "i32" + +@proc +def foo(a: {T}, b: {T}): + a += b +``` + +### Unquoting Indices + +Unquote expressions can also be used to index into a buffer. The Python object that gets unquoted may be a single Exo expression, a number, or a slice object. + +### Unquoting Memories + +Memory objects can also be unquoted. Note that memories in Exo correspond to Python objects in the base language anyway, so the process of unquoting an object representing a type of memory in Exo is relatively straightforward. For instance, the memory used to pass in the arguments to this function are determined by the first line: +```python +mem = DRAM + +@proc +def foo(a: i32 @ {mem}, b: i32 @ {mem}): + a += b +``` + +## Binding Quoted Statements to Variables + +A quoted Exo statement does not have to be executed immediately in the place that it is declared. Instead, the quote may be stored in a Python variable using the syntax `with exo as ...:`. It can then be unquoted with the `{...}` operator if it appears as a statement. + +The following example is equivalent to `a += b; a += b`: +```python +@proc +def foo(a: i32, b: i32): + with python: + with exo as stmt: + a += b + {stmt} + {stmt} +``` + +## Limitations + +- There is currently no support for defining quotes outside of an Exo procedure. Thus, it is difficult to share metaprogramming logic between two different Exo procedures. +- Attempting to execute a quoted statement while unquoting an expression will result in an error being thrown. Since Exo expressions do not have side effects, the semantics of such a program would be unclear if allowed. For instance: +```python +@proc +def foo(a: i32): + with python: + def bar(): + with exo: + a += 1 + return 2 + a *= {bar()} # illegal! +``` +- Identifiers that appear on the left hand side of assignment and reductions in Exo cannot be unquoted. This is partly due to limitations in the Python grammar, which Exo must conform to. \ No newline at end of file diff --git a/docs/README.md b/docs/README.md index 7fb03de9..9e26170f 100644 --- a/docs/README.md +++ b/docs/README.md @@ -8,6 +8,7 @@ This directory provides detailed documentation about Exo's interface and interna - To learn how to define **hardware targets externally to the compiler**, refer to [externs.md](externs.md), [instructions.md](instructions.md), and [memories.md](memories.md). - To learn how to define **new scheduling operations externally to the compiler**, refer to [Cursors.md](./Cursors.md) and [inspection.md](./inspection.md). - To understand the available scheduling primitives and how to use them, look into the [primitives/](./primitives) directory. +- To learn about metaprogramming as a method for writing cleaner code, see [Metaprogramming.md](Metaprogramming.md). The scheduling primitives are classified into six categories: diff --git a/src/exo/API.py b/src/exo/API.py index 3a690ca3..d00ff5de 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -21,7 +21,7 @@ # Moved to new file from .core.proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc -from .frontend.pyparser import get_ast_from_python, Parser, get_src_locals +from .frontend.pyparser import get_ast_from_python, Parser, get_parent_scope from .frontend.typecheck import TypeChecker from . import API_cursors as C @@ -36,14 +36,13 @@ def proc(f, _instr=None) -> "Procedure": if not isinstance(f, types.FunctionType): raise TypeError("@proc decorator must be applied to a function") - body, getsrcinfo = get_ast_from_python(f) + body, src_info = get_ast_from_python(f) assert isinstance(body, pyast.FunctionDef) parser = Parser( body, - getsrcinfo, - func_globals=f.__globals__, - srclocals=get_src_locals(depth=3 if _instr else 2), + src_info, + parent_scope=get_parent_scope(depth=3 if _instr else 2), instr=_instr, as_func=True, ) @@ -68,14 +67,13 @@ def parse_config(cls): if not inspect.isclass(cls): raise TypeError("@config decorator must be applied to a class") - body, getsrcinfo = get_ast_from_python(cls) + body, src_info = get_ast_from_python(cls) assert isinstance(body, pyast.ClassDef) parser = Parser( body, - getsrcinfo, - func_globals={}, - srclocals=get_src_locals(depth=2), + src_info, + parent_scope=get_parent_scope(depth=2), as_config=True, ) return Config(*parser.result(), not readwrite) diff --git a/src/exo/frontend/pattern_match.py b/src/exo/frontend/pattern_match.py index 55eca676..32b71704 100644 --- a/src/exo/frontend/pattern_match.py +++ b/src/exo/frontend/pattern_match.py @@ -83,7 +83,7 @@ def match_pattern( # get source location where this is getting called from caller = inspect.getframeinfo(stack_frames[call_depth][0]) func_locals = ChainMap(stack_frames[call_depth].frame.f_locals) - func_globals = ChainMap(stack_frames[call_depth].frame.f_globals) + func_globals = stack_frames[call_depth].frame.f_globals # parse the pattern we're going to use to match p_ast = pyparser.pattern( diff --git a/src/exo/frontend/pyparser.py b/src/exo/frontend/pyparser.py index b341b42e..c2023737 100644 --- a/src/exo/frontend/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -15,6 +15,9 @@ from ..core.prelude import * from ..core.extern import Extern +from typing import Any, Callable, Union, NoReturn, Optional +import copy +from dataclasses import dataclass # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # @@ -31,8 +34,35 @@ def __init__(self, nm): self.nm = nm -def str_to_mem(name): - return getattr(sys.modules[__name__], name) +@dataclass +class SourceInfo: + """ + Source code locations that are needed to compute the location of AST nodes. + """ + + src_file: str + src_line_offset: int + src_col_offset: int + + def get_src_info(self, node: pyast.AST): + """ + Computes the location of the given AST node based on line and column offsets. + """ + return SrcInfo( + filename=self.src_file, + lineno=node.lineno + self.src_line_offset, + col_offset=node.col_offset + self.src_col_offset, + end_lineno=( + None + if node.end_lineno is None + else node.end_lineno + self.src_line_offset + ), + end_col_offset=( + None + if node.end_col_offset is None + else node.end_col_offset + self.src_col_offset + ), + ) # --------------------------------------------------------------------------- # @@ -40,49 +70,101 @@ def str_to_mem(name): # Top-level decorator -def get_ast_from_python(f): +def get_ast_from_python(f: Callable[..., Any]) -> tuple[pyast.stmt, SourceInfo]: # note that we must dedent in case the function is defined # inside of a local scope - rawsrc = inspect.getsource(f) src = textwrap.dedent(rawsrc) n_dedent = len(re.match("^(.*)", rawsrc).group()) - len( re.match("^(.*)", src).group() ) - srcfilename = inspect.getsourcefile(f) - _, srclineno = inspect.getsourcelines(f) - srclineno -= 1 # adjust for decorator line - - # create way to query for src-code information - def getsrcinfo(node): - return SrcInfo( - filename=srcfilename, - lineno=node.lineno + srclineno, - col_offset=node.col_offset + n_dedent, - end_lineno=( - None if node.end_lineno is None else node.end_lineno + srclineno - ), - end_col_offset=( - None if node.end_col_offset is None else node.end_col_offset + n_dedent - ), - ) # convert into AST nodes; which should be a module with a single node module = pyast.parse(src) assert len(module.body) == 1 - return module.body[0], getsrcinfo + return module.body[0], SourceInfo( + src_file=inspect.getsourcefile(f), + src_line_offset=inspect.getsourcelines(f)[1] - 1, + src_col_offset=n_dedent, + ) + + +@dataclass +class BoundLocal: + """ + Wrapper class that represents locals that have been assigned a value. + """ + + val: Any + + +Local = Optional[BoundLocal] # Locals that are unassigned will be represesnted as None + + +@dataclass +class FrameScope: + """ + Wrapper around frame object to read local and global variables. + """ + + frame: inspect.frame + + def get_globals(self) -> dict[str, Any]: + """ + Get globals dictionary for the frame. The globals dictionary is not a copy. If the + returned dictionary is modified, the globals of the scope will be changed. + """ + return self.frame.f_globals + + def read_locals(self) -> dict[str, Local]: + """ + Return a copy of the local variables held by the scope. In contrast to globals, it is + not possible to add new local variables or modify the local variables by modifying + the returned dictionary. + """ + return { + var: ( + BoundLocal(self.frame.f_locals[var]) + if var in self.frame.f_locals + else None + ) + for var in self.frame.f_code.co_varnames + + self.frame.f_code.co_cellvars + + self.frame.f_code.co_freevars + } -def get_src_locals(*, depth): +@dataclass +class DummyScope: + """ + Wrapper for emulating a scope with a set of global and local variables. + Used for parsing patterns, which should not be able to capture local variables from the enclosing scope. + """ + + global_dict: dict[str, Any] + local_dict: dict[str, Any] + + def get_globals(self) -> dict[str, Any]: + return self.global_dict + + def read_locals(self) -> dict[str, Any]: + return self.local_dict.copy() + + +Scope = Union[ + DummyScope, FrameScope +] # Type to represent scopes, which have an API for getting global and local variables. + + +def get_parent_scope(*, depth) -> Scope: """ Get global and local environments for context capture purposes """ - stack_frames: [inspect.FrameInfo] = inspect.stack() + stack_frames = inspect.stack() assert len(stack_frames) >= depth - func_locals = stack_frames[depth].frame.f_locals - assert isinstance(func_locals, dict) - return ChainMap(func_locals) + frame = stack_frames[depth].frame + return FrameScope(frame) # --------------------------------------------------------------------------- # @@ -105,30 +187,391 @@ def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): module = pyast.parse(src) assert isinstance(module, pyast.Module) - # create way to query for src-code information - def getsrcinfo(node): - return SrcInfo( - filename=srcfilename, - lineno=node.lineno + srclineno, - col_offset=node.col_offset + n_dedent, - end_lineno=( - None if node.end_lineno is None else node.end_lineno + srclineno - ), - end_col_offset=( - None if node.end_col_offset is None else node.end_col_offset + n_dedent - ), - ) - parser = Parser( module.body, - getsrcinfo, + SourceInfo( + src_file=srcfilename, src_line_offset=srclineno, src_col_offset=n_dedent + ), + parent_scope=DummyScope( + srcglobals if srcglobals is not None else {}, + ( + {k: BoundLocal(v) for k, v in srclocals.items()} + if srclocals is not None + else {} + ), + ), # add globals from enclosing scope is_fragment=True, - func_globals=srcglobals, - srclocals=srclocals, ) return parser.result() +# These constants are used to name helper variables that allow the metalanguage to be parsed and evaluated. +# All of them start with two underscores, so there is not collision in names if the user avoids using names +# with two underscores. +QUOTE_CALLBACK_PREFIX = "__quote_callback" +OUTER_SCOPE_HELPER = "__outer_scope" +NESTED_SCOPE_HELPER = "__nested_scope" +UNQUOTE_RETURN_HELPER = "__unquote_val" +QUOTE_STMT_PROCESSOR = "__process_quote_stmt" + +QUOTE_BLOCK_KEYWORD = "exo" +UNQUOTE_BLOCK_KEYWORD = "python" + + +@dataclass +class ExoExpression: + """ + Opaque wrapper class for representing expressions in object code. Can be unquoted. + """ + + _inner: Any # note: strict typing is not possible as long as PAST/UAST grammar definition is not static + + +@dataclass +class ExoStatementList: + """ + Opaque wrapper class for representing a list of statements in object code. Can be unquoted. + """ + + _inner: tuple[Any, ...] + + +@dataclass +class QuoteReplacer(pyast.NodeTransformer): + """ + Replace quotes (Exo object code statements/expressions) in the metalanguage with calls to + helper functions that will parse and return the quoted code. + """ + + src_info: SourceInfo + exo_locals: dict[str, Any] + unquote_env: "UnquoteEnv" + inside_function: bool = False + + def visit_With(self, node: pyast.With) -> pyast.Any: + """ + Replace quoted statements. These will begin with "with exo:". + """ + if ( + len(node.items) == 1 + and isinstance(node.items[0].context_expr, pyast.Name) + and node.items[0].context_expr.id == QUOTE_BLOCK_KEYWORD + and isinstance(node.items[0].context_expr.ctx, pyast.Load) + ): + stmt_destination = node.items[0].optional_vars + + def parse_quote_block(): + return Parser( + node.body, + self.src_info, + parent_scope=get_parent_scope(depth=3), + is_quote_stmt=True, + parent_exo_locals=self.exo_locals, + ).result() + + if stmt_destination is None: + + def quote_callback( + quote_stmt_processor: Optional[Callable[[Any], None]] + ): + if quote_stmt_processor is None: + raise TypeError( + "Cannot unquote Exo statements in this context. You are likely trying to unquote Exo statements while inside an Exo expression." + ) + quote_stmt_processor(parse_quote_block()) + + callback_name = self.unquote_env.register_quote_callback(quote_callback) + + return pyast.Expr( + value=pyast.Call( + func=pyast.Name(id=callback_name, ctx=pyast.Load()), + args=[pyast.Name(id=QUOTE_STMT_PROCESSOR, ctx=pyast.Load())], + keywords=[], + ) + ) + else: + callback_name = self.unquote_env.register_quote_callback( + lambda: ExoStatementList(tuple(parse_quote_block())) + ) + return pyast.Assign( + targets=[stmt_destination], + value=pyast.Call( + func=pyast.Name(id=callback_name, ctx=pyast.Load()), + args=[], + keywords=[], + ), + ) + else: + return super().generic_visit(node) + + def visit_UnaryOp(self, node: pyast.UnaryOp) -> Any: + """ + Replace quoted expressions. These will look like "~{...}". + """ + if ( + isinstance(node.op, pyast.Invert) + and isinstance(node.operand, pyast.Set) + and len(node.operand.elts) == 1 + ): + + def quote_callback(): + return ExoExpression( + Parser( + node.operand.elts[0], + self.src_info, + parent_scope=get_parent_scope(depth=2), + is_quote_expr=True, + parent_exo_locals=self.exo_locals, + ).result() + ) + + callback_name = self.unquote_env.register_quote_callback(quote_callback) + return pyast.Call( + func=pyast.Name(id=callback_name, ctx=pyast.Load()), + args=[], + keywords=[], + ) + else: + return super().generic_visit(node) + + def visit_Nonlocal(self, node: pyast.Nonlocal) -> Any: + raise ParseError( + f"{self.src_info.get_src_info(node)}: nonlocal is not supported in metalanguage" + ) + + def visit_FunctionDef(self, node: pyast.FunctionDef): + """ + Record whether we are inside a function definition in the metalanguage, so that we can + prevent return statements that occur outside a function. + """ + was_inside_function = self.inside_function + self.inside_function = True + result = super().generic_visit(node) + self.inside_function = was_inside_function + return result + + def visit_AsyncFunctionDef(self, node): + was_inside_function = self.inside_function + self.inside_function = True + result = super().generic_visit(node) + self.inside_function = was_inside_function + return result + + def visit_Return(self, node): + if not self.inside_function: + raise ParseError( + f"{self.src_info.get_src_info(node)}: cannot return from metalanguage fragment" + ) + + return super().generic_visit(node) + + +@dataclass +class UnquoteEnv: + """ + Record of all the context needed to interpret a block of metalanguage code. + This includes the local and global variables of the scope that the metalanguage code will be evaluated in + and the Exo variables of the surrounding object code. + """ + + parent_globals: dict[str, Any] + parent_locals: dict[str, Local] + exo_local_vars: dict[str, Any] + + def mangle_name(self, prefix: str) -> str: + """ + Create unique names for helper functions that are used to parse object code + (see QuoteReplacer). + """ + index = 0 + while True: + mangled_name = f"{prefix}{index}" + if ( + mangled_name not in self.parent_locals + and mangled_name not in self.parent_globals + ): + return mangled_name + index += 1 + + def register_quote_callback(self, quote_callback: Callable[..., Any]) -> str: + """ + Store helper functions that are used to parse object code so that they may be referenced + when we interpret the metalanguage code. + """ + mangled_name = self.mangle_name(QUOTE_CALLBACK_PREFIX) + self.parent_locals[mangled_name] = BoundLocal(quote_callback) + return mangled_name + + def interpret_unquote_block( + self, + stmts: list[pyast.stmt], + quote_stmt_processor: Optional[Callable[[Any], None]], + ) -> Any: + """ + Interpret a metalanguage block of code. This is done by pasting the AST of the metalanguage code + into a helper function that sets up the local variables that need to be referenced in the metalanguage code, + and then calling that helper function. + + This function is also used to parse metalanguage expressions by representing the expressions as return statements + and saving the output returned by the helper function. + """ + bound_locals = { + name: val.val for name, val in self.parent_locals.items() if val is not None + } + unbound_names = { + name for name, val in self.parent_locals.items() if val is None + } + quote_locals = { + name: ExoExpression(val) + for name, val in self.exo_local_vars.items() + if name not in self.parent_locals + } + env_locals = {**quote_locals, **bound_locals} + old_stmt_processor = ( + self.parent_globals[QUOTE_STMT_PROCESSOR] + if QUOTE_STMT_PROCESSOR in self.parent_globals + else None + ) + self.parent_globals[QUOTE_STMT_PROCESSOR] = quote_stmt_processor + exec( + compile( + pyast.fix_missing_locations( + pyast.Module( + body=[ + pyast.FunctionDef( + name=OUTER_SCOPE_HELPER, + args=pyast.arguments( + posonlyargs=[], + args=[ + *[pyast.arg(arg=arg) for arg in bound_locals], + *[pyast.arg(arg=arg) for arg in unbound_names], + *[pyast.arg(arg=arg) for arg in quote_locals], + ], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[ + *( + [ + pyast.Delete( + targets=[ + pyast.Name( + id=name, + ctx=pyast.Del(), + ) + for name in unbound_names + ] + ) + ] + if len(unbound_names) != 0 + else [] + ), + pyast.FunctionDef( + name=NESTED_SCOPE_HELPER, + args=pyast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=[ + pyast.Expr( + value=pyast.Lambda( + args=pyast.arguments( + posonlyargs=[], + args=[], + kwonlyargs=[], + kw_defaults=[], + defaults=[], + ), + body=pyast.Tuple( + elts=[ + *[ + pyast.Name( + id=arg, + ctx=pyast.Load(), + ) + for arg in bound_locals + ], + *[ + pyast.Name( + id=arg, + ctx=pyast.Load(), + ) + for arg in unbound_names + ], + ], + ctx=pyast.Load(), + ), + ) + ), + *stmts, + ], + decorator_list=[], + ), + pyast.Return( + value=pyast.Call( + func=pyast.Name( + id=NESTED_SCOPE_HELPER, + ctx=pyast.Load(), + ), + args=[], + keywords=[], + ) + ), + ], + decorator_list=[], + ), + pyast.Assign( + targets=[ + pyast.Name( + id=UNQUOTE_RETURN_HELPER, ctx=pyast.Store() + ) + ], + value=pyast.Call( + func=pyast.Name( + id=OUTER_SCOPE_HELPER, + ctx=pyast.Load(), + ), + args=[ + *[ + pyast.Name(id=name, ctx=pyast.Load()) + for name in bound_locals + ], + *[ + pyast.Constant(value=None) + for _ in unbound_names + ], + *[ + pyast.Name(id=name, ctx=pyast.Load()) + for name in quote_locals + ], + ], + keywords=[], + ), + ), + ], + type_ignores=[], + ) + ), + "", + "exec", + ), + self.parent_globals, + env_locals, + ) + self.parent_globals[QUOTE_STMT_PROCESSOR] = old_stmt_processor + return env_locals[UNQUOTE_RETURN_HELPER] + + def interpret_unquote_expr(self, expr: pyast.expr): + """ + Parse a metalanguage expression using the machinery provided by interpret_unquote_block. + """ + return self.interpret_unquote_block([pyast.Return(value=expr)], None) + + # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # # Parser Pass object @@ -156,27 +599,28 @@ class Parser: def __init__( self, module_ast, - getsrcinfo, + src_info, + parent_scope=None, is_fragment=False, - func_globals=None, - srclocals=None, as_func=False, as_config=False, instr=None, + is_quote_stmt=False, + is_quote_expr=False, + parent_exo_locals=None, ): - self.module_ast = module_ast - self.globals = func_globals - self.locals = srclocals or ChainMap() - self.getsrcinfo = getsrcinfo + self.parent_scope = parent_scope + self.exo_locals = ChainMap() if parent_exo_locals is None else parent_exo_locals + self.src_info = src_info self.is_fragment = is_fragment self.push() special_cases = ["stride"] - for key, val in self.globals.items(): + for key, val in parent_scope.get_globals().items(): if isinstance(val, Extern): special_cases.append(key) - for key, val in self.locals.items(): + for key, val in parent_scope.read_locals().items(): if isinstance(val, Extern): special_cases.append(key) @@ -203,18 +647,25 @@ def __init__( self._cached_result = self.parse_expr(s.value) else: self._cached_result = self.parse_stmt_block(module_ast) + elif is_quote_expr: + self._cached_result = self.parse_expr(module_ast) + elif is_quote_stmt: + self._cached_result = self.parse_stmt_block(module_ast) else: assert False, "parser mode configuration unsupported" self.pop() + def getsrcinfo(self, ast): + return self.src_info.get_src_info(ast) + def result(self): return self._cached_result def push(self): - self.locals = self.locals.new_child() + self.exo_locals = self.exo_locals.new_child() def pop(self): - self.locals = self.locals.parents + self.exo_locals = self.exo_locals.parents # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - # # parser helper routines @@ -222,11 +673,64 @@ def pop(self): def err(self, node, errstr, origin=None): raise ParseError(f"{self.getsrcinfo(node)}: {errstr}") from origin + def make_exo_var_asts(self, srcinfo): + return { + name: self.AST.Read(val, [], srcinfo) + for name, val in self.exo_locals.items() + if isinstance(val, Sym) + } + + def try_eval_unquote( + self, unquote_node: pyast.expr + ) -> Union[tuple[()], tuple[Any]]: + if isinstance(unquote_node, pyast.Set): + if len(unquote_node.elts) != 1: + self.err(unquote_node, "Unquote must take 1 argument") + else: + unquote_env = UnquoteEnv( + self.parent_scope.get_globals(), + self.parent_scope.read_locals(), + self.make_exo_var_asts(self.getsrcinfo(unquote_node)), + ) + quote_replacer = QuoteReplacer( + self.src_info, self.exo_locals, unquote_env + ) + unquoted = unquote_env.interpret_unquote_expr( + quote_replacer.visit(copy.deepcopy(unquote_node.elts[0])) + ) + return (unquoted,) + elif ( + isinstance(unquote_node, pyast.Name) + and isinstance(unquote_node.ctx, pyast.Load) + and unquote_node.id not in self.exo_locals + and not self.is_fragment + ): + cur_globals = self.parent_scope.get_globals() + cur_locals = self.parent_scope.read_locals() + return ( + ( + UnquoteEnv( + cur_globals, + cur_locals, + self.make_exo_var_asts(self.getsrcinfo(unquote_node)), + ).interpret_unquote_expr(unquote_node), + ) + if unquote_node.id in cur_locals or unquote_node.id in cur_globals + else tuple() + ) + else: + return tuple() + def eval_expr(self, expr): assert isinstance(expr, pyast.expr) - code = compile(pyast.Expression(expr), "", "eval") - e_obj = eval(code, self.globals, self.locals) - return e_obj + return UnquoteEnv( + self.parent_scope.get_globals(), + { + **self.parent_scope.read_locals(), + **{k: BoundLocal(v) for k, v in self.exo_locals.items()}, + }, + self.make_exo_var_asts(self.getsrcinfo(expr)), + ).interpret_unquote_expr(expr) # - # - # - # - # - # - # - # - # - # - # - # - # - # - # - # # structural parsing rules... @@ -268,10 +772,10 @@ def parse_fdef(self, fdef, instr=None): names.add(a.arg) nm = Sym(a.arg) if isinstance(typ, UAST.Size): - self.locals[a.arg] = SizeStub(nm) + self.exo_locals[a.arg] = SizeStub(nm) else: # note we don't need to stub the index variables - self.locals[a.arg] = nm + self.exo_locals[a.arg] = nm args.append(UAST.fnarg(nm, typ, mem, self.getsrcinfo(a))) # return types are non-sensical for Exo, b/c it models procedures @@ -453,11 +957,8 @@ def parse_num_type(self, node, is_arg=False): typ = _prim_types[node.value.id] is_window = False else: - self.err( - node, - "expected tensor type to be " - "of the form 'R[...]', 'f32[...]', etc.", - ) + typ = self.parse_num_type(node.value) + is_window = False if sys.version_info[:3] >= (3, 9): # unpack single or multi-arg indexing to list of slices/indices @@ -493,7 +994,13 @@ def parse_num_type(self, node, is_arg=False): node, f"Cannot allocate an intermediate value of type {node.id}" ) else: - self.err(node, "unrecognized type: " + pyast.dump(node)) + unquote_eval_result = self.try_eval_unquote(node) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + if isinstance(unquoted, str) and unquoted in _prim_types: + return _prim_types[unquoted] + else: + self.err(node, "Unquote computation did not yield valid type") def parse_stmt_block(self, stmts): assert isinstance(stmts, list) @@ -501,8 +1008,56 @@ def parse_stmt_block(self, stmts): rstmts = [] for s in stmts: + if isinstance(s, pyast.With): + if ( + len(s.items) == 1 + and isinstance(s.items[0].context_expr, pyast.Name) + and s.items[0].context_expr.id == UNQUOTE_BLOCK_KEYWORD + and isinstance(s.items[0].context_expr.ctx, pyast.Load) + and s.items[0].optional_vars is None + ): + unquote_env = UnquoteEnv( + self.parent_scope.get_globals(), + self.parent_scope.read_locals(), + self.make_exo_var_asts(self.getsrcinfo(s)), + ) + quote_stmt_replacer = QuoteReplacer( + self.src_info, + self.exo_locals, + unquote_env, + ) + unquote_env.interpret_unquote_block( + [ + quote_stmt_replacer.visit(copy.deepcopy(python_s)) + for python_s in s.body + ], + lambda stmts: rstmts.extend(stmts), + ) + else: + self.err(s, "Expected unquote") + elif isinstance(s, pyast.Expr) and isinstance(s.value, pyast.Set): + if len(s.value.elts) != 1: + self.err(s, "Unquote must take 1 argument") + else: + unquoted = self.try_eval_unquote(s.value)[0] + if ( + isinstance(unquoted, ExoStatementList) + and isinstance(unquoted._inner, tuple) + and all( + map( + lambda inner_s: isinstance(inner_s, self.AST.stmt), + unquoted._inner, + ) + ) + ): + rstmts.extend(unquoted._inner) + else: + self.err( + s, + "Statement-level unquote expression must return Exo statements", + ) # ----- Assginment, Reduction, Var Declaration/Allocation parsing - if isinstance(s, (pyast.Assign, pyast.AnnAssign, pyast.AugAssign)): + elif isinstance(s, (pyast.Assign, pyast.AnnAssign, pyast.AugAssign)): # parse the rhs first, if it's present rhs = None if isinstance(s, pyast.AnnAssign): @@ -601,7 +1156,7 @@ def parse_stmt_block(self, stmts): # insert any needed Allocs if isinstance(s, pyast.AnnAssign): nm = Sym(name_node.id) - self.locals[name_node.id] = nm + self.exo_locals[name_node.id] = nm typ, mem = self.parse_alloc_typmem(s.annotation) rstmts.append(UAST.Alloc(nm, typ, mem, self.getsrcinfo(s))) @@ -610,10 +1165,10 @@ def parse_stmt_block(self, stmts): if ( isinstance(s, pyast.Assign) and len(idxs) == 0 - and name_node.id not in self.locals + and name_node.id not in self.exo_locals ): nm = Sym(name_node.id) - self.locals[name_node.id] = nm + self.exo_locals[name_node.id] = nm do_fresh_assignment = True else: do_fresh_assignment = False @@ -621,9 +1176,9 @@ def parse_stmt_block(self, stmts): # get the symbol corresponding to the name on the # left-hand-side if isinstance(s, (pyast.Assign, pyast.AugAssign)): - if name_node.id not in self.locals: + if name_node.id not in self.exo_locals: self.err(name_node, f"variable '{name_node.id}' undefined") - nm = self.locals[name_node.id] + nm = self.exo_locals[name_node.id] if isinstance(nm, SizeStub): self.err( name_node, @@ -660,7 +1215,7 @@ def parse_stmt_block(self, stmts): itr = s.target.id else: itr = Sym(s.target.id) - self.locals[s.target.id] = itr + self.exo_locals[s.target.id] = itr cond = self.parse_loop_cond(s.iter) body = self.parse_stmt_block(s.body) @@ -831,12 +1386,75 @@ def parse_array_indexing(self, node): if not isinstance(node.value, pyast.Name): self.err(node, "expected access to have form 'x' or 'x[...]'") - is_window = any(isinstance(e, pyast.Slice) for e in dims) - idxs = [ - (self.parse_slice(e, node) if is_window else self.parse_expr(e)) - for e in dims - ] + def unquote_to_index(unquoted, ref_node, srcinfo, top_level): + if isinstance(unquoted, (int, float)): + return self.AST.Const(unquoted, srcinfo) + elif isinstance(unquoted, ExoExpression) and isinstance( + unquoted._inner, self.AST.expr + ): + return unquoted._inner + elif isinstance(unquoted, slice) and top_level: + if unquoted.step is None: + return UAST.Interval( + ( + None + if unquoted.start is None + else unquote_to_index( + unquoted.start, ref_node, srcinfo, False + ) + ), + ( + None + if unquoted.stop is None + else unquote_to_index( + unquoted.stop, ref_node, srcinfo, False + ) + ), + srcinfo, + ) + else: + self.err(ref_node, "Unquote returned slice index with step") + else: + self.err( + ref_node, "Unquote received input that couldn't be unquoted" + ) + idxs = [] + srcinfo_for_idxs = [] + for e in dims: + if sys.version_info[:3] >= (3, 9): + srcinfo = self.getsrcinfo(e) + else: + if isinstance(e, pyast.Index): + e = e.value + srcinfo = self.getsrcinfo(e) + else: + srcinfo = self.getsrcinfo(node) + if isinstance(e, pyast.Slice): + idxs.append(self.parse_slice(e, node)) + srcinfo_for_idxs.append(srcinfo) + else: + unquote_eval_result = self.try_eval_unquote(e) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + if isinstance(unquoted, tuple): + for unquoted_val in unquoted: + idxs.append( + unquote_to_index(unquoted_val, e, srcinfo, True) + ) + srcinfo_for_idxs.append(srcinfo) + else: + idxs.append(unquote_to_index(unquoted, e, srcinfo, True)) + srcinfo_for_idxs.append(srcinfo) + else: + idxs.append(self.parse_expr(e)) + srcinfo_for_idxs.append(srcinfo) + + is_window = any(map(lambda idx: isinstance(idx, UAST.Interval), idxs)) + if is_window: + for i in range(len(idxs)): + if not isinstance(idxs[i], UAST.Interval): + idxs[i] = UAST.Point(idxs[i], srcinfo_for_idxs[i]) return node.value, idxs, is_window else: assert False, "bad case" @@ -853,23 +1471,31 @@ def parse_slice(self, e, node): else: srcinfo = self.getsrcinfo(node) - if isinstance(e, pyast.Slice): - lo = None if e.lower is None else self.parse_expr(e.lower) - hi = None if e.upper is None else self.parse_expr(e.upper) - if e.step is not None: - self.err( - e, - "expected windowing to have the form x[:], " - "x[i:], x[:j], or x[i:j], but not x[i:j:k]", - ) + lo = None if e.lower is None else self.parse_expr(e.lower) + hi = None if e.upper is None else self.parse_expr(e.upper) + if e.step is not None: + self.err( + e, + "expected windowing to have the form x[:], " + "x[i:], x[:j], or x[i:j], but not x[i:j:k]", + ) - return UAST.Interval(lo, hi, srcinfo) - else: - return UAST.Point(self.parse_expr(e), srcinfo) + return UAST.Interval(lo, hi, srcinfo) # parse expressions, including values, indices, and booleans def parse_expr(self, e): - if isinstance(e, (pyast.Name, pyast.Subscript)): + unquote_eval_result = self.try_eval_unquote(e) + if len(unquote_eval_result) == 1: + unquoted = unquote_eval_result[0] + if isinstance(unquoted, (int, float)): + return self.AST.Const(unquoted, self.getsrcinfo(e)) + elif isinstance(unquoted, ExoExpression) and isinstance( + unquoted._inner, self.AST.expr + ): + return unquoted._inner + else: + self.err(e, "Unquote received input that couldn't be unquoted") + elif isinstance(e, (pyast.Name, pyast.Subscript)): nm_node, idxs, is_window = self.parse_array_indexing(e) if self.is_fragment: @@ -879,11 +1505,9 @@ def parse_expr(self, e): else: return PAST.Read(nm, idxs, self.getsrcinfo(e)) else: - if nm_node.id in self.locals: - nm = self.locals[nm_node.id] - elif nm_node.id in self.globals: - nm = self.globals[nm_node.id] - else: # could not resolve name to anything + if nm_node.id in self.exo_locals: + nm = self.exo_locals[nm_node.id] + else: self.err(nm_node, f"variable '{nm_node.id}' undefined") if isinstance(nm, SizeStub): @@ -937,11 +1561,15 @@ def parse_expr(self, e): opnm = ( "+" if isinstance(e.op, pyast.UAdd) - else "not" - if isinstance(e.op, pyast.Not) - else "~" - if isinstance(e.op, pyast.Invert) - else "ERROR-BAD-OP-CASE" + else ( + "not" + if isinstance(e.op, pyast.Not) + else ( + "~" + if isinstance(e.op, pyast.Invert) + else "ERROR-BAD-OP-CASE" + ) + ) ) self.err(e, f"unsupported unary operator: {opnm}") @@ -1067,9 +1695,9 @@ def parse_expr(self, e): dim = int(e.args[1].value) if not self.is_fragment: - if name not in self.locals: + if name not in self.exo_locals: self.err(e.args[0], f"variable '{name}' undefined") - name = self.locals[name] + name = self.exo_locals[name] return self.AST.StrideExpr(name, dim, self.getsrcinfo(e)) diff --git a/tests/golden/test_metaprogramming/test_capture_nested_quote.txt b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt new file mode 100644 index 00000000..56a54d49 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_capture_nested_quote.txt @@ -0,0 +1,20 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a = 2 + a = 2 + a = 2 +C: +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a = ((int32_t) 2); +*a = ((int32_t) 2); +*a = ((int32_t) 2); +} + diff --git a/tests/golden/test_metaprogramming/test_captured_closure.txt b/tests/golden/test_metaprogramming/test_captured_closure.txt new file mode 100644 index 00000000..569653d3 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_captured_closure.txt @@ -0,0 +1,34 @@ +EXO IR: +def bar(a: i32 @ DRAM): + a += 1 + a += 2 + a += 3 + a += 4 + a += 5 + a += 6 + a += 7 + a += 8 + a += 9 + a += 10 +C: +#include "test.h" + +#include +#include + +// bar( +// a : i32 @DRAM +// ) +void bar( void *ctxt, int32_t* a ) { +*a += ((int32_t) 1); +*a += ((int32_t) 2); +*a += ((int32_t) 3); +*a += ((int32_t) 4); +*a += ((int32_t) 5); +*a += ((int32_t) 6); +*a += ((int32_t) 7); +*a += ((int32_t) 8); +*a += ((int32_t) 9); +*a += ((int32_t) 10); +} + diff --git a/tests/golden/test_metaprogramming/test_conditional.txt b/tests/golden/test_metaprogramming/test_conditional.txt new file mode 100644 index 00000000..7e3473e5 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_conditional.txt @@ -0,0 +1,29 @@ +EXO IR: +def bar1(a: i8 @ DRAM): + b: i8 @ DRAM + b += 1 +def bar2(a: i8 @ DRAM): + b: i8 @ DRAM + b = 0 +C: +#include "test.h" + +#include +#include + +// bar1( +// a : i8 @DRAM +// ) +void bar1( void *ctxt, const int8_t* a ) { +int8_t b; +b += ((int8_t) 1); +} + +// bar2( +// a : i8 @DRAM +// ) +void bar2( void *ctxt, const int8_t* a ) { +int8_t b; +b = ((int8_t) 0); +} + diff --git a/tests/golden/test_metaprogramming/test_constant_lifting.txt b/tests/golden/test_metaprogramming/test_constant_lifting.txt new file mode 100644 index 00000000..5ac001ad --- /dev/null +++ b/tests/golden/test_metaprogramming/test_constant_lifting.txt @@ -0,0 +1,16 @@ +EXO IR: +def foo(a: f64 @ DRAM): + a = 2.0818897486445276 +C: +#include "test.h" + +#include +#include + +// foo( +// a : f64 @DRAM +// ) +void foo( void *ctxt, double* a ) { +*a = 2.0818897486445276; +} + diff --git a/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt b/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt new file mode 100644 index 00000000..29bc1782 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_eval_expr_in_mem.txt @@ -0,0 +1,16 @@ +EXO IR: +def foo(a: f32 @ DRAM): + pass +C: +#include "test.h" + +#include +#include + +// foo( +// a : f32 @DRAM +// ) +void foo( void *ctxt, const float* a ) { +; // NO-OP +} + diff --git a/tests/golden/test_metaprogramming/test_local_externs.txt b/tests/golden/test_metaprogramming/test_local_externs.txt new file mode 100644 index 00000000..504175e7 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_local_externs.txt @@ -0,0 +1,17 @@ +EXO IR: +def foo(a: f64 @ DRAM): + a = sin(a) +C: +#include "test.h" + +#include +#include + +#include +// foo( +// a : f64 @DRAM +// ) +void foo( void *ctxt, double* a ) { +*a = sin((double)*a); +} + diff --git a/tests/golden/test_metaprogramming/test_proc_shadowing.txt b/tests/golden/test_metaprogramming/test_proc_shadowing.txt new file mode 100644 index 00000000..5a3d3670 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_proc_shadowing.txt @@ -0,0 +1,28 @@ +EXO IR: +def foo(a: f32 @ DRAM): + sin(a) +C: +#include "test.h" + +#include +#include + +// sin( +// a : f32 @DRAM +// ) +static void sin( void *ctxt, float* a ); + +// foo( +// a : f32 @DRAM +// ) +void foo( void *ctxt, float* a ) { +sin(ctxt,a); +} + +// sin( +// a : f32 @DRAM +// ) +static void sin( void *ctxt, float* a ) { +*a = 0.0f; +} + diff --git a/tests/golden/test_metaprogramming/test_quote_complex_expr.txt b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt new file mode 100644 index 00000000..b111df4f --- /dev/null +++ b/tests/golden/test_metaprogramming/test_quote_complex_expr.txt @@ -0,0 +1,16 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a = a + 1 + 1 +C: +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a = *a + ((int32_t) 1) + ((int32_t) 1); +} + diff --git a/tests/golden/test_metaprogramming/test_quote_elision.txt b/tests/golden/test_metaprogramming/test_quote_elision.txt new file mode 100644 index 00000000..a22821c7 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_quote_elision.txt @@ -0,0 +1,17 @@ +EXO IR: +def foo(a: i32 @ DRAM, b: i32 @ DRAM): + b = a +C: +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM, +// b : i32 @DRAM +// ) +void foo( void *ctxt, const int32_t* a, int32_t* b ) { +*b = *a; +} + diff --git a/tests/golden/test_metaprogramming/test_scope_collision1.txt b/tests/golden/test_metaprogramming/test_scope_collision1.txt new file mode 100644 index 00000000..c2d6b20c --- /dev/null +++ b/tests/golden/test_metaprogramming/test_scope_collision1.txt @@ -0,0 +1,20 @@ +EXO IR: +def foo(a: i32 @ DRAM): + b: i32 @ DRAM + b = 2 + a = 1 +C: +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +int32_t b; +b = ((int32_t) 2); +*a = ((int32_t) 1); +} + diff --git a/tests/golden/test_metaprogramming/test_scope_collision2.txt b/tests/golden/test_metaprogramming/test_scope_collision2.txt new file mode 100644 index 00000000..a22821c7 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_scope_collision2.txt @@ -0,0 +1,17 @@ +EXO IR: +def foo(a: i32 @ DRAM, b: i32 @ DRAM): + b = a +C: +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM, +// b : i32 @DRAM +// ) +void foo( void *ctxt, const int32_t* a, int32_t* b ) { +*b = *a; +} + diff --git a/tests/golden/test_metaprogramming/test_scope_nesting.txt b/tests/golden/test_metaprogramming/test_scope_nesting.txt new file mode 100644 index 00000000..0ae39ca1 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_scope_nesting.txt @@ -0,0 +1,17 @@ +EXO IR: +def foo(a: i8 @ DRAM, b: i8 @ DRAM): + a = b +C: +#include "test.h" + +#include +#include + +// foo( +// a : i8 @DRAM, +// b : i8 @DRAM +// ) +void foo( void *ctxt, int8_t* a, const int8_t* b ) { +*a = *b; +} + diff --git a/tests/golden/test_metaprogramming/test_scoping.txt b/tests/golden/test_metaprogramming/test_scoping.txt new file mode 100644 index 00000000..331db00a --- /dev/null +++ b/tests/golden/test_metaprogramming/test_scoping.txt @@ -0,0 +1,16 @@ +EXO IR: +def foo(a: i8 @ DRAM): + a = 3 +C: +#include "test.h" + +#include +#include + +// foo( +// a : i8 @DRAM +// ) +void foo( void *ctxt, int8_t* a ) { +*a = ((int8_t) 3); +} + diff --git a/tests/golden/test_metaprogramming/test_statement_assignment.txt b/tests/golden/test_metaprogramming/test_statement_assignment.txt new file mode 100644 index 00000000..a8ea5b1a --- /dev/null +++ b/tests/golden/test_metaprogramming/test_statement_assignment.txt @@ -0,0 +1,22 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a += 1 + a += 2 + a += 1 + a += 2 +C: +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a += ((int32_t) 1); +*a += ((int32_t) 2); +*a += ((int32_t) 1); +*a += ((int32_t) 2); +} + diff --git a/tests/golden/test_metaprogramming/test_statements.txt b/tests/golden/test_metaprogramming/test_statements.txt new file mode 100644 index 00000000..cf5820b4 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_statements.txt @@ -0,0 +1,21 @@ +#include "test.h" + + + +#include +#include + + + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a += ((int32_t) 1); +*a += ((int32_t) 1); +for (int_fast32_t i = 0; i < 2; i++) { + *a += ((int32_t) 1); + *a += ((int32_t) 1); +} +} + diff --git a/tests/golden/test_metaprogramming/test_type_params.txt b/tests/golden/test_metaprogramming/test_type_params.txt new file mode 100644 index 00000000..98c6282a --- /dev/null +++ b/tests/golden/test_metaprogramming/test_type_params.txt @@ -0,0 +1,51 @@ +EXO IR: +def bar1(a: i32 @ DRAM, b: i8 @ DRAM): + c: i32[4] @ DRAM + for i in seq(0, 3): + d: i32 @ DRAM + d = b + c[i + 1] = a + c[i] * d + a = c[3] +def bar2(a: f64 @ DRAM, b: f64 @ DRAM): + c: f64[4] @ DRAM + for i in seq(0, 3): + d: f64 @ DRAM + d = b + c[i + 1] = a + c[i] * d + a = c[3] +C: +#include "test.h" + +#include +#include + +// bar1( +// a : i32 @DRAM, +// b : i8 @DRAM +// ) +void bar1( void *ctxt, int32_t* a, const int8_t* b ) { +int32_t *c = (int32_t*) malloc(4 * sizeof(*c)); +for (int_fast32_t i = 0; i < 3; i++) { + int32_t d; + d = (int32_t)(*b); + c[i + 1] = *a + c[i] * d; +} +*a = c[3]; +free(c); +} + +// bar2( +// a : f64 @DRAM, +// b : f64 @DRAM +// ) +void bar2( void *ctxt, double* a, const double* b ) { +double *c = (double*) malloc(4 * sizeof(*c)); +for (int_fast32_t i = 0; i < 3; i++) { + double d; + d = *b; + c[i + 1] = *a + c[i] * d; +} +*a = c[3]; +free(c); +} + diff --git a/tests/golden/test_metaprogramming/test_type_quote_elision.txt b/tests/golden/test_metaprogramming/test_type_quote_elision.txt new file mode 100644 index 00000000..d9173f3d --- /dev/null +++ b/tests/golden/test_metaprogramming/test_type_quote_elision.txt @@ -0,0 +1,19 @@ +EXO IR: +def foo(a: i8 @ DRAM, x: i8[2] @ DRAM): + a += x[0] + a += x[1] +C: +#include "test.h" + +#include +#include + +// foo( +// a : i8 @DRAM, +// x : i8[2] @DRAM +// ) +void foo( void *ctxt, int8_t* a, const int8_t* x ) { +*a += x[0]; +*a += x[1]; +} + diff --git a/tests/golden/test_metaprogramming/test_unary_ops.txt b/tests/golden/test_metaprogramming/test_unary_ops.txt new file mode 100644 index 00000000..028ac6f3 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unary_ops.txt @@ -0,0 +1,16 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a = -2 +C: +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a = ((int32_t) -2); +} + diff --git a/tests/golden/test_metaprogramming/test_unquote_elision.txt b/tests/golden/test_metaprogramming/test_unquote_elision.txt new file mode 100644 index 00000000..71079913 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unquote_elision.txt @@ -0,0 +1,16 @@ +EXO IR: +def foo(a: i32 @ DRAM): + a = a * 2 +C: +#include "test.h" + +#include +#include + +// foo( +// a : i32 @DRAM +// ) +void foo( void *ctxt, int32_t* a ) { +*a = *a * ((int32_t) 2); +} + diff --git a/tests/golden/test_metaprogramming/test_unquote_in_slice.txt b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt new file mode 100644 index 00000000..de0fc0e9 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unquote_in_slice.txt @@ -0,0 +1,28 @@ +EXO IR: +def foo(a: [i8][2] @ DRAM): + a[0] += a[1] +def bar(a: i8[10, 10] @ DRAM): + for i in seq(0, 5): + foo(a[i, 2:4]) +C: +#include "test.h" + +#include +#include + +// bar( +// a : i8[10, 10] @DRAM +// ) +void bar( void *ctxt, int8_t* a ) { +for (int_fast32_t i = 0; i < 5; i++) { + foo(ctxt,(struct exo_win_1i8){ &a[(i) * (10) + 2], { 1 } }); +} +} + +// foo( +// a : [i8][2] @DRAM +// ) +void foo( void *ctxt, struct exo_win_1i8 a ) { +a.data[0] += a.data[a.strides[0]]; +} + diff --git a/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt new file mode 100644 index 00000000..49abf306 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unquote_index_tuple.txt @@ -0,0 +1,30 @@ +EXO IR: +def foo(a: [i8][2, 2] @ DRAM): + a[0, 0] += a[0, 1] + a[1, 0] += a[1, 1] +def bar(a: i8[10, 10, 10] @ DRAM): + for i in seq(0, 7): + foo(a[i, i:i + 2, i + 1:i + 3]) +C: +#include "test.h" + +#include +#include + +// bar( +// a : i8[10, 10, 10] @DRAM +// ) +void bar( void *ctxt, int8_t* a ) { +for (int_fast32_t i = 0; i < 7; i++) { + foo(ctxt,(struct exo_win_2i8){ &a[(i) * (100) + (i) * (10) + i + 1], { 10, 1 } }); +} +} + +// foo( +// a : [i8][2, 2] @DRAM +// ) +void foo( void *ctxt, struct exo_win_2i8 a ) { +a.data[0] += a.data[a.strides[1]]; +a.data[a.strides[0]] += a.data[a.strides[0] + a.strides[1]]; +} + diff --git a/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt new file mode 100644 index 00000000..ea4f9798 --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unquote_slice_object1.txt @@ -0,0 +1,38 @@ +EXO IR: +def foo(a: [i8][2] @ DRAM): + a[0] += a[1] +def bar(a: i8[10, 10] @ DRAM): + for i in seq(0, 10): + foo(a[i, 1:3]) + for i in seq(0, 10): + foo(a[i, 5:7]) + for i in seq(0, 10): + foo(a[i, 2:4]) +C: +#include "test.h" + +#include +#include + +// bar( +// a : i8[10, 10] @DRAM +// ) +void bar( void *ctxt, int8_t* a ) { +for (int_fast32_t i = 0; i < 10; i++) { + foo(ctxt,(struct exo_win_1i8){ &a[(i) * (10) + 1], { 1 } }); +} +for (int_fast32_t i = 0; i < 10; i++) { + foo(ctxt,(struct exo_win_1i8){ &a[(i) * (10) + 5], { 1 } }); +} +for (int_fast32_t i = 0; i < 10; i++) { + foo(ctxt,(struct exo_win_1i8){ &a[(i) * (10) + 2], { 1 } }); +} +} + +// foo( +// a : [i8][2] @DRAM +// ) +void foo( void *ctxt, struct exo_win_1i8 a ) { +a.data[0] += a.data[a.strides[0]]; +} + diff --git a/tests/golden/test_metaprogramming/test_unrolling.txt b/tests/golden/test_metaprogramming/test_unrolling.txt new file mode 100644 index 00000000..136c770c --- /dev/null +++ b/tests/golden/test_metaprogramming/test_unrolling.txt @@ -0,0 +1,38 @@ +EXO IR: +def foo(a: i8 @ DRAM): + b: i8 @ DRAM + b = 0 + b += a + b += a + b += a + b += a + b += a + b += a + b += a + b += a + b += a + b += a +C: +#include "test.h" + +#include +#include + +// foo( +// a : i8 @DRAM +// ) +void foo( void *ctxt, const int8_t* a ) { +int8_t b; +b = ((int8_t) 0); +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +b += *a; +} + diff --git a/tests/test_metaprogramming.py b/tests/test_metaprogramming.py new file mode 100644 index 00000000..7067db37 --- /dev/null +++ b/tests/test_metaprogramming.py @@ -0,0 +1,527 @@ +from __future__ import annotations +from exo import proc, compile_procs_to_strings +from exo.API_scheduling import rename +from exo.frontend.pyparser import ParseError +import pytest +import warnings +from exo.libs.externs import * +from exo.platforms.x86 import DRAM + + +def test_unrolling(golden): + @proc + def foo(a: i8): + b: i8 + b = 0 + with python: + for _ in range(10): + with exo: + b += a + + c_file, _ = compile_procs_to_strings([foo], "test.h") + + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_conditional(golden): + def foo(cond: bool): + @proc + def bar(a: i8): + b: i8 + with python: + if cond: + with exo: + b = 0 + else: + with exo: + b += 1 + + return bar + + bar1 = rename(foo(False), "bar1") + bar2 = rename(foo(True), "bar2") + + c_file, _ = compile_procs_to_strings([bar1, bar2], "test.h") + assert f"EXO IR:\n{str(bar1)}\n{str(bar2)}\nC:\n{c_file}" == golden + + +def test_scoping(golden): + a = 3 + + @proc + def foo(a: i8): + a = {a} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_scope_nesting(golden): + x = 3 + + @proc + def foo(a: i8, b: i8): + with python: + y = 2 + with exo: + a = {~{b} if x == 3 and y == 2 else ~{a}} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_global_scope(): + cell = [0] + + @proc + def foo(a: i8): + a = 0 + with python: + with exo: + with python: + global dict + cell[0] = dict + dict = None + + assert cell[0] == dict + + +def test_constant_lifting(golden): + x = 1.3 + + @proc + def foo(a: f64): + a = {(x**x + x) / x} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_type_params(golden): + def foo(T: str, U: str): + @proc + def bar(a: {T}, b: {U}): + c: {T}[4] + for i in seq(0, 3): + d: {T} + d = b + c[i + 1] = a + c[i] * d + a = c[3] + + return bar + + bar1 = rename(foo("i32", "i8"), "bar1") + bar2 = rename(foo("f64", "f64"), "bar2") + + c_file, _ = compile_procs_to_strings([bar1, bar2], "test.h") + assert f"EXO IR:\n{str(bar1)}\n{str(bar2)}\nC:\n{c_file}" == golden + + +def test_captured_closure(golden): + cell = [0] + + def foo(): + cell[0] += 1 + + @proc + def bar(a: i32): + with python: + for _ in range(10): + foo() + with exo: + a += {cell[0]} + + c_file, _ = compile_procs_to_strings([bar], "test.h") + assert f"EXO IR:\n{str(bar)}\nC:\n{c_file}" == golden + + +def test_capture_nested_quote(golden): + a = 2 + + @proc + def foo(a: i32): + with python: + for _ in range(3): + with exo: + a = {a} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_quote_elision(golden): + @proc + def foo(a: i32, b: i32): + with python: + + def bar(): + return a + + with exo: + b = {bar()} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_unquote_elision(golden): + @proc + def foo(a: i32): + with python: + x = 2 + with exo: + a = a * x + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_scope_collision1(golden): + @proc + def foo(a: i32): + with python: + b = 1 + with exo: + b: i32 + b = 2 + with python: + c = b + with exo: + a = c + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_scope_collision2(golden): + @proc + def foo(a: i32, b: i32): + with python: + a = 1 + with exo: + b = a + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_scope_collision3(): + with pytest.raises( + NameError, + match="free variable 'x' referenced before assignment in enclosing scope", + ): + + @proc + def foo(a: i32, b: i32): + with python: + with exo: + a = b * x + x = 1 + + +def test_type_quote_elision(golden): + T = "i8" + + @proc + def foo(a: T, x: T[2]): + a += x[0] + a += x[1] + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_unquote_in_slice(golden): + @proc + def foo(a: [i8][2]): + a[0] += a[1] + + @proc + def bar(a: i8[10, 10]): + with python: + x = 2 + with exo: + for i in seq(0, 5): + foo(a[i, {x} : {2 * x}]) + + c_file, _ = compile_procs_to_strings([foo, bar], "test.h") + assert f"EXO IR:\n{str(foo)}\n{str(bar)}\nC:\n{c_file}" == golden + + +def test_unquote_slice_object1(golden): + @proc + def foo(a: [i8][2]): + a[0] += a[1] + + @proc + def bar(a: i8[10, 10]): + with python: + for s in [slice(1, 3), slice(5, 7), slice(2, 4)]: + with exo: + for i in seq(0, 10): + foo(a[i, s]) + + c_file, _ = compile_procs_to_strings([foo, bar], "test.h") + assert f"EXO IR:\n{str(foo)}\n{str(bar)}\nC:\n{c_file}" == golden + + +def test_unquote_slice_object2(): + with pytest.raises( + ParseError, match="cannot perform windowing on left-hand-side of an assignment" + ): + + @proc + def foo(a: i8[10, 10]): + with python: + for s in [slice(1, 3), slice(5, 7), slice(2, 4)]: + with exo: + for i in seq(0, 10): + a[i, s] = 2 + + +def test_unquote_index_tuple(golden): + @proc + def foo(a: [i8][2, 2]): + a[0, 0] += a[0, 1] + a[1, 0] += a[1, 1] + + @proc + def bar(a: i8[10, 10, 10]): + with python: + + def get_index(i): + return slice(i, ~{i + 2}), slice(~{i + 1}, ~{i + 3}) + + with exo: + for i in seq(0, 7): + foo(a[i, {get_index(i)}]) + + c_file, _ = compile_procs_to_strings([foo, bar], "test.h") + assert f"EXO IR:\n{str(foo)}\n{str(bar)}\nC:\n{c_file}" == golden + + +def test_unquote_err(): + with pytest.raises( + ParseError, match="Unquote computation did not yield valid type" + ): + T = 1 + + @proc + def foo(a: T): + a += 1 + + +def test_quote_complex_expr(golden): + @proc + def foo(a: i32): + with python: + + def bar(x): + return ~{x + 1} + + with exo: + a = {bar(~{a + 1})} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_statement_assignment(golden): + @proc + def foo(a: i32): + with python: + with exo as s1: + a += 1 + a += 2 + with exo as s2: + a += 3 + a += 4 + s = s1 if True else s2 + with exo: + {s} + {s} + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_statement_in_expr(): + with pytest.raises( + TypeError, match="Cannot unquote Exo statements in this context." + ): + + @proc + def foo(a: i32): + with python: + + def bar(): + with exo: + a += 1 + return 2 + + with exo: + a += {bar()} + a += {bar()} + + +def test_nonlocal_disallowed(): + with pytest.raises(ParseError, match="nonlocal is not supported"): + x = 0 + + @proc + def foo(a: i32): + with python: + nonlocal x + + +def test_outer_return_disallowed(): + with pytest.raises(ParseError, match="cannot return from metalanguage fragment"): + + @proc + def foo(a: i32): + with python: + return + + +def test_with_block(): + @proc + def foo(a: i32): + with python: + + def issue_warning(): + warnings.warn("deprecated", DeprecationWarning) + + with warnings.catch_warnings(record=True) as recorded_warnings: + issue_warning() + assert len(recorded_warnings) == 1 + pass + + +def test_unary_ops(golden): + @proc + def foo(a: i32): + with python: + x = ~1 + with exo: + a = x + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_return_in_async(): + @proc + def foo(a: i32): + with python: + + async def bar(): + return 1 + + pass + + +def test_local_externs(golden): + my_sin = sin + + @proc + def foo(a: f64): + a = my_sin(a) + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_unquote_multiple_exprs(): + with pytest.raises(ParseError, match="Unquote must take 1 argument"): + x = 0 + + @proc + def foo(a: i32): + a = {x, x} + + +def test_disallow_with_in_exo(): + with pytest.raises(ParseError, match="Expected unquote"): + + @proc + def foo(a: i32): + with a: + pass + + +def test_unquote_multiple_stmts(): + with pytest.raises(ParseError, match="Unquote must take 1 argument"): + + @proc + def foo(a: i32): + with python: + with exo as s: + a += 1 + with exo: + {s, s} + + +def test_unquote_non_statement(): + with pytest.raises( + ParseError, + match="Statement-level unquote expression must return Exo statements", + ): + + @proc + def foo(a: i32): + with python: + x = ~{a} + with exo: + {x} + + +def test_unquote_slice_with_step(): + with pytest.raises(ParseError, match="Unquote returned slice index with step"): + + @proc + def bar(a: [i32][10]): + a[0] = 0 + + @proc + def foo(a: i32[20]): + with python: + x = slice(0, 20, 2) + with exo: + bar(a[x]) + + +def test_typecheck_unquote_index(): + with pytest.raises( + ParseError, match="Unquote received input that couldn't be unquoted" + ): + + @proc + def foo(a: i32[20]): + with python: + x = "0" + with exo: + a[x] = 0 + + +def test_proc_shadowing(golden): + @proc + def sin(a: f32): + a = 0 + + @proc + def foo(a: f32): + sin(a) + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden + + +def test_eval_expr_in_mem(golden): + mems = [DRAM] + + @proc + def foo(a: f32 @ mems[0]): + pass + + c_file, _ = compile_procs_to_strings([foo], "test.h") + assert f"EXO IR:\n{str(foo)}\nC:\n{c_file}" == golden diff --git a/tests/test_typecheck.py b/tests/test_typecheck.py index fe9f86d0..84a8b21c 100644 --- a/tests/test_typecheck.py +++ b/tests/test_typecheck.py @@ -80,14 +80,14 @@ def foo(n: size, A: R[n] @ GEMM_SCRATCH): def test_sin1(): @proc - def sin(x: f32): + def sin_proc(x: f32): y: f32 y = sin(x) def test_sin2(): @proc - def sin(x: f32): + def sin_proc(x: f32): y: f32 if False: y = sin(x) diff --git a/tests/test_uast.py b/tests/test_uast.py index 08f771e8..4bf8b5ab 100644 --- a/tests/test_uast.py +++ b/tests/test_uast.py @@ -5,7 +5,7 @@ from exo import DRAM from exo.frontend.pyparser import ( Parser, - get_src_locals, + get_parent_scope, get_ast_from_python, ParseError, ) @@ -16,8 +16,7 @@ def to_uast(f): parser = Parser( body, getsrcinfo, - func_globals=f.__globals__, - srclocals=get_src_locals(depth=2), + parent_scope=get_parent_scope(depth=2), instr=("TEST", ""), as_func=True, ) @@ -99,7 +98,9 @@ def func(f: f32): for i in seq(0, global_str): f += 1 - with pytest.raises(ParseError, match="type "): + with pytest.raises( + ParseError, match="Unquote received input that couldn't be unquoted" + ): to_uast(func) local_str = "xyzzy" @@ -108,7 +109,9 @@ def func(f: f32): for i in seq(0, local_str): f += 1 - with pytest.raises(ParseError, match="type "): + with pytest.raises( + ParseError, match="Unquote received input that couldn't be unquoted" + ): to_uast(func)