diff --git a/.github/dependabot.yml b/.github/dependabot.yml index b38df29f4..55dbe8564 100644 --- a/.github/dependabot.yml +++ b/.github/dependabot.yml @@ -3,4 +3,4 @@ updates: - package-ecosystem: "pip" directory: "/" schedule: - interval: "daily" + interval: "monthly" diff --git a/CITATION.cff b/CITATION.cff index cebdd7c66..62b2cf70a 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -28,8 +28,8 @@ license: MIT # These fields should point to the latest release and be updated as soon as a # new release is tagged. commit: 09e666ee0690797a6c56103b65f5d83abd79c60e -version: 0.2.1 -date-released: 2024-08-28 +version: 1.0.0 +date-released: 2024-11-05 # This is the citation for the PLDI 2022 paper. preferred-citation: diff --git a/README.md b/README.md index bbbe3cf1f..46eeb2322 100644 --- a/README.md +++ b/README.md @@ -11,6 +11,7 @@ If you're just using Exo, install it using `pip`: ```sh $ pip install exo-lang ``` +In case of `ModuleNotFoundError: No module named 'attrs'` please upgrade your attrs module by `pip install --upgrade attrs`. ## Compile Exo @@ -29,11 +30,6 @@ You can use optional arguments to customize the output: - The `--stem` argument allows you to specify custom names for the C file and header file. -# Examples - -Take a look at [examples](examples/README.md) for scheduling examples, and [API documentation](docs/API.md) for complete scheduling interface documentation. - - # Build Exo from source We make active use of newer Python 3.x features. Please use Python 3.9 or 3.10 if you're getting errors about unsupported features. @@ -62,7 +58,6 @@ Finally, you can build and install Exo. (exo) $ pip install dist/*.whl ``` - ## PySMT Depending on your setup, getting PySMT to work correctly may be difficult. You @@ -117,31 +112,16 @@ pytest --cov=./ --cov-report=html Then, if you want to see annotated source files, open `./htmlcov/index.html`. +--- -# Repository structure - -In this repository, folders are structured as follows: +# Learn about Exo -1. `src/exo` is where the core Exo implementation resides. - - **APIs.** Documentation for the APIs can be found in the [API documentation](docs/API.md). - - `API.py` defines a stable API for top-level decorators (`proc`, `instr`, and `config`). - - `API_scheduling.py` defines a API for scheduling primitives. - - `API_cursors.py` defines a API for scheduling primitives. - - **Standard libraries.** These could be user-defined, but we provide them for convenience. - - `libs/` contains some common memory definitions (`memories.py`) and custom malloc implementations. - - `platforms/` contains instruction definitions that are part of the release. - - `stdlib/` contains user-level scheduling functions such as `vectorize`. - - Other files are implementation details of Exo (e.g., `typecheck.py` implements typecheck), are not exposed to users. -2. `apps/` contains some sample applications written in Exo. -3. `dependencies/` contains submodules that Exo's apps and testing depends on. -4. `examples/` contains a step-by-step example of scheduling basic matrix multiplication on AVX2. -5. `tests/` contains the Exo test suite. -6. `docs/` contains additional Exo documentation. +Take a look at the [examples](examples/README.md) directory for scheduling examples and the [documentation](docs/README.md) directory for various documentation about Exo. # Contact -Please contact [exo@mit.edu](mailto:exo@mit.edu) if you have any questions. +Please contact [exo@mit.edu](mailto:exo@mit.edu) or [yuka@csail.mit.edu](mailto:yuka@csail.mit.edu) if you have any questions. # Publication diff --git a/apps/x86/conv/conv.py b/apps/x86/conv/conv.py index c8282c610..975d22a33 100644 --- a/apps/x86/conv/conv.py +++ b/apps/x86/conv/conv.py @@ -1,9 +1,9 @@ from __future__ import annotations from exo import * -from exo.builtins import * +from exo.libs.externs import * from exo.platforms.x86 import * -from exo.syntax import * +from exo.frontend.syntax import * from exo.stdlib.scheduling import * diff --git a/apps/x86/sgemm/sgemm.py b/apps/x86/sgemm/sgemm.py index ea1f663ef..1d2e005b6 100644 --- a/apps/x86/sgemm/sgemm.py +++ b/apps/x86/sgemm/sgemm.py @@ -3,7 +3,7 @@ from exo import * from exo.libs.memories import DRAM_STATIC from exo.platforms.x86 import * -from exo.syntax import * +from exo.frontend.syntax import * from exo.stdlib.scheduling import * from exo.stdlib.stdlib import * diff --git a/dev-requirements.txt b/dev-requirements.txt index f9c2b4a4e..ea5b6ab5c 100644 --- a/dev-requirements.txt +++ b/dev-requirements.txt @@ -1,9 +1,9 @@ -black==24.8.0 -coverage==7.6.1 -pre-commit==3.8.0 -pytest-cov==5.0.0 +black==24.10.0 +coverage==7.6.4 +pre-commit==4.0.1 +pytest-cov==6.0.0 pytest-xdist==3.6.1 pytest==8.3.3 -tox==4.21.2 -numpy==2.1.1 -Pillow==10.4.0 +tox==4.23.2 +numpy==2.1.2 +Pillow==11.0.0 diff --git a/docs/Cursors.md b/docs/Cursors.md new file mode 100644 index 000000000..dfbaaf533 --- /dev/null +++ b/docs/Cursors.md @@ -0,0 +1,376 @@ +# Cursors + +This documentation covers how to use cursors to navigate, point-to, and apply forwarding on procedures. +Throughout this document: +- `p` refers to an Exo `Procedure` object +- `c` refers to an Exo `Cursor` object + +## Obtaining Cursors + +### From Procedures +An Exo `Procedure` provides methods to obtain `Cursor`s: + +- `p.args()`: Returns cursors to the procedure's arguments. +- `p.body()`: Returns a `BlockCursor` selecting the entire body of the procedure. +- `p.find(pattern, many=False)`: Finds cursor(s) matching the given `pattern` string: + - If `many=False` (default), returns the first matching cursor. + - If `many=True`, returns a list of all matching cursors. +- `p.find_loop(loop_pattern, many=False)`: Finds cursor(s) to a loop, expanding shorthand patterns: + - `"name"` or `"name #n"` are expanded to `for name in _:_` + - Works like `p.find()`, returning the first match by default unless `many=True` +- `p.find_alloc_or_arg(buf_name)`: Finds an allocation or argument cursor, expanding the name to `buf_name: _`. +- `p.find_all(pattern)`: Shorthand for `p.find(pattern, many=True)`, returning all matching cursors. + +### From Cursors +A `Cursor` provides a similar method to find sub-cursors within its sub-AST: + +- `c.find(pattern, many=False)`: Finds cursor(s) matching the `pattern` within the cursor's sub-AST. + - Like `p.find()`, returns the first match by default unless `many=True`. + +### Pattern Language +The `pattern` argument is a string using the following special syntax: + +- `_` is a wildcard matching any statement or expression +- `#n` at the end selects the `n+1`th match instead of the first + - Ex. `"for i in _:_ #2"` matches the 3rd `i` loop +- `;` is a sequence of statements + +Example patterns: +- `"for i in _:_"` matches a `for i in seq(0, n):...` loop +- `"if i == 0:_"` or `"if _:_"` match `if` statements +- `"a : i8"` or `"a : _"` match an allocation of a buffer `a` +- `"a = 3.0"` or `"a = _"` match an assignment to `a` +- `"a += 3.0"` or `"a += _"` match a reduction on `a` +- `"a = 3.0 ; b = 2.0"` matches a block with those two statements + +## Cursor Types + +Exo defines the following `Cursor` types: + +- `StmtCursor`: Cursor to a specific Exo IR statement +- `GapCursor`: Cursor to the space between statements, anchored to (before or after) a statement +- `BlockCursor`: Cursor to a block (sequence) of statements +- `ArgCursor`: Cursor to a procedure argument (no navigation) +- `InvalidCursor`: Special cursor type for invalid cursors + +## Common Cursor Methods + +All `Cursor` types provide these common methods: + +- `c.parent()`: Returns `StmtCursor` to the parent node in Exo IR + - Raises `InvalidCursorError` if at the root with no parent +- `c.proc()`: Returns the `Procedure` this cursor is pointing to +- `c.find(pattern, many=False)`: Finds cursors by pattern-match within `c`s sub-AST + +## Statement Cursor Navigation + +A `StmtCursor` (pointing to one IR statement) provides these navigation methods. + +- `c.next()`: Returns `StmtCursor` to next statement +- `c.prev()`: Returns `StmtCursor` to previous statement +- `c.before()`: Returns `GapCursor` to space immediately before this statement +- `c.after()`: Returns `GapCursor` to space immediately after this statement +- `c.as_block()`: Returns a `BlockCursor` containing only this one statement +- `c.expand()`: Shorthand for `stmt_cursor.as_block().expand(...)` +- `c.body()`: Returns a `BlockCursor` to the body. Only works on `ForCursor` and `IfCursor`. +- `c.orelse()`: Returns a `BlockCursor` to the orelse branch. Works only on `IfCursor`. + +`c.next()` / `c.prev()` return an `InvalidCursor` when there is no next/previous statement. +`c.before()` / `c.after()` return anchored `GapCursor`s that move with their anchor statements. + +Examples: +``` +s1 <- c +s2 <- c.next() + +s1 <- c.prev() +s2 <- c + +s1 + <- c.before() +s2 <- c + +s1 +s2 <- c + <- c.after() +``` + +## Other Cursor Navigation + +- `GapCursor.anchor()`: Returns `StmtCursor` to the statement this gap is anchored to + +- `BlockCursor.expand(delta_lo=None, delta_hi=None)`: Returns an expanded block cursor + - `delta_lo`/`delta_hi` specify statements to add at start/end; `None` means expand fully + - Ex. in `s1; s2; s3`, if `c` is a `BlockCursor` pointing `s1; s2`, then `c.expand(0, 1)` returns a new `BlockCursor` pointing `s1; s2; s3` +- `BlockCursor.before()`: Returns `GapCursor` before block's first statement +- `BlockCursor.after()`: Returns `GapCursor` after block's last statement +- `BlockCursor[pt]`: Returns a `pt+1`th `StmtCursor` within the BlockCursor (e.g. `c[0]` returns `s1` when `c` is pointing to `s1;s2;...`) +- `BlockCursor[lo:hi]`: Returns a slice of `BlockCursor` from `lo` to `hi-1`. (e.g. `c[0:2]` returns `s1;s2` when `c` is pointing to `s2;s2;...`) + +## Cursor inspection + +`StmtCursor`s wrap the underlying Exo IR object and can be inspected. + - Ex. check cursor type with `isinstance(c, PC.AllocCursor)` + +`StmtCursor`s are one of the following types. + +#### `ArgCursor` + +Represents a cursor pointing to a procedure argument of the form: +``` +name : type @ mem +``` + +Methods: +- `name() -> str`: Returns the name of the argument. +- `mem() -> Memory`: Returns the memory location of the argument. +- `is_tensor() -> bool`: Checks if the argument is a tensor. +- `shape() -> ExprListCursor`: Returns a cursor to the shape expression list. +- `type() -> API.ExoType`: Returns the type of the argument. + +#### `AssignCursor` + +Represents a cursor pointing to an assignment statement of the form: +``` +name[idx] = rhs +``` + +Methods: +- `name() -> str`: Returns the name of the variable being assigned to. +- `idx() -> ExprListCursor`: Returns a cursor to the index expression list. +- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression. +- `type() -> API.ExoType`: Returns the type of the assignment. + +#### `ReduceCursor` + +Represents a cursor pointing to a reduction statement of the form: +``` +name[idx] += rhs +``` + +Methods: +- `name() -> str`: Returns the name of the variable being reduced. +- `idx() -> ExprListCursor`: Returns a cursor to the index expression list. +- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression. + + +#### `AssignConfigCursor` + +Represents a cursor pointing to a configuration assignment statement of the form: +``` +config.field = rhs +``` + +Methods: +- `config() -> Config`: Returns the configuration object. +- `field() -> str`: Returns the name of the configuration field being assigned to. +- `rhs() -> ExprCursor`: Returns a cursor to the right-hand side expression. + +#### `PassCursor` + +Represents a cursor pointing to a no-op statement: +``` +pass +``` + +#### `IfCursor` + +Represents a cursor pointing to an if statement of the form: +``` +if condition: + body +``` +or +``` +if condition: + body +else: + orelse +``` + +Methods: +- `cond() -> ExprCursor`: Returns a cursor to the if condition expression. +- `body() -> BlockCursor`: Returns a cursor to the if body block. +- `orelse() -> BlockCursor | InvalidCursor`: Returns a cursor to the else block if present, otherwise returns an invalid cursor. + +#### `ForCursor` + +Represents a cursor pointing to a loop statement of the form: +``` +for name in seq(0, hi): + body +``` + +Methods: +- `name() -> str`: Returns the loop variable name. +- `lo() -> ExprCursor`: Returns a cursor to the lower bound expression (defaults to 0). +- `hi() -> ExprCursor`: Returns a cursor to the upper bound expression. +- `body() -> BlockCursor`: Returns a cursor to the loop body block. + + +#### `AllocCursor` + +Represents a cursor pointing to a buffer allocation statement of the form: +``` +name : type @ mem +``` + +Methods: +- `name() -> str`: Returns the name of the allocated buffer. +- `mem() -> Memory`: Returns the memory location of the buffer. +- `is_tensor() -> bool`: Checks if the allocated buffer is a tensor. +- `shape() -> ExprListCursor`: Returns a cursor to the shape expression list. +- `type() -> API.ExoType`: Returns the type of the allocated buffer. + + +#### `CallCursor` + +Represents a cursor pointing to a sub-procedure call statement of the form: +``` +subproc(args) +``` + +Methods: +- `subproc()`: Returns the called sub-procedure. +- `args() -> ExprListCursor`: Returns a cursor to the argument expression list. + + +#### `WindowStmtCursor` + +Represents a cursor pointing to a window declaration statement of the form: +``` +name = winexpr +``` + +Methods: +- `name() -> str`: Returns the name of the window. +- `winexpr() -> ExprCursor`: Returns a cursor to the window expression. + + +## ExoType + +The `ExoType` enumeration represents user-facing various data and control types. It is a wrapper around Exo IR types. + +- `F16`: Represents a 16-bit floating-point type. +- `F32`: Represents a 32-bit floating-point type. +- `F64`: Represents a 64-bit floating-point type. +- `UI8`: Represents an 8-bit unsigned integer type. +- `I8`: Represents an 8-bit signed integer type. +- `UI16`: Represents a 16-bit unsigned integer type. +- `I32`: Represents a 32-bit signed integer type. +- `R`: Represents a generic numeric type. +- `Index`: Represents an index type. +- `Bool`: Represents a boolean type. +- `Size`: Represents a size type. +- `Int`: Represents a generic integer type. +- `Stride`: Represents a stride type. + +The `ExoType` provides the following utility methods: + +#### `is_indexable()` + +Returns `True` if the `ExoType` is one of the indexable types, which include: +- `ExoType.Index` +- `ExoType.Size` +- `ExoType.Int` +- `ExoType.Stride` + +#### `is_numeric()` + +Returns `True` if the `ExoType` is one of the numeric types, which include: +- `ExoType.F16` +- `ExoType.F32` +- `ExoType.F64` +- `ExoType.I8` +- `ExoType.UI8` +- `ExoType.UI16` +- `ExoType.I32` +- `ExoType.R` + +#### `is_bool()` + +Returns `True` if the `ExoType` is the boolean type (`ExoType.Bool`). + + +## Cursor Forwarding + +When a procedure `p1` is transformed into a new procedure `p2` by applying scheduling primitives, any cursors pointing into `p1` need to be updated to point to the corresponding locations in `p2`. This process is called *cursor forwarding*. + +To forward a cursor `c1` from `p1` to `p2`, you can use the `forward` method on the new procedure: +```python +c2 = p2.forward(c1) +``` + +### How Forwarding Works + +Internally, each scheduling primitive returns a *forwarding function* that maps AST locations in the input procedure to locations in the output procedure. + +When you call `p2.forward(c1)`, Exo composes the forwarding functions for all the scheduling steps between `c1.proc()` (the procedure `c1` points into, in this case `p1`) and `p2` (the final procedure). This composition produces a single function that can map `c1` from its original procedure to the corresponding location in `p2`. + +Here's the actual implementation of the forwarding in `src/exo/API.py`: + +```python +def forward(self, cur: C.Cursor): + p = self + fwds = [] + while p is not None and p is not cur.proc(): + fwds.append(p._forward) + p = p._provenance_eq_Procedure + + ir = cur._impl + for fn in reversed(fwds): + ir = fn(ir) + + return C.lift_cursor(ir, self) +``` + +The key steps are: + +1. Collect the forwarding functions (`p._forward`) for all procedures between `cur.proc()` and `self` (the final procedure). +2. Get the underlying Exo IR for the input cursor (`cur._impl`). +3. Apply the forwarding functions in reverse order to map the IR node to its final location. +4. Lift the mapped IR node back into a cursor in the final procedure. + +So in summary, `p.forward(c)` computes and applies the composite forwarding function to map cursor `c` from its original procedure to the corresponding location in procedure `p`. + +Note that a forwarding function can return an invalid cursor, and that is expected. For example, when a statement cease to exist by a rewrite, cursors pointing to the statement will be forwarded to an invalid cursor. + +### Implicit and Explicit Cursor Forwarding in Scheduling Primitives + +Scheduling primitives, such as `lift_alloc` and `expand_dim`, operate on a target procedure, which is passed as the first argument. When passing cursors to these primitives, the cursors should be forwarded to point to the target procedure. + +Consider the following example: +```python +c = p0.find("x : _") +p1 = lift_alloc(p0, c) +p2 = expand_dim(p1, p1.forward(c), ...) +``` + +In the call to `expand_dim`, the cursor `c` is explicitly forwarded to `p1` using `p1.forward(c)`. This is necessary because `c` was originally obtained from `p0`, and it needs to be adjusted to point to the correct location in `p1`. + +However, the scheduling primitives support *implicit forwarding* of cursors. This means that all the cursors passed to these primitives will be automatically forwarded to point to the first argument procedure. The above code can be simplified as follows: + +```python +c = p0.find("x : _") +p1 = lift_alloc(p0, c) +p2 = expand_dim(p1, c, ...) # implicit forwarding! +``` + +In this case, `c` is implicitly forwarded to `p1` within the `expand_dim` primitive, eliminating the need for explicit forwarding. + +#### Limitations of Implicit Forwarding + +It is important to note that implicit forwarding does not work when navigation is applied to a forwarded cursor. Consider the following example: + +```python +c = p0.find("x : _") +p1 = lift_alloc(p0, c) +p2 = reorder_scope(p1, p1.forward(c).next(), ...) +``` + +In this code, the navigation `.next()` is applied to the forwarded cursor `p1.forward(c)`. Attempting to change `p1.forward(c).next()` to `p1.forward(c.next())` will result in incorrect behavior. This is because navigation and forwarding are *not commutative*. + +## Further Reading +More details of the design principles of Cursors can be found in our [ASPLOS '25 paper](.) or in [Kevin Qian's MEng thesis](https://dspace.mit.edu/handle/1721.1/157187). + + diff --git a/docs/Design.md b/docs/Design.md new file mode 100644 index 000000000..153aa7d29 --- /dev/null +++ b/docs/Design.md @@ -0,0 +1,70 @@ +# Design Document for Exo + +Exo is a domain-specific language designed to enable productive development of high-performance kernel libraries that target specialized hardware accelerators. + +The key design principles of Exo are: +- **Performance Transparency**: We do not do "magic optimizations" that are surprising and opaque to users. +- **WYSIWYG**: Exo IR closely models C-style code and will be trivially lowered to C code. +- **User Control**: Give the performance control back to users. + +--- + +# Exocompilation: Externalizing Hardware Targets + +One of the main ideas behind Exo is **exocompilation**, which allows users to define hardware targets externally to the compiler in user-level libraries. This has several advantages: + +- Hardware vendors can support new accelerators without maintaining compiler forks. +- The cost of adding support for new hardware is significantly reduced. +- Proprietary details of hardware can be protected. + +Users can model custom [memories](./memories.md), [instructions](./instructions.md), and configuration state in libraries to target a specific accelerator. These hardware abstractions can then be used to write hand-optimized code or as building blocks for higher-level scheduling transformations. + +More info can be found in the [PLDI paper](https://people.csail.mit.edu/yuka/pdf/exo_pldi2022_full.pdf), [instructions.md](./instructions.md), and [memories.md](./memories.md). + +## Fine-Grained Primitives for Performance Control + +Exo provides a set of fine-grained scheduling primitives that offer users low-level control over performance-critical aspects. These primitives can be combined to create complex transformation schedules. Some examples of these primitives include: + +- `replace`: Maps code fragments to custom instructions +- `delete_config`: Removes redundant configuration statements + +The key research contributions of Exo were supporting `replace` through unification and the ability to reason about configuration states. Explicit control over these low-level details allows Exo to achieve performance comparable to highly-tuned vendor libraries and hand-optimized assembly code. All the primitives can be found in the [primitives/](./primitives/) directory. + +## Rewrite-based Scheduling Language + +Exo employs a *rewrite-based* compilation process, which differs from the *lowering-based* approach used by popular frameworks like Halide and TVM. + +The rewrite-based approach offers several advantages: + +- Reduced complexity and less "magic" involved +- Easier to print and inspect the state of the scheduling process at any point + +--- + +# User-Defined Scheduling Operations + +While the flexibility of fine-grained primitives is necessary for achieving peak performance, directly using them can be verbose and laborious. To address this, Exo allows users to define new higher-level scheduling operations by composing the core primitives. + +These user-defined scheduling operations can encapsulate common optimization patterns and hardware-specific transformations such as auto-vectorize, tiling, and even simulate scheduling operations from other USLs (like Halide's `compute_at`). +They can be put together in reusable libraries, further enabling modularity and portability. + +More infomation can be found in the [ASPLOS paper](.) and [Cursor.md](./Cursor.md). + +## The AIR Framework: Action, Inspection, Reference + +We identified that Action, Inspection, and Reference are the key scheduling language design mechanisms that enable user-defined scheduling operations. + +- **[Actions](./primitives)** are scheduling operations that transform the code. This could be compiler-provided *primitive actions* (e.g., `divide_loop`, `reorder`), or *user-defined* (e.g., tile2D in the ASPLOS paper). +- **[Inspections](./inspection.md)** query properties of the code (e.g., loop bounds, memory access patterns). +- **References** point to specific parts of the code to apply actions to. + +Together, AIR allows scheduling operations to be defined as composable rewrites on the code. The language implementation guarantees the correctness of these primitive rewrites with a set of effect analyses. + +## Cursors: Enabling Relative References + +A novel feature in Exo's design is the concept of cursors, which serve as relative references into the code. Similar to a text editing cursor, an Exo cursor can refer to a specific location in the program AST, such as a statement, loop nest, or even the gap between statements. + +Cursors support navigation operations such as `next`, `prev`, `parent`, enabling powerful code transformations using relative positions. +Furthermore, Cursor _forwarding_ let users reuse the cursor from the previous procedure in the current procedure. +Multiple cursors can coexist, allowing different parts of the code to be referenced and modified simultaneously. + diff --git a/docs/Imports.md b/docs/Imports.md new file mode 100644 index 000000000..61eed7f15 --- /dev/null +++ b/docs/Imports.md @@ -0,0 +1,97 @@ +# Imports in Exo + +This document provides an overview of the imports used when writing Exo. + +Exo's parser only resolves names in the local and global namespaces, and Exo reserves the attribute syntax (foo.bar) for configurations. +Therefore, if users wish to utilize Exo constructs, they must import them into their local namespace. + +## Table of Contents + +1. [Standard Python Future Import](#1-standard-python-future-import) +2. [Core Exo Module](#2-core-exo-module) +3. [Memory Libraries](#3-memory-libraries) +4. [Instruction Libraries](#4-instruction-libraries) +5. [Extern Libraries](#5-extern-libraries) +6. [Frontend Syntax Utilities](#6-frontend-syntax-utilities) +7. [Standard Library Scheduling Functions](#7-standard-library-scheduling-functions) +8. [API Cursors](#8-api-cursors) + + +## 1. Standard Python Future Import + +```python +from __future__ import annotations +``` + +Enables postponed evaluation of type annotations, allowing you to use forward references in type hints without causing issues during runtime. This is necessary to support Exo's `x : f32` syntax. + + +## 2. Core Exo Module + +```python +from exo import * +``` + +Imports basic classes and functions necessary for defining and manipulating high-performance computational kernels, such as `proc`, `instr`, `config`, `Memory`, `Extern`, `DRAM`, and `SchedulingError`. + + +## 3. Memory Libraries + +Even though users can define memory definitions externally to the compiler in the user code (see [memories.md](./memories.md)), we provide memory definitions for some architectures for convinience. +The supported memory definitions can be found by looking into `src/exo/libs/memories.py`. + +```python +from exo.libs.memories import DRAM_STATIC, AVX2, AVX512 +``` + +For example, you can import `DRAM_STATIC`, `AVX2`, or `AVX512` as shown above. + + +## 4. Instruction Libraries + +Similar to memories, we provide some hardware instruction definitions for convinience (see [instructions.md](./instructions.md) to learn how to define your own accelerator instructions). + +```python +from exo.platforms.x86 import mm256_loadu_ps, mm256_setzero_ps, mm256_broadcast_ss +``` + +## 5. Extern Libraries + +Similary, convinience extern libraries can be imported as follows. See [externs.md](./externs.md) to learn how to define your own externs. + +```python +from exo.libs.externs import sin, relu +``` + + +## 6. Frontend Syntax Utilities + +```python +from exo.frontend.syntax import * +``` + +This module defines special symbols that are used inside Exo code. +Importing this can suppress warnings inside an IDE (like PyCharm). + + +## 7. Standard Library Scheduling Functions + +Exo provides users with the ability to define new scheduling operations using Cursors. For convenience, we have implemented scheduling libraries (standard library) that contain common scheduling operations users may want to use, such as vectorization and tiling. Users can import the standard library as follows: + +```python +from exo.stdlib.scheduling import repeat, replace_all +from exo.stdlib.stdlib import vectorize, tile_loops +``` + +Alternatively, users can define their own scheduling operations by composing scheduling primitives directly in their code. + +## 8. API Cursors + +Cursors (see [Cursors.md](./Cursors.md)) are Exo's reference mechanism that allows users to navigate and inspect object code. When users define new scheduling operators using Cursors, they may wish to write their own inspection pass (see [inspection.md](./inspection.md)). API Cursors define types that will be useful for user inspection. + +```python +from exo.API_cursors import ForCursor, AssignCursor, InvalidCursor +``` + +These API Cursors provide specific types, such as `ForCursor` for for-loops, `AssignCursor` for assignments, and `InvalidCursor` for invalid cursors. Users can leverage these types when inspecting and manipulating code using Cursors. + diff --git a/docs/API.md b/docs/Procedures.md similarity index 59% rename from docs/API.md rename to docs/Procedures.md index 9f7280e20..6516a36a1 100644 --- a/docs/API.md +++ b/docs/Procedures.md @@ -8,17 +8,17 @@ ## Procedure Object Methods +The following are methods on Exo Procedures (functions decorated with `@proc` or `@instr`). + ### Inspection Operations - `.name()`: Returns the procedure name. - `.is_instr()`: Returns `True` if the procedure has a hardware instruction string. - `.get_instr()`: Returns the hardware instruction string. -- `.args()`: Returns cursors to procedure arguments. -- `.body()`: Returns a BlockCursor selecting the entire body of the Procedure. -- `.find(pattern, many=False)`: Finds a cursor for the given pattern. If `many=True`, returns a list of all cursors matching the pattern. -- `.find_loop(loop_pattern, many=False)`: Finds a cursor pointing to a loop. Similar to `proc.find(...)`, but if the supplied pattern is of the form 'name' or 'name #n', it will be auto-expanded to `for name in _:_`. -- `.find_alloc_or_arg(pattern)`: Finds an allocation or argument cursor. -- `.find_all(pattern)`: Finds a list of all cursors matching the pattern. + +### Obtaining Cursors + +Cursors can be obtained by querying patterns on a procedure. All the Cursor related documentations are in [Cursors.md](Cursors.md). ### Compilation Operations @@ -32,14 +32,3 @@ - `.transpose(arg_cursor)`: Transposes a 2D buffer argument in the signature and the body. Returns a new procedure and is non-equivalence preserving because the signature has changed. - `.add_assertion(assertion)`: Adds an assertion to the procedure. - `.is_eq(other_proc)`: Checks the equivalence of this procedure with another procedure. - -## Scheduling Primitives - -We have classified scheduling primitives into six categories. Here are the links to each: - -- [Buffer Transformations](buffer_ops.md) -- [Loop and Scope Transformations](loop_ops.md) -- [Configuration States](config_ops.md) -- [Subprocedure Operations](subproc_ops.md) -- [Memory, Precision, and Parallelism Transformations](backend_ops.md) -- [Other Operations](other_ops.md) diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 000000000..322887831 --- /dev/null +++ b/docs/README.md @@ -0,0 +1,30 @@ +# Documentation + +This directory provides detailed documentation about Exo's interface and internal systems. + +- To learn about the design principles of Exo, read [Design.md](Design.md). +- To understand how the Exo system is implemented, read [System.md](System.md). +- For information on writing Exo object code, APIs, and imports, refer to [Procedures.md](Procedures.md), [object_code.md](object_code.md), and [Imports.md](Imports.md). +- To learn how to define **hardware targets externally to the compiler**, refer to [externs.md](externs.md), [instructions.md](instructions.md), and [memories.md](memories.md). +- To learn how to define **new scheduling operations externally to the compiler**, refer to [Cursors.md](./Cursors.md) and [inspection.md](./inspection.md). +- To understand the available scheduling primitives and how to use them, look into the [primitives/](./primitives) directory. + +The scheduling primitives are classified into six categories: + +1. [Buffer Transformations](primitives/buffer_ops.md) +2. [Loop and Scope Transformations](primitives/loop_ops.md) +3. [Configuration States](primitives/config_ops.md) +4. [Subprocedure Operations](primitives/subproc_ops.md) +5. [Memory, Precision, and Parallelism Transformations](primitives/backend_ops.md) +6. [Other Operations](primitives/other_ops.md) + +# Further Reading + +The following papers provide a high-level and holistic view of Exo as a project: + +- [PLDI '22 paper](https://people.csail.mit.edu/yuka/pdf/exo_pldi2022_full.pdf) +- [ASPLOS '25 paper](.) +- [Kevin Qian's MEng thesis](https://dspace.mit.edu/handle/1721.1/157187) +- [Samir Droubi's MEng thesis](https://dspace.mit.edu/handle/1721.1/156752) + +For more documentation with running Exo code, refer to the [Examples](../examples/README.md) directory. diff --git a/docs/System.md b/docs/System.md new file mode 100644 index 000000000..c6869f8b4 --- /dev/null +++ b/docs/System.md @@ -0,0 +1,99 @@ +# System Overview + +This document provides an overview of the Exo compilation process, as illustrated in Figure 1 of the PLDI'22 paper. + +![System overview](images/system-overview.png) + +The Exo compiler's frontend starts by parsing the Python AST and constructing the Untyped Exo AST (UAST). +It then runs various frontend checks before converting the UAST into LoopIR, which serves as Exo's primary IR. +Exo supports rewrite-based user-scheduling, where scheduling primitives take a LoopIR and returns another (transformed) LoopIR. +These primitives take the immutable LoopIR and rewrite it into a new LoopIR. +Finally, in the backend, the optimized LoopIR is code-generated into C code. + +The input to the compiler is a set of Exo source files (`*.py`), and the output is generated C code (`*.c`). + +In this repository, folders are structured as follows: + +1. `src/exo` is where the core Exo implementation resides. + - **APIs.** + - `API.py` defines a stable API for top-level decorators (`proc`, `instr`, and `config`). + - `API_scheduling.py` defines a API for scheduling primitives. + - `API_cursors.py` defines a API for Cursors. + - **Standard libraries.** These could be user-defined, but we provide them for convenience. + - `libs/` contains some common memory definitions (`memories.py`) and custom malloc implementations. + - `platforms/` contains instruction definitions that are part of the release. + - `stdlib/` contains user-level scheduling functions such as `vectorize`. + - Other files are implementation of Exo (e.g., `typecheck.py` implements typecheck), are not exposed to users. +2. `apps/` contains some sample applications written in Exo. +3. `dependencies/` contains submodules that Exo's apps and testing depends on. +4. `examples/` contains examples of scheduling with Exo. +5. `tests/` contains the Exo test suite. +6. `docs/` contains additional Exo documentation. + +--- + +## Core + +`src/exo/core` defines IRs used in Exo and other core implementations. +- `LoopIR.py` is the main file that defines IRs (LoopIR, UAST, PAST), and their visitor functions (LoopIR_Do, LoopIR_Rewrite). +- `LoopIR_pprint.py` implements a printing procedure for the IRs defined in `LoopIR.py`. +- `prelude.py` defines `Sym` and `Srcinfo`. + +User-defined features like config, externs, and Memory's parent class implementations are in `configs.py`, `extern.py`, and `memory.py`, respectively. + +`internal_cursors` defines primitive cursor movements (see Section 5.2 "Cursor implementation" of our ASPLOS paper) that are used internally by `LoopIR_scheduling` implementations of scheduling primitives. +`proc_eqv.py` defines a union-find tree which we use to track the equivalence of procedures. + +--- + +## Frontend + +`API.py` provides various user-facing entry points to Exo. The frontend consists of three types of parsing passes, all of which are located in the `src/exo/frontend` directory. + +### Procedures + +The `@proc` and `@instr` decorators are defined in this section and call into the `Pyparser`. The frontend workflow is as follows: +``` +API -> Parser -> TypeCheck -> BoundsCheck/AssertCheck +``` + +`frontend/pyparser.py` defines a parser that translates the Python AST to UAST/PAST. Instead of implementing a custom lexer, Exo relies on the Python lexer to build the Python AST and hijacks it to translate it into Exo's internal ASTs. UAST (Untyped AST) is an untyped version of LoopIR (LoopIR is the "Exo IR" in the paper terminology). UAST is used when parsing full procedure definitions (`@proc` or `@instr`). PAST (Pattern AST) is an AST with holes, used to parse fragments from the user code outside the procedure (see next two sections). + +`typecheck.py` performs type checking and converts UAST to LoopIR. +`boundscheck.py` checks for any out-of-bounds errors in the frontend code and ensures that all assertions in the code are satisfiable by invoking an SMT solver. + +### New LoopIR Expressions + +Some scheduling primitives (such as `expand_dim` and all primitives that take `NewExprA` as an argument) require the construction of new LoopIR expressions. +`parse_fragment.py` implements this pass by calling into `pyparser.pattern`, which invokes the parser with `is_fragment=True`. +When parsing new expressions, it is not possible to use holes `_`. Holes are used for pattern matching for obtaining a cursor referene. + +### Pattern Match for Reference + +Cursors can be obtained by pattern matching. The pattern gets parsed into PAST and then matched against the LoopIR to obtain a reference. +`frontend/pattern_match.py` implements this functionality. + + +--- + +## Rewrites (User-Scheduling) + +After the frontend pass, we obtain LoopIR. The files in `src/exo/rewrite` implement Exo's rewrite-based user-scheduling process. + +- `LoopIR_scheduling.py` is the main file that implements all the scheduling primitives. Many implementations of primitives call into `Check_...` functions, which are the safety checks implemented in `new_eff.py`. +- The handling of analysis to preserve functional equivalence of rewrites is a separate topic not covered in detail here. `new_eff.py`, `new_analysis_core.py`, and `analysis_simplify.py` are all files related to the analysis. +- `LoopIR_unification.py` implements a unification process to support the `replace(p, ...)` rewrite primitive. + +--- + +## Backend + +The backend is responsible for lowering LoopIR to C code and performing backend checks, including precision analysis, window analysis, and parallelism analysis. + +- `LoopIR_compiler.py` is the main file in the backend, which compiles LoopIR to C code. +- `mem_analysis.py` implements a memory consistency check. For example, if a callee expects an `AVX2` annotation but the caller passes `DRAM` memory, it raises an error. +- `parallel_analysis.py` implements a parallel analysis. +- `prec_analysis.py` implements a precision consistency check and coerces the precision where possible. +- `win_analysis.py` implements a window analysis to check if callee and caller window annotations (tensor or window) match with each other. + + diff --git a/docs/externs.md b/docs/externs.md new file mode 100644 index 000000000..69adc8493 --- /dev/null +++ b/docs/externs.md @@ -0,0 +1,172 @@ +# Externs + +Externs in Exo provide a mechanism to interface with external functions and libraries directly from your Exo code. By defining custom extern functions, you can extend the capabilities of Exo and leverage existing code written in other languages like C or C++. Externs can be used as expressions in your code, particularly on the right-hand side (RHS) of assignment and reduction statements. + +## Defining Externs in User Code + +Extern functions are defined by subclassing the `Extern` class provided by Exo. This allows you to specify how the extern function should behave, including type checking, compilation, and any global code it might require. + +### Step-by-Step Guide + +#### 1. Import the Extern Class + +Before you can define an extern function, you need to import the `Extern` class and the `_EErr` exception from `exo.core.extern`. + +```python +from exo.core.extern import Extern, _EErr +``` + +- `Extern`: The base class for creating custom extern functions. +- `_EErr`: An exception class used for error handling during type checking. + +#### 2. Subclass the Extern Class + +Create a new class that inherits from `Extern`. This class represents your custom extern function. + +```python +class _Sin(Extern): + # Implementation details will go here +``` + +#### 3. Implement Required Methods + +Your subclass must implement several methods to define the behavior of the extern function. + +##### `__init__(self)` + +Initialize your extern function with its name. + +```python +def __init__(self): + super().__init__("sin") +``` + +- `"sin"`: The name of the external function as it will appear in the printed Exo object code. + +##### `typecheck(self, args)` + +Define how the function checks the types of its arguments. + +```python +def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + arg_type = args[0].type + if not arg_type.is_real_scalar(): + raise _EErr( + f"expected argument to be a real scalar value, but got type {arg_type}" + ) + return arg_type +``` + +- Checks that there is exactly one argument. +- Ensures the argument is a real scalar type (e.g., `float`, `double`). +- Returns the type of the argument as the return type of the function. + +##### `compile(self, args, prim_type)` + +Define how the function is compiled into target code. +- `args`: list of arguments as C strings +- `prim_type`: A C string representing the primitive data type. It could be one of the following C strings, mapping from LoopIR types to C strings: + - `f16` -> `"_Float16"` + - `f32` -> `"float"` + - `f64` -> `"double"` + - `i8` -> `"int8_t"` + - `ui8` -> `"uint8_t"` + - `ui16`-> `"uint16_t"` + - `i32` -> `"int32_t"` + +```python +def compile(self, args, prim_type): + return f"sin(({prim_type}){args[0]})" +``` + +- Generates the code that calls the external function, ensuring proper casting to the primitive type. + +##### `globl(self, prim_type)` + +Provide any global code or headers needed. + +```python +def globl(self, prim_type): + return "#include " +``` + +- Includes necessary headers required for the external function (e.g., `` for mathematical functions). +- `globl` is called and is instantiated for every `prim_type`s. + +#### 4. Instantiate the Extern Function + +Create an instance of your extern class to make it usable in your code. + +```python +sin = _Sin() +``` + +- `sin` now represents the extern function and can be used like any other expression in Exo. + +## Using Externs as Expressions + +Externs can be used as expressions on the RHS of assignment and reduction statements. This allows you to incorporate external functions seamlessly into your Exo computations. + +Unlike Exo procedures that do not allow aliasing in their arguments, you _can_ pass the same buffer to external arguments (e.g., `select(xi, xi, xi, xi)`). +This is because there is no concern about aliasing since all external arguments are read-only, as opposed to Exo procedure arguments which can have write effects on the input arguments. + +### Example: Using `sin` in an Expression + +Here's a complete example demonstrating how to define and use the `sin` extern function within an expression. + +```python +from __future__ import annotations +from exo import * +from exo.core.extern import Extern, _EErr + +class _Sin(Extern): + def __init__(self): + super().__init__("sin") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + arg_type = args[0].type + if not arg_type.is_real_scalar(): + raise _EErr( + f"expected argument to be a real scalar value, but got type {arg_type}" + ) + return arg_type + + def compile(self, args, prim_type): + return f"sin(({prim_type}){args[0]})" + + def globl(self, prim_type): + return "#include " + + def interpret(self, args): + import math + return math.sin(args[0]) + +# Instantiate the extern function +sin = _Sin() + +# Define an Exo procedure using the extern function in an expression +@proc +def foo(x: f32): + x = sin(x) * 3.0 + +print(foo) +``` + +### Output + +When you run the code above with `exocc`, the generated C code will be: +```c +#include +// foo( +// x : f32 @DRAM +// ) +void foo( void *ctxt, float* x ) { + *x = sin((float)*x) * 3.0f; +} +``` diff --git a/docs/images/system-overview.png b/docs/images/system-overview.png new file mode 100644 index 000000000..bb3817143 Binary files /dev/null and b/docs/images/system-overview.png differ diff --git a/docs/inspection.md b/docs/inspection.md new file mode 100644 index 000000000..8760ede49 --- /dev/null +++ b/docs/inspection.md @@ -0,0 +1,50 @@ +# External Inspection Functions + +Inspection is a metaprogramming feature that enables metaprograms (like schedules) to dynamically examine the properties of object code. Exo provides inspection through [Cursors](./Cursors.md), allowing users to examine standard AST properties such as variable names, literal expression values, and annotations (e.g., memory spaces and precisions) at scheduling time. Cursors also support local AST navigation, for example, accessing loop bounds (`loop.hi()`) and bodies (`loop.body()`). Inspection functions can be written externally from the Exo compiler, giving users the ability to customize them according to their needs. +For convinience, standard library inspection functions are provided as `exo.stdlib.inspection` module. + +Cursor types (such as `ForCursor` and `IfCursor`) are defined in `exo.API_cursors`, so you should import it when writing inspection functions: + +```python +from exo.API_cursors import * +``` + +Here are some simple inspection functions: + +```python +def is_loop(proc, loop): + loop = proc.forward(loop) + return isinstance(loop, ForCursor) + +def get_top_level_stmt(proc, c): + c = proc.forward(c) + + while not isinstance(c.parent(), InvalidCursor): + c = c.parent() + return c +``` + +Explanation: +- The `is_loop` function takes a `proc` object and a `loop` cursor as input. It forwards the `loop` cursor using `proc.forward(loop)` and checks if the resulting cursor is an instance of `ForCursor`. This function determines whether the given cursor points to a loop statement. +- The `get_top_level_stmt` function takes a `proc` object and a cursor `c` as input. It forwards the cursor `c` using `proc.forward(c)` and then iteratively moves the cursor to its parent using `c.parent()` until it reaches an `InvalidCursor`, which means the cursor reached the outer-most level of the procedure. This function finds the top-level statement that wraps the given cursor. + +Exo also exposes `ExoType` for expression types (defined in `src/exo/API_types.py`), which users can access using constructs like `ExoType.F16` and branch on it. + +```python +class ExoType(Enum): + F16 = auto() + F32 = auto() + F64 = auto() + UI8 = auto() + I8 = auto() + UI16 = auto() + I32 = auto() + R = auto() + Index = auto() + Bool = auto() + Size = auto() + Int = auto() + Stride = auto() +``` + +All the Cursor types and the kind of navigation you can perform on them are documented in [Cursors.md](./Cursors.md). diff --git a/docs/instructions.md b/docs/instructions.md new file mode 100644 index 000000000..ca9fc228a --- /dev/null +++ b/docs/instructions.md @@ -0,0 +1,165 @@ +# External Instruction Definitions + +Exo allows users to define custom hardware instructions within their code using the `@instr` annotation. +These user-defined instructions can be leveraged during the scheduling process to replace specific code fragments with calls to hardware-optimized instructions. + +## Overview + +- **Custom Instructions**: Define hardware-specific instructions as procedures using the `@instr` decorator. +- **Replace**: Use the `replace` primitive to substitute code fragments with calls to these instructions. +- **Code Generation**: Custom instructions can emit arbitrary C code, including inline assembly, with placeholders for arguments. + +## Defining Custom Instructions + +Custom instructions are defined as procedures annotated with `@instr`. +The `@instr` decorator allows you to specify the C code to be emitted when the instruction is called. + +### Syntax + +```python +@instr("C code") +def instruction_name(args): + # Specification of the instruction's behavior +``` +- **`@instr`**: Decorator that specifies the C code to emit. In the string provided to `@instr`, you can include placeholders wrapped in `{}`. These placeholders will be replaced with the names of the arguments when the code is compiled. +- **`instruction_name`**: The name of your custom instruction. +- **`args`**: Arguments to the instruction. +- **semantics**: Semantics of the hardware instruction, written as Exo object code. + +### Example: Defining a Neon Load Instruction + +Below is an example of defining a NEON load instruction that loads four `f32` values into Neon memory. + +```python +from exo import * + +@instr("{dst_data} = vld1q_f32(&{src_data});") +def neon_vld_4xf32(dst: [f32][4] @ Neon, src: [f32][4] @ DRAM): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + + for i in seq(0, 4): + dst[i] = src[i] +``` + +- **`@instr(...)`**: Specifies the semantics of the hardware instruction and the C code to emit. + - `{dst_data}` and `{src_data}` are format strings that will be replaced with the actual arguments during codegen. You can put `_data` after the function argument names and surround them with curly braces (`{dst_data}`). + - `"{dst_data} = vld1q_f32(&{src_data});"`: The argument to `@instr` decorators specifies the C code to emit for this instruction. +- **`dst: [f32][4] @ Neon`**: Declares `dst` as a 4-element array of `f32` in `Neon` memory. +- **`src: [f32][4] @ DRAM`**: Declares `src` as a 4-element array of `f32` in `DRAM`. +- **Assertions**: Ensure that the strides of `src` and `dst` are 1 for correct memory access. +- **Body**: The function body specifies the semantics of the instruction (written in Exo object code), copying elements from `src` to `dst`. + +### Defining the Memory Annotation `Neon` + +The `Neon` memory type can be defined similarly to how custom memories are defined, as explained in [memories.md](memories.md). + +```python +class Neon(Memory): + @classmethod + def global_(cls): + return "#include " + + # Implement other required methods +``` + +## Using Custom Instructions + +Once you've defined a custom instruction, you can use it to replace code fragments in your procedures. + +### Define Your Procedure + +Define your Exo procedure as usual. + +```python +@proc +def foo(src: [f32][4] @ DRAM, dst: [f32][4] @ Neon): + ... + for i in seq(0, ...): + ... + for j in seq(0, 4): + dst[j] = src[j] + ... +``` + +### Use `replace` to Substitute the Instruction + +Use the `replace` primitive to substitute the loop with the custom instruction. + +```python +# Replace the loop with the custom instruction +foo = replace(foo, "for j in _:_", neon_vld_4xf32) +``` + +- **`replace(foo, "for i in _:_", neon_vld_4xf32)`**: + - **`foo`**: The procedure in which to perform the replacement. + - **`"for i in _:_"`**: A cursor pointing to the loop to replace. + - **`neon_vld_4xf32`**: The instruction to replace the loop with. + +After `replace`, the procedure `foo` will look like: +```python +@proc +def foo(M: size, src: [f32][4] @ DRAM, dst: [f32][4] @ Neon): + ... + for i in seq(0, M/4): + ... + neon_vld_4xf32(dst, src) + ... +``` + +#### How `replace` Works + +The `replace` primitive is used to substitute a fragment of code within a procedure with a call to another procedure (e.g., a custom instruction). The syntax for `replace` is as follows: + +```python +replace(proc, cursor_path, subproc) +``` + +- **`proc`**: The procedure containing the code to be replaced. +- **`cursor`**: A cursor pointing to the code fragment to be replaced. +- **`subproc`**: The procedure whose body will replace the code fragment. + +The `replace` primitive works by performing an unification modulo linear equalities. The process can be broken down into two main steps: + +1. **Pattern Matching**: The body of the sub-procedure `subproc` is unified (pattern matched) with the designated statement block `s` in the original procedure `proc`. During this process: + - The arguments of `subproc` are treated as unknowns. + - The free variables of `s` are treated as known symbols. + - Any symbols introduced or bound within the body of `subproc` or within `s` are unified. + + The ASTs (Abstract Syntax Trees) of `subproc` and `s` are required to match exactly with respect to statements and all expressions that are not simply integer-typed control. + +2. **Solving Linear Equations**: Any equivalences between integer-typed control expressions are recorded as a system of linear equations. These equations are then solved to determine the values of the unknowns and ensure a consistent substitution. + +By following this process, the `replace` primitive effectively replaces the designated code fragment with a call to the sub-procedure, while ensuring that the substitution is valid and consistent. + + +### Generated C Code + +`exocc` can be used to compile Exo code into C. + +```c +void foo(float src[4], float32x4_t dst) { + ... + for (int_fast32_t i = 0; i < ...; i++) { + ... + dst = vld1q_f32(&src[0]); + } + ... +} +``` + +- **`dst = vld1q_f32(&src[0]);`**: The custom instruction is emitted as specified in the `@instr` decorator, with arguments replaced. + +## Understanding the Magic + +By defining the behavior of hardware instructions in Python using Exo procedures, you can express the semantics of your accelerator or specialized hardware. The `replace` primitive allows Exo to reason about whether it's safe to offload certain computations to hardware instructions based on their specifications. + +- **No Compiler Backend Needed**: The heavy lifting is done within Exo, eliminating the need for a separate compiler backend. +- **Semantics Encoding**: The instruction's body acts as a specification, encoding its semantics for Exo's pattern matching. +- **Flexible and Extensible**: Users can define any instruction and specify how it should be matched and replaced. + + +## Further Reading and Examples + +- **RVM Tutorial**: [https://exo-lang.dev/tutorial.html](https://exo-lang.dev/tutorial.html) +- **Running Code Examples**: [examples/rvm_conv1d/exo/conv1d.py](https://github.com/exo-lang/exo/blob/main/examples/rvm_conv1d/exo/conv1d.py) diff --git a/docs/memories.md b/docs/memories.md new file mode 100644 index 000000000..2cc946dfa --- /dev/null +++ b/docs/memories.md @@ -0,0 +1,200 @@ +# External Memory Definitions + +Exo allows users to define custom memory types external to the compiler. +This feature enables modeling of specialized memory systems, such as vector machines and hardware accelerator memories, directly within your Exo code. +By defining custom memories, you can optimize your programs to target specific hardware architectures. + +## Overview + +- **Custom Memories**: Define your own memory types by subclassing the `Memory` class. +- **Usage**: Use custom memories as annotations in your Exo code or set them during scheduling. + +## Defining Custom Memories + +To define a custom memory, you need to create a class that inherits from `Memory` and implement the required methods. +Below is an example of defining an `AVX512` memory, which models the AVX-512 vector registers. + +### Example: Defining AVX512 Memory + +```python +class AVX512(Memory): + @classmethod + def global_(cls): + return "#include " + + @classmethod + def can_read(cls): + return False + + @classmethod + def alloc(cls, new_name, prim_type, shape, srcinfo): + if not shape: + raise MemGenError(f"{srcinfo}: AVX512 vectors are not scalar values") + if not prim_type == "float": + raise MemGenError(f"{srcinfo}: AVX512 vectors must be f32 (for now)") + if not shape[-1].isdecimal() and int(shape[-1]) == 16: + raise MemGenError(f"{srcinfo}: AVX512 vectors must be 16-wide") + shape = shape[:-1] + if shape: + result = f'__m512 {new_name}[{"][".join(map(str, shape))}];' + else: + result = f"__m512 {new_name};" + return result + + @classmethod + def free(cls, new_name, prim_type, shape, srcinfo): + return "" + + @classmethod + def window(cls, basetyp, baseptr, indices, strides, srcinfo): + assert strides[-1] == "1" + idxs = indices[:-1] or "" + if idxs: + idxs = "[" + "][".join(idxs) + "]" + return f"{baseptr}{idxs}" +``` + +#### Explanation of Methods + +- **`global_(cls)`**: Returns any global code or headers needed. Here, it includes the AVX-512 intrinsic header. + + ```python + @classmethod + def global_(cls): + return "#include " + ``` + +- **`can_read(cls)`**: Controls whether the memory can be read directly. Setting it to `False` means you cannot read/write directly to this memory using standard array access. + + ```python + @classmethod + def can_read(cls): + return False + ``` + +- **`alloc(cls, new_name, prim_type, shape, srcinfo)`**: Defines how to lower `LoopIR.Alloc` into C code. + Allocation in Exo is expressed as `x : f32[N, M]`. + - `new_name`: A C string representing the allocated variable name. In this example, it would be `"x"`. + - `prim_type`: A C string representing the primitive data type. In this example, it would be `"float"`. The mapping from LoopIR types to C types is as follows: + - `f16` -> `"_Float16"` + - `f32` -> `"float"` + - `f64` -> `"double"` + - `i8` -> `"int8_t"` + - `ui8` -> `"uint8_t"` + - `ui16`-> `"uint16_t"` + - `i32` -> `"int32_t"` + - `shape`: A list of C strings representing the shape of each dimension. In the example above, it would be `["N", "M"]`. + + For `AVX512` memory, the `alloc` method ensures that the allocated memory represents 16-wide vectors (`shape[-1].isdecimal() and int(shape[-1]) == 16`) of the `float` type (`prim_type == "float"`). + + +- **`free(cls, new_name, prim_type, shape, srcinfo)`**: Handles memory deallocation. For `AVX512`, no action is needed. + + ```python + @classmethod + def free(cls, new_name, prim_type, shape, srcinfo): + return "" + ``` + +- **`window(cls, basetyp, baseptr, indices, strides, srcinfo)`**: Defines how array accesses are lowered into C code. + + Usually, you cannot access your specialized hardware accelerator memory from C code, and you will need to use your accelerator instructions to access it, like the following: + + ```python + x : f32[16,16] @ your_memory + your_instr(x[0, 0:16]) + ``` + + The `window` member defines how `x[0, 0:16]` should be lowered to C code, as different accelerator instructions and memory have different addressing schemes. + + For example, the Gemmini accelerator's scratchpad memory is 2D and has a fixed column width of 16. The Gemmini instruction expects accessing the scratchpad by *number of rows* only, and accessing columns is not permitted. Therefore, the window definition will look like: + + ```python + @classmethod + def window(cls, basetyp, baseptr, indices, strides, srcinfo): + # Assume that strides[-1] == 1 + # and that strides[-2] == 16 (if there is a strides[-2]) + assert len(indices) == len(strides) and len(strides) >= 2 + prim_type = basetyp.basetype().ctype() + offset = generate_offset(indices, strides) + return ( + f"*({prim_type}*)((uint64_t)( " + f"((uint32_t)((uint64_t){baseptr})) + " + f"({offset})/16))" + ) + ``` + + Explanation of arguments: + - `basetyp`: type of the buffer in `LoopIR.type` + - `baseptr`: C pointer string to the buffer (e.g., `x`) + - `indices`: List of C strings for index accesses for each dimension + - `strides`: List of C strings for strides for each dimension + - `srcinfo`: Source location information, Can be used for error messages + + Both tensor and window expressions will be resolved to vanilla indices and strides. + + +## Understanding `can_read` + +The `can_read` method controls whether direct array access is allowed for the memory type. When `can_read` is set to `False`, you cannot read or write to the memory using standard array indexing in Exo or the generated C code. This models hardware that requires special instructions for memory access, such as vector registers. + +### Invalid Usage + +Attempting to read or write directly results in an error. + +```python +x: f32[16] @ AVX512 +x[0] = 3.0 # Invalid when can_read() is False +``` + +### Valid Usage + +To interact with the memory, you must use specific instructions or operations designed for that memory type (e.g., AVX-512 intrinsics). + +```python +# Use AVX-512 instructions to manipulate x +x: f32[16] @ AVX512 +mm512_loadu_ps(x, inp[16*i : 16*i+16]) +``` +To learn more about how to define and use instructions in Exo, see [instructions.md](./instructions.md). + +## Using Custom Memories + +There are two primary ways to use custom memories in Exo: + +1. **Direct Annotation**: Annotate variables with the custom memory type using the `@` symbol. +2. **Scheduling Primitive**: Change the memory annotation during scheduling using `set_memory`. + +### 1. Direct Annotation + +Annotate buffers at the time of declaration. +```python +from exo import * +from exo.libs.memories import AVX512 + +@proc +def foo(x: f32[16] @ AVX512): + y: f32[16] @ AVX512 + # Function body +``` + +- **`x: f32[16] @ AVX512`**: Declares `x` as a 16-element array of `f32` stored in `AVX512` memory. +- **`y: f32[16] @ AVX512`**: Similarly declares `y` in `AVX512` memory. + +### 2. Changing Memory During Scheduling + +Use the `set_memory` primitive to change the memory annotation of a variable during scheduling. +- **`set_memory(p, "C", AVX512)`**: Changes the memory of variable `C` in procedure `p` to `AVX512`. +- This is common when optimizing simple object code (e.g., GEMM) for specific hardware. + +#### Documentation for `set_memory` + +The `set_memory` primitive is documented in [primitives/buffer_ops.md](primitives/buffer_ops.md). + + +## Additional Examples + +- **Memory Definitions**: More examples of custom memory definitions can be found in [src/exo/libs/memories.py](https://github.com/exo-lang/exo/blob/main/src/exo/libs/memories.py). +- **Usage in Applications**: Examples of using custom memories in real applications are available in [examples/rvm_conv1d/exo/conv1d.py](https://github.com/exo-lang/exo/blob/main/examples/rvm_conv1d/exo/conv1d.py). + + diff --git a/docs/object_code.md b/docs/object_code.md new file mode 100644 index 000000000..863288269 --- /dev/null +++ b/docs/object_code.md @@ -0,0 +1,367 @@ +# Exo Object Code Syntax + +In Exo, object code can be defined using Python-like syntax with specific annotations and constructs to model low-level programming concepts. + +This documentation explains Exo's object code syntax using the following example of a 1D convolution operation: + +```python +@proc +def generic_conv1d( + data: i32[IC, N] @ DRAM, + kernels: i32[OC, IC, W] @ DRAM, + out: i32[OC, N] @ DRAM, +): + # Perform the convolution + for i in seq(0, OC): + for j in seq(0, N): + # Zero out the output memory + out[i, j] = 0.0 + for c in seq(0, IC): + for r in seq(0, W): + y: i32 + if j + r < N: + y = data[c, j + r] + else: + y = 0 + out[i, j] += kernels[i, c, r] * y +``` + +## Table of Contents + +- [Annotations and Decorators](#annotations-and-decorators) + - [`@proc` Decorator](#proc-decorator) + - [Type and Memory Annotations](#type-and-memory-annotations) + - [Procedure Arguments](#procedure-arguments) + - [Allocations](#allocations) + - [Memories](#memories) +- [Loops](#loops) + - [`for` Loop Syntax](#for-loop-syntax) +- [Conditional Statements](#conditional-statements) +- [Assignments](#assignments) +- [Understanding the Example](#understanding-the-example) + +## Annotations and Decorators + +### `@proc` Decorator + +The `@proc` decorator is used to define an Exo procedure (analogous to a function in other programming languages). It indicates that the following function definition should be treated as Exo object code (not Python), which can be further optimized and transformed. + +```python +@proc +def function_name(arguments): + # Function body +``` + +### Type and Memory Annotations + +In Exo, types and memory spaces are explicitly annotated. The syntax is: + +```python +name: type[size] @ memory +``` + +- **`name`**: The variable name. +- **`type`**: The data type. Supported precision types are: `f16`, `f32`, `f64`, `i8`, `i32`, `ui8`, and `ui16`. +- **`[size]`**: The dimensions of the array (optional for scalars). +- **`@ memory`**: The memory space where the variable resides. + + +### Procedure Arguments + +Procedure arguments are declared with their types, sizes, and memory spaces. They can have sizes that depend on other arguments. + +Example from the code: + +```python +data: i32[IC, N] @ DRAM +``` + +- **`data`**: The name of the argument. +- **`i32`**: The data type (32-bit integer). +- **`[IC, N]`**: A 2D array with dimensions `IC` and `N`. +- **`@ DRAM`**: Specifies that `data` resides in DRAM memory. + +The `data` buffer above represents **tensor** types, which means the stride of the innermost dimension is 1, and the strides of other dimensions are simple multiples of the shapes of the inner dimensions. + +Exo allows **window expressions** as well, which are similar to array slicing in Python. Instead of accessing the buffer point-wise (e.g., `x[i]`), users can *window* the array as `x[i:i+2]`. This will create a windowed array of size 2. +Exo procedures take tensor expressions when annotated with `x:f32[3]` syntax and take window expressions when annotated with `x:[f32][3]`, with square brackets around the types. + +```python +@proc +def foo(x: [f32][3]): + for i in seq(0, 3): + x[i] = 0.0 + +@proc +def bar(y: f32[10], z: f32[20, 20]): + foo(y[2:5]) + foo(z[1, 10:13]) +``` + +In this example, `foo` takes a window array of size 3, and `bar` calls `foo` by slicing `y` and `z`, respectively. Running `exocc` on this will generate the following C code: + +```c +#include "tmp.h" + +#include +#include + +// bar( +// y : f32[10] @DRAM, +// z : f32[20, 20] @DRAM +// ) +void bar(void *ctxt, float* y, float* z) { + foo(ctxt, (struct exo_win_1f32){ &y[2], { 1 } }); + foo(ctxt, (struct exo_win_1f32){ &z[20 + 10], { 1 } }); +} + +// foo( +// x : [f32][3] @DRAM +// ) +void foo(void *ctxt, struct exo_win_1f32 x) { + for (int_fast32_t i = 0; i < 3; i++) { + x.data[i * x.strides[0]] = 0.0f; + } +} +``` + +Moreover, Exo checks the consistency of tensor and window bounds in the frontend. If you modify `foo(y[2:5])` to `foo(y[2:6])` in the code above, the bounds check will fail and emit the following error: + +``` +TypeError: Errors occurred during effect checking: +/private/tmp/tmp.py:12:8: type-shape of calling argument may not equal the required type-shape: [Effects.BinOp(op='-', lhs=Effects.Const(val=6, type=LoopIR.Int(), srcinfo=), rhs=Effects.Const(val=2, type=LoopIR.Int(), srcinfo=), type=LoopIR.Index(), srcinfo=)] vs. [Effects.Const(val=3, type=LoopIR.Int(), srcinfo=)]. It could be non equal when: + y_stride_0 = 1, z_stride_0 = 20, z_stride_1 = 1 +``` + +#### Aliasing Limitations + +When passing buffers to procedure arguments, aliasing is not allowed. Concretely, you cannot write something like: + +```python +foo(y, y) +foo(y[0:5], y[2:7]) +``` + +This limitation exists because the analysis would be imprecise if we allowed such aliasing. This is similar to how C++ compilers can perform more optimization when you use the `__restrict__` keyword to explicitly indicate that you're not aliasing your buffers. + + +#### Passing Tensor Window Slices to Functions Expecting Non-Window Tensors + +It is not allowed to pass a _window_ to a function that expects a non-window tensor as an argument. Consider the following example: + +```python +@proc +def callee(x: f32[10]): + pass + +@proc +def caller(x: f32[2, 10]): + callee(x[0]) # Error: Passing a window slice to a function expecting a non-window tensor + callee(x[1, :]) # Error: Passing a window slice to a function expecting a non-window tensor +``` + +In this code snippet, the `callee` function expects a non-window tensor `x` of shape `f32[10]`. However, in the `caller` function, we attempt to pass slices of the `x` tensor (`x[0]` and `x[1]`) to the `callee` function. These slices are windows of the original tensor, and passing them to a function expecting a non-window tensor is not allowed. + +To resolve this issue, you can either: +1. Modify the `callee` function to accept a window tensor as an argument, or +2. Create a new non-window tensor from the slice before passing it to the `callee` function. + + +### Allocations + +Variables within the procedure are declared similarly to arguments. + +Example: + +```python +y: i32 +``` + +- **`y`**: The variable name. +- **`i32`**: The data type (32-bit integer). +- **No memory annotation**: Defaults to `DRAM` if memory is unspecified. + +### Memories + +Memory annotations in Exo are used to model different hardware memory regions, such as DRAM, caches, or specialized memories. The `@` symbol is used to specify the memory space, for example: `@DRAM`, `@AVX2`, or `@Neon`. +Memory annotations for your custom hardware accelerators can be defined externally to Exo and can be used as annotations in the same way. +While Exo provides default memory (`DRAM`) and some library memory definitions for convenience (`AVX2`, `AVX512`, `Neon`, `GEMM_SCRATCH`, etc.), it is recommended and encouraged that users define their own memory annotations for their specific hardware. For more information on defining custom memory annotations, refer to [memories.md](./memories.md). + + + +## Loops + +### `for` Loop Syntax + +Exo uses explicit loop constructs to model iteration. The `for` loop syntax is: + +```python +for loop_variable in seq(start, end): + # Loop body +``` + +- **`loop_variable`**: The loop counter variable. +- **`seq(start, end)`**: Iterates from `start` to `end - 1`. + +Example from the code: + +```python +for i in seq(0, OC): + # Iterates i from 0 to OC - 1 +``` + +## Conditional Statements + +Conditional logic is expressed using `if` and `else` statements. + +Syntax: + +```python +if condition: + # True branch +else: + # False branch +``` + +Example: + +```python +if j + r < N: + y = data[c, j + r] +else: + y = 0 +``` + +- Checks if `j + r` is less than `N`. +- Assigns `y` accordingly. + +## Assignments + +- **Assignment (`=`)**: Assigns a value to a variable. + + ```python + y = data[c, j + r] + ``` + +- **Reduction (`+=`)**: Adds a value to a variable and stores the result back. + + ```python + out[i, j] += kernels[i, c, r] * y + ``` + +- **Array Access**: Uses square brackets to access array elements. + + ```python + data[c, j + r] + ``` + +- **Window Statements**: Creates a slice (in other words, _window_) of the buffer and assign a new name. + ```python + y = x[0:3] + ``` + +## Limitations + +Exo has a few limitations that users should be aware of: + +1. **Non-affine indexing**: Exo does not support non-affine indexing. This means that any indexing operation must be a linear combination of loop variables and constants. For example, the following expressions are not allowed: + + ```python + data[i * j + r] = 0.0 # i * j is non-affine + if n * m < 30: # n * m is non-affine + pass + ``` + + Exo allows quasi-affine indexing by division (e.g., `i/3`) and modulo (e.g., `i%3`) by constants. + + To work around this limitation, you may need to restructure your code or use additional variables to represent the non-affine expressions. + +2. **Value-dependent control flow**: Exo separates control values from buffer values, which means that it is not possible to write value-dependent control flow. For instance, the following code is not allowed: + + ```python + if data[i] < 3.0: + pass + ``` + + If you need to express such operations, consider using externs (see [externs documentation](./externs.md)). + + +## Understanding the Example + +Let's break down the example code step by step. + +### Procedure Definition + +```python +@proc +def generic_conv1d( + data: i32[IC, N] @ DRAM, + kernels: i32[OC, IC, W] @ DRAM, + out: i32[OC, N] @ DRAM, +): +``` + +- **`generic_conv1d`**: The procedure name. +- **Arguments**: + - **`data`**: Input data array of shape `[IC, N]` in DRAM. + - **`kernels`**: Kernel weights array of shape `[OC, IC, W]` in DRAM. + - **`out`**: Output data array of shape `[OC, N]` in DRAM. +- **Variables**: + - **`IC`**, **`OC`**, **`N`**, **`W`**: Dimensions, assumed to be defined elsewhere or passed as parameters. + +### Loop Nest + +```python +for i in seq(0, OC): + for j in seq(0, N): + # Zero out the output memory + out[i, j] = 0.0 + for c in seq(0, IC): + for r in seq(0, W): + y: i32 + if j + r < N: + y = data[c, j + r] + else: + y = 0 + out[i, j] += kernels[i, c, r] * y +``` + +#### Outer Loops + +- **`for i in seq(0, OC):`**: Iterates over the output channels. +- **`for j in seq(0, N):`**: Iterates over the spatial dimension of the output. + +#### Initialization + +- **`out[i, j] = 0.0`**: Initializes the output element at `(i, j)` to zero. + +#### Inner Loops + +- **`for c in seq(0, IC):`**: Iterates over the input channels. +- **`for r in seq(0, W):`**: Iterates over the kernel width. + +#### Conditional Data Access + +```python +y: i32 +if j + r < N: + y = data[c, j + r] +else: + y = 0 +``` + +- **Purpose**: Handles boundary conditions where the kernel extends beyond the input data. +- **`y`**: Temporary variable to hold the input data or zero. +- **Condition**: + - **If `j + r < N`**: Valid index; assign `data[c, j + r]` to `y`. + - **Else**: Out-of-bounds; assign `0` to `y`. + +#### Accumulation + +```python +out[i, j] += kernels[i, c, r] * y +``` + +- **Operation**: Accumulates the product of the kernel weight and the input data into the output. +- **`kernels[i, c, r]`**: Kernel weight for output channel `i`, input channel `c`, at position `r`. +- **`y`**: The input data value or zero. diff --git a/docs/backend_ops.md b/docs/primitives/backend_ops.md similarity index 100% rename from docs/backend_ops.md rename to docs/primitives/backend_ops.md diff --git a/docs/buffer_ops.md b/docs/primitives/buffer_ops.md similarity index 100% rename from docs/buffer_ops.md rename to docs/primitives/buffer_ops.md diff --git a/docs/config_ops.md b/docs/primitives/config_ops.md similarity index 100% rename from docs/config_ops.md rename to docs/primitives/config_ops.md diff --git a/docs/loop_ops.md b/docs/primitives/loop_ops.md similarity index 100% rename from docs/loop_ops.md rename to docs/primitives/loop_ops.md diff --git a/docs/other_ops.md b/docs/primitives/other_ops.md similarity index 100% rename from docs/other_ops.md rename to docs/primitives/other_ops.md diff --git a/docs/subproc_ops.md b/docs/primitives/subproc_ops.md similarity index 100% rename from docs/subproc_ops.md rename to docs/primitives/subproc_ops.md diff --git a/examples/README.md b/examples/README.md index a586f4786..caf1c0656 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,268 +1,10 @@ -# Scheduling Example +# Scheduling Examples -Please [install Exo](https://github.com/exo-lang/exo#install-exo) before proceeding with this example. -This tutorial assumes some familiarity with SIMD instructions. +This directory contains several examples, along with documentation and code. +If you are new to Exo, we recommend going through the examples in the following order: -Exo provides *scheduling operators* to transform program and rewrite them to make use of complex hardware instructions. -We'll show you how to take a simple matrix multiplication kernel and transform it into an implementation that can make use of [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) vector instructions. +1. [AVX2 Matmul](./avx2_matmul/README.md): This example demonstrates how to take a simple matrix multiplication kernel and transform it into an implementation that can make use of AVX2 instructions. It provides an overview of Exo and its scheduling system. -The complete code with scheduling operations can be found in `exo/examples/x86_matmul.py`, and running `make` will compile the Exo code and generate an executable `avx2_matmul`. - -## Basic Implementation - -To start off, let's implement a basic matrix multiplication kernel in Exo object code: -```py -from __future__ import annotation -from exo import * - -@proc -def rank_k_reduce_6x16( - K: size, A: f32[6, K] @ DRAM, B: f32[K, 16] @ DRAM, C: f32[6, 16] @ DRAM -): - for i in seq(0, 6): - for j in seq(0, 16): - for k in seq(0, K): - C[i, j] += A[i, k] * B[k, j] - -print(rank_k_reduce_6x16) -``` - -These microkernels usually function as the inner loops of highly optimized linear algebra computations. For example, [BLIS][] (an open-source [BLAS][] library) is architected around re-implementing such microkernels for each new target architecture that they support. The goal of Exo is to make this specialization process dramatically easier. - -For our example, we want to specialize the kernel to use the AVX2 instructions; it is likely the case that a vectorizing compiler cannot automatically transform this kernel. - -## Scheduling Walkthrough - -Scheduling plays a central role in Exo's machinery to generate high-performance kernels. -Instead of relying on automated compiler passes, we can specify *program rewrites* that allow Exo to generate high-performance code. - -Looking at our kernel, we can see that the contraction dimension `k` is amenable to streaming. -This means that we want to perform vectorized computation using the `i` and `j` iterators. -At high-level, we're going to perform the following set of rewrites to enable our vectorized computation: -- Reorder to the loops to expose streaming behavior -- Decompose the loading and storing of the output -- Vectorize the inner loop computation - -While doing this, we also have to contend with the restriction that the AVX2 instruction set exposes *16 vector registers* which means if our computation attempts to use any more than that, we'll have register spillage and lose out on the performance. - -### Reordering Loops - -The first step is reordering the loops in our program so that the streaming nature of the computation is better expressed. -We can do this using the `reorder_loops`. -Also, just to keep things easy to follow, we're going to rename our kernel to `rank_k_reduce_6x16_scheduled`: - -First, let's import the scheduling operations at the top of the file: -```py -from exo.stdlib.scheduling import * -``` - -Next, we can add the *scheduling commands* which act upon a give kernel and return a new kernel for us. Kernels in Exo are also called procs so we'll use those names interchangeably: -```py -avx = rename(rank_k_reduce_6x16, "rank_k_reduce_6x16_scheduled") -avx = reorder_loops(avx, 'j k') -avx = reorder_loops(avx, 'i k') - -print(avx) -``` - -The `rename` command is straightforward: it renamed our proc. Most of the time, we want access to both our original kernel and the optimized kernel so we recommend rename them. -The `reorder_loops` command is more interesting, it takes a pattern or a *cursor* to the loops that should be reordered. -For example, the pattern `j k` is the same as: -```py -for j in _: - for k in _: _ -``` -This tells Exo to find a program fragment that matches those two, immediately nested loop nests, and reorder them. -The `j k` is a shorthand syntax for exactly that pattern. - -Finally, the `print(avx)` shows us the resulting program's loop nests. Note that they have been reordered! -```py -... - for k in seq(0, K): - for i in seq(0, 6): - for j in seq(0, 16): - C[i, j] += A[i, k] * B[k, j] - -``` - -> When scheduling a new program, we often leave the `print(...)` command at the bottom and keep running the program to the see the output of the last scheduling step. - -### Vectorizing the Output - -The reordered loops let us better see the opportunity to expose vectorizing in our program. -At a high-level, we produce our outputs as a $6\times 16$ matrix which can be stored in 12, 8-wide vectors. -Since the size of the `k` dimension is unknown, we have to keep iterating on it, but we can make use of a register blocking strategy to vectorize our computation. - -To do this, we will use some more complicated scheduling operations in Exo. We encourage you to step through the transformation done by each operation by printing out `avx`: - -```py -avx = divide_loop(avx, 'for j in _: _', 8, ['jo', 'ji'], perfect=True) -avx = stage_mem(avx, 'for k in _:_', 'C[0:6, 0:16]', 'C_reg') -avx = simplify(avx) -``` - -We perform three transformations: -- `divide_loop` splits the innermost `j` loop into two loops so that we have a `for _ in seq(0, 8)` which represents the size of our vectors. -- `stage_mem` replaces the use of the output memory `C` with `C_reg` and generates loops to load and store values from and to the memory. -- `simplify` simplifies simple constant expressions - -Note that in the result, we have a new memory `C_reg: f32[6, 16] @ DRAM`. -This is not quite in the shape we want; a vector register should have a size of 8 so that we can map it to the AVX2 instructions. -The next set of transformations will address this: - -```py -avx = divide_dim(avx, 'C_reg:_', 1, 8) -avx = repeat(divide_loop)(avx, 'for i1 in _: _', 8, ['i2', 'i3'], perfect=True) -avx = simplify(avx) -``` - -The `divide_dim` operation splits the last dimension of `C_reg` into two dimensions the latter of which has 8 elements. -Next, we use the `divide_loop` operator to split apart the loops that operate on the memory `C_reg` and see our first *higher-order scheduling operator* `repeat` which applies a scheduling operator till the scheduling operation fails. -The final `simplify` simplifies the index expressions. - -These changes give us a couple of loop nests amenable for mapping onto vector instructions: -```py - ... - for i3 in seq(0, 8): - C_reg[i0, i2, i3] = C[i0, i3 + 8 * i2] - ... - for jo in seq(0, 2): - for ji in seq(0, 8): - C_reg[i, jo, ji] += A[i, k] * B[k, ji + 8 * jo] - ... - for i3 in seq(0, 8): - C[i0, i3 + 8 * i2] = C_reg[i0, i2, i3] -``` - -In order of appearance, they perform a load from `C` into `C_reg`, performs the computation on `C_reg`, and store the results into `C` from `C_reg`. -The second loop nest cannot be vectorized just yet but the other two are vectorizable. - -### Instruction Mapping - -Exo support *instruction mapping* which takes a particular program fragment and replaces it with an equivalent instruction. -For example, we can take the following loop nest: -```py -for i3 in seq(0, 8): - C_reg[i0, i2, i3] = C[i0, i3 + 8 * i2] -``` -And turn it into the AVX2 `mm256_loadu_ps`. - -To do this, we import the AVX2 instructions and use the `replace_all` operator to replace all matching loop nests: -```py -from exo.platforms.x86 import * -... -avx = set_memory(avx, 'C_reg:_', AVX2) -avx = replace_all(avx, mm256_loadu_ps) -avx = simplify(avx) -print(avx) -``` - -This transforms the above loop nest into: -```py -mm256_loadu_ps(C_reg[i0, i2, 0:8], C[i0, 8 * i2:8 + 8 * i2]) -``` - -The `set_memory` operator marks the `C_reg` memory as an AVX2 vector register explicitly and `replace_all` attempts to rewrite all loops in the code that implement a load into the `mm256_loadu_ps` instruction. - -The latter is a bit magical! How does the scheduling operator know what the semantics of the instruction are and when it is safe to rewrite loops to the instructions? -This is the final part of Exo's magic: the definitions of these instructions are *externalized*, i.e., provided by you: -```py -@instr("{dst_data} = _mm256_loadu_ps(&{src_data});") -def mm256_loadu_ps(dst: [f32][8] @ AVX2, src: [f32][8] @ DRAM): - assert stride(src, 0) == 1 - assert stride(dst, 0) == 1 - - for i in seq(0, 8): - dst[i] = src[i] -``` -The definition implements the semantics of the instruction using plain old python code and the `replace_all` command knows how to replace them using this definition. - -### Take a Breather - -Congratulations on getting through a whirlwind tour of Exo's capabilities. To review, we've seen a couple of concepts that work in tandem to make enable productive performance engineering: -- *Scheduling operators* allow you to rewrite programs. -- *Instruction mapping* uses user-level instruction definitions to rewrite program fragments to backend instructions. - -### Vectorizing the Computation - -Next, we're going to vectorize the innermost computation. However, we have to work with our original constraint: AVX2 exposes 16 vector registers, and we've consumed 12 of those for our output memory. The rest of computation needs to be staged carefully so that we don't end up taking more than 4 registers. - -The scheduling will follow a similar pattern to the previous sections: we want to stage memories `A` and `B` using vector registers and replace their uses from the computational kernel. - -Let's start off with `B` which is the larger of the two: -```py -# B is easy, it is just two vector loads -avx = stage_mem(avx, 'for i in _:_', 'B[k, 0:16]', 'B_reg') -avx = simplify(avx) -avx = divide_loop(avx, 'for i0 in _: _ #1', 8, ['io', 'ii'], perfect=True) -avx = divide_dim(avx, 'B_reg:_', 0, 8) -avx = set_memory(avx, 'B_reg:_', AVX2) -avx = simplify(avx) -avx = replace_all(avx, mm256_loadu_ps) -avx = simplify(avx) -print(avx) -``` - -We'll not be going into the details of each scheduling operate since you've already seen all of them before, but we encourage you to step through them and printing out `avx` after each operation. - -The rewritten program exposes the reuse pattern available for the data in `B`: -```py -... - for k in seq(0, K): - B_reg: f32[2, 8] @ AVX2 - for io in seq(0, 2): - mm256_loadu_ps(B_reg[io, 0:8], B[k, 8 * io:8 + 8 * io]) - for i in seq(0, 6): - for jo in seq(0, 2): - for ji in seq(0, 8): - C_reg[i, jo, ji] += A[i, k] * B_reg[jo, ji] -``` -For each `k` value, we get to load 16 values from `B` (two vector register's worth) and perform the computation using those. - -Next, we need to stage `A`: -```py -avx = bind_expr(avx, 'A[i, k]', 'A_reg') -avx = expand_dim(avx, 'A_reg', 8, 'ji') -avx = lift_alloc(avx, 'A_reg', n_lifts=2) -avx = fission(avx, avx.find('A_reg[ji] = _').after(), n_lifts=2) -avx = remove_loop(avx, 'for jo in _: _') -avx = set_memory(avx, 'A_reg:_', AVX2) -avx = replace_all(avx, mm256_broadcast_ss) -print(avx) -``` - -Staging `A` is a little more complex because unlike `C` and `B`, its reuse pattern is different: each value of `A` is broadcast into `A_reg` which is then used to perform the innermost computation. There are a couple of new scheduling operators: -- `lift_alloc`: Move an variable definition through the specified number of loops. -- `fission`: Splits apart the loop using the given cursor. -- `remove_loop`: Eliminates an unused loop. - -Finally, we can vectorize the computation: -```py -avx = replace_all(avx, mm256_fmadd_ps) -print(avx) -``` -This is perhaps a bit underwhelming however, under the hood, Exo has been performing analyses, automatic rewriting of loop bounds and indexing expressions to make the process easier. The analysis serve as guard rails for the powerful rewrite rules and are topic of another tutorial. - -## Compiling - -Finally, the code can be compiled and run on your machine if you have AVX2 instructions. -We provided a main function in `main.c` to call these procedures and to time them. -Please run `make` or compile manually: - -```sh -$ exocc -o . --stem avx2_matmul x86_matmul.py -$ gcc -o avx2_matmul -march=native main.c avx2_matmul.c -``` - -This will print out the results of running kernel with and without the AVX instructions. - -[blas]: https://www.netlib.org/blas/ -[blis]: https://github.com/flame/blis - -## Stay tuned for more automation! - -Congratulations on completing this example! -You might have felt that the scheduling operations in this example were very low-level and might be laborious to write. -We felt the same! We have a pre-release version of Exo that provides scheduling automation _external_ to the compiler implementation. -By sharing the repeated pattern of schedules and using our novel reference mechanism called Cursors, we achieve fewer lines of code than what we've shown here in the upcoming release. Please contact Exo developers at exo@mit.edu if you want to learn more or wish to collaborate! +2. [Cursor](./cursors/README.md): This example shows how to use Cursors to efficiently write schedules and define a new scheduling operator. +3. [RVM](./rvm_conv1d/README.md): This example illustrates how to use Exo to define and target a new hardware accelerator entirely in the user code. diff --git a/examples/Makefile b/examples/avx2_matmul/Makefile similarity index 100% rename from examples/Makefile rename to examples/avx2_matmul/Makefile diff --git a/examples/avx2_matmul/README.md b/examples/avx2_matmul/README.md new file mode 100644 index 000000000..4db6742c1 --- /dev/null +++ b/examples/avx2_matmul/README.md @@ -0,0 +1,267 @@ +# Scheduling Example + +Please [install Exo](https://github.com/exo-lang/exo#install-exo) before proceeding with this example. +This tutorial assumes some familiarity with SIMD instructions. + +Exo provides *scheduling operators* to transform program and rewrite them to make use of complex hardware instructions. +We'll show you how to take a simple matrix multiplication kernel and transform it into an implementation that can make use of [AVX2](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) vector instructions. + +The complete code with scheduling operations can be found in `exo/examples/avx2_matmul/x86_matmul.py`, and running `make` will compile the Exo code and generate an executable `avx2_matmul`. + +## Basic Implementation + +To start off, let's implement a basic matrix multiplication kernel in Exo object code: +```py +from __future__ import annotation +from exo import * + +@proc +def rank_k_reduce_6x16( + K: size, A: f32[6, K] @ DRAM, B: f32[K, 16] @ DRAM, C: f32[6, 16] @ DRAM +): + for i in seq(0, 6): + for j in seq(0, 16): + for k in seq(0, K): + C[i, j] += A[i, k] * B[k, j] + +print(rank_k_reduce_6x16) +``` + +These microkernels usually function as the inner loops of highly optimized linear algebra computations. For example, [BLIS][] (an open-source [BLAS][] library) is architected around re-implementing such microkernels for each new target architecture that they support. The goal of Exo is to make this specialization process dramatically easier. + +For our example, we want to specialize the kernel to use the AVX2 instructions; it is likely the case that a vectorizing compiler cannot automatically transform this kernel. + +## Scheduling Walkthrough + +Scheduling plays a central role in Exo's machinery to generate high-performance kernels. +Instead of relying on automated compiler passes, we can specify *program rewrites* that allow Exo to generate high-performance code. + +Looking at our kernel, we can see that the contraction dimension `k` is amenable to streaming. +This means that we want to perform vectorized computation using the `i` and `j` iterators. +At high-level, we're going to perform the following set of rewrites to enable our vectorized computation: +- Reorder to the loops to expose streaming behavior +- Decompose the loading and storing of the output +- Vectorize the inner loop computation + +While doing this, we also have to contend with the restriction that the AVX2 instruction set exposes *16 vector registers* which means if our computation attempts to use any more than that, we'll have register spillage and lose out on the performance. + +### Reordering Loops + +The first step is reordering the loops in our program so that the streaming nature of the computation is better expressed. +We can do this using the `reorder_loops`. +Also, just to keep things easy to follow, we're going to rename our kernel to `rank_k_reduce_6x16_scheduled`: + +First, let's import the scheduling operations at the top of the file: +```py +from exo.stdlib.scheduling import * +``` + +Next, we can add the *scheduling commands* which act upon a give kernel and return a new kernel for us. Kernels in Exo are also called procs so we'll use those names interchangeably: +```py +avx = rename(rank_k_reduce_6x16, "rank_k_reduce_6x16_scheduled") +avx = reorder_loops(avx, 'j k') +avx = reorder_loops(avx, 'i k') + +print(avx) +``` + +The `rename` command is straightforward: it renamed our proc. Most of the time, we want access to both our original kernel and the optimized kernel so we recommend rename them. +The `reorder_loops` command is more interesting, it takes a pattern or a *cursor* to the loops that should be reordered. +For example, the pattern `j k` is the same as: +```py +for j in _: + for k in _: _ +``` +This tells Exo to find a program fragment that matches those two, immediately nested loop nests, and reorder them. +The `j k` is a shorthand syntax for exactly that pattern. + +Finally, the `print(avx)` shows us the resulting program's loop nests. Note that they have been reordered! +```py +... + for k in seq(0, K): + for i in seq(0, 6): + for j in seq(0, 16): + C[i, j] += A[i, k] * B[k, j] + +``` + +> When scheduling a new program, we often leave the `print(...)` command at the bottom and keep running the program to the see the output of the last scheduling step. + +### Vectorizing the Output + +The reordered loops let us better see the opportunity to expose vectorizing in our program. +At a high-level, we produce our outputs as a $6\times 16$ matrix which can be stored in 12, 8-wide vectors. +Since the size of the `k` dimension is unknown, we have to keep iterating on it, but we can make use of a register blocking strategy to vectorize our computation. + +To do this, we will use some more complicated scheduling operations in Exo. We encourage you to step through the transformation done by each operation by printing out `avx`: + +```py +avx = divide_loop(avx, 'for j in _: _', 8, ['jo', 'ji'], perfect=True) +avx = stage_mem(avx, 'for k in _:_', 'C[0:6, 0:16]', 'C_reg') +avx = simplify(avx) +``` + +We perform three transformations: +- `divide_loop` splits the innermost `j` loop into two loops so that we have a `for _ in seq(0, 8)` which represents the size of our vectors. +- `stage_mem` replaces the use of the output memory `C` with `C_reg` and generates loops to load and store values from and to the memory. +- `simplify` simplifies simple constant expressions + +Note that in the result, we have a new memory `C_reg: f32[6, 16] @ DRAM`. +This is not quite in the shape we want; a vector register should have a size of 8 so that we can map it to the AVX2 instructions. +The next set of transformations will address this: + +```py +avx = divide_dim(avx, 'C_reg:_', 1, 8) +avx = repeat(divide_loop)(avx, 'for i1 in _: _', 8, ['i2', 'i3'], perfect=True) +avx = simplify(avx) +``` + +The `divide_dim` operation splits the last dimension of `C_reg` into two dimensions the latter of which has 8 elements. +Next, we use the `divide_loop` operator to split apart the loops that operate on the memory `C_reg` and see our first *higher-order scheduling operator* `repeat` which applies a scheduling operator till the scheduling operation fails. +The final `simplify` simplifies the index expressions. + +These changes give us a couple of loop nests amenable for mapping onto vector instructions: +```py + ... + for i3 in seq(0, 8): + C_reg[i0, i2, i3] = C[i0, i3 + 8 * i2] + ... + for jo in seq(0, 2): + for ji in seq(0, 8): + C_reg[i, jo, ji] += A[i, k] * B[k, ji + 8 * jo] + ... + for i3 in seq(0, 8): + C[i0, i3 + 8 * i2] = C_reg[i0, i2, i3] +``` + +In order of appearance, they perform a load from `C` into `C_reg`, performs the computation on `C_reg`, and store the results into `C` from `C_reg`. +The second loop nest cannot be vectorized just yet but the other two are vectorizable. + +### Instruction Mapping + +Exo support *instruction mapping* which takes a particular program fragment and replaces it with an equivalent instruction. +For example, we can take the following loop nest: +```py +for i3 in seq(0, 8): + C_reg[i0, i2, i3] = C[i0, i3 + 8 * i2] +``` +And turn it into the AVX2 `mm256_loadu_ps`. + +To do this, we import the AVX2 instructions and use the `replace_all` operator to replace all matching loop nests: +```py +from exo.platforms.x86 import * +... +avx = set_memory(avx, 'C_reg:_', AVX2) +avx = replace_all(avx, mm256_loadu_ps) +avx = simplify(avx) +print(avx) +``` + +This transforms the above loop nest into: +```py +mm256_loadu_ps(C_reg[i0, i2, 0:8], C[i0, 8 * i2:8 + 8 * i2]) +``` + +The `set_memory` operator marks the `C_reg` memory as an AVX2 vector register explicitly and `replace_all` attempts to rewrite all loops in the code that implement a load into the `mm256_loadu_ps` instruction. + +The latter is a bit magical! How does the scheduling operator know what the semantics of the instruction are and when it is safe to rewrite loops to the instructions? +This is the final part of Exo's magic: the definitions of these instructions are *externalized*, i.e., provided by you: +```py +@instr("{dst_data} = _mm256_loadu_ps(&{src_data});") +def mm256_loadu_ps(dst: [f32][8] @ AVX2, src: [f32][8] @ DRAM): + assert stride(src, 0) == 1 + assert stride(dst, 0) == 1 + + for i in seq(0, 8): + dst[i] = src[i] +``` +The definition implements the semantics of the instruction using plain old python code and the `replace_all` command knows how to replace them using this definition. + +### Take a Breather + +Congratulations on getting through a whirlwind tour of Exo's capabilities. To review, we've seen a couple of concepts that work in tandem to make enable productive performance engineering: +- *Scheduling operators* allow you to rewrite programs. +- *Instruction mapping* uses user-level instruction definitions to rewrite program fragments to backend instructions. + +### Vectorizing the Computation + +Next, we're going to vectorize the innermost computation. However, we have to work with our original constraint: AVX2 exposes 16 vector registers, and we've consumed 12 of those for our output memory. The rest of computation needs to be staged carefully so that we don't end up taking more than 4 registers. + +The scheduling will follow a similar pattern to the previous sections: we want to stage memories `A` and `B` using vector registers and replace their uses from the computational kernel. + +Let's start off with `B` which is the larger of the two: +```py +# B is easy, it is just two vector loads +avx = stage_mem(avx, 'for i in _:_', 'B[k, 0:16]', 'B_reg') +avx = simplify(avx) +avx = divide_loop(avx, 'for i0 in _: _ #1', 8, ['io', 'ii'], perfect=True) +avx = divide_dim(avx, 'B_reg:_', 0, 8) +avx = set_memory(avx, 'B_reg:_', AVX2) +avx = simplify(avx) +avx = replace_all(avx, mm256_loadu_ps) +avx = simplify(avx) +print(avx) +``` + +We'll not be going into the details of each scheduling operate since you've already seen all of them before, but we encourage you to step through them and printing out `avx` after each operation. + +The rewritten program exposes the reuse pattern available for the data in `B`: +```py +... + for k in seq(0, K): + B_reg: f32[2, 8] @ AVX2 + for io in seq(0, 2): + mm256_loadu_ps(B_reg[io, 0:8], B[k, 8 * io:8 + 8 * io]) + for i in seq(0, 6): + for jo in seq(0, 2): + for ji in seq(0, 8): + C_reg[i, jo, ji] += A[i, k] * B_reg[jo, ji] +``` +For each `k` value, we get to load 16 values from `B` (two vector register's worth) and perform the computation using those. + +Next, we need to stage `A`: +```py +avx = bind_expr(avx, 'A[i, k]', 'A_reg') +avx = expand_dim(avx, 'A_reg', 8, 'ji') +avx = lift_alloc(avx, 'A_reg', n_lifts=2) +avx = fission(avx, avx.find('A_reg[ji] = _').after(), n_lifts=2) +avx = remove_loop(avx, 'for jo in _: _') +avx = set_memory(avx, 'A_reg:_', AVX2) +avx = replace_all(avx, mm256_broadcast_ss) +print(avx) +``` + +Staging `A` is a little more complex because unlike `C` and `B`, its reuse pattern is different: each value of `A` is broadcast into `A_reg` which is then used to perform the innermost computation. There are a couple of new scheduling operators: +- `lift_alloc`: Move an variable definition through the specified number of loops. +- `fission`: Splits apart the loop using the given cursor. +- `remove_loop`: Eliminates an unused loop. + +Finally, we can vectorize the computation: +```py +avx = replace_all(avx, mm256_fmadd_ps) +print(avx) +``` +This is perhaps a bit underwhelming however, under the hood, Exo has been performing analyses, automatic rewriting of loop bounds and indexing expressions to make the process easier. The analysis serve as guard rails for the powerful rewrite rules and are topic of another tutorial. + +## Compiling + +Finally, the code can be compiled and run on your machine if you have AVX2 instructions. +We provided a main function in `main.c` to call these procedures and to time them. +Please run `make` or compile manually: + +```sh +$ exocc -o . --stem avx2_matmul x86_matmul.py +$ gcc -o avx2_matmul -march=native main.c avx2_matmul.c +``` + +This will print out the results of running kernel with and without the AVX instructions. + +[blas]: https://www.netlib.org/blas/ +[blis]: https://github.com/flame/blis + +## More Automation? + +Congratulations on completing this example! +You might have felt that the scheduling operations in this example were very low-level and could be laborious to write. +We felt the same! We implemented a new feature called Cursors that provides scheduling automation *external* to the compiler implementation. +To learn more, please take a look at the [cursors example](cursors/README.md) and our ASPLOS '25 paper. diff --git a/examples/main.c b/examples/avx2_matmul/main.c similarity index 100% rename from examples/main.c rename to examples/avx2_matmul/main.c diff --git a/examples/x86_matmul.py b/examples/avx2_matmul/x86_matmul.py similarity index 100% rename from examples/x86_matmul.py rename to examples/avx2_matmul/x86_matmul.py diff --git a/examples/cursors/.gitignore b/examples/cursors/.gitignore new file mode 100644 index 000000000..6cbfadf78 --- /dev/null +++ b/examples/cursors/.gitignore @@ -0,0 +1 @@ +cursors/ diff --git a/examples/cursors/README.md b/examples/cursors/README.md new file mode 100644 index 000000000..90a3cef55 --- /dev/null +++ b/examples/cursors/README.md @@ -0,0 +1,22 @@ +# Cursor Step-by-Step Tutorial + +This example demonstrates Cursors using the tile2D example (as shown in our ASPLOS '25 paper). + +## Overview + +This example covers the key concepts presented in the paper: +- Finding Cursors with pattern-matching +- Cursor navigation +- Applying scheduling primitives using cursors +- Cursor forwarding after code transformations +- Defining a new scheduling operation + +## Getting Started + +To run this example: +```bash +exocc cursors.py +``` +Running `exocc` on `cursors.py` will generate the C code in the `cursors/cursors.c` file. +It will also print out the intermediate steps of the example. + diff --git a/examples/cursors/cursors.py b/examples/cursors/cursors.py new file mode 100644 index 000000000..8cc58ddab --- /dev/null +++ b/examples/cursors/cursors.py @@ -0,0 +1,142 @@ +from __future__ import annotations +from exo import * +from exo.API_scheduling import * + +""" +Cursor Example + +This example introduces the concept of Cursors in Exo 2 paper and demonstrates +how to use scheduling operators with them to manipulate loops and optimize code. + +Cursors allow you to select and refer to parts of the code such as expressions, +statements, and code blocks. They also support spatial navigation within a procedure +to proximate locations. + +Key concepts covered: +- Finding Cursors with pattern-matching +- Cursor navigation +- Applying scheduling primitives using cursors +- Cursor forwarding after code transformations +- Defining a new scheduling operation +""" + + +""" +1: Basic loop example using Exo 2 + +GEMV kernel: y = A * x +Args: + M (size): Number of rows in matrix A + N (size): Number of columns in matrix A + A (tensor): M x N matrix stored in DRAM + x (tensor): N-dimensional vector stored in DRAM + y (tensor): M-dimensional vector stored in DRAM +""" + + +@proc +def gemv(M: size, N: size, A: f32[M, N], x: f32[N], y: f32[M]): + assert M % 8 == 0 + assert N % 8 == 0 + + for i in seq(0, M): + for j in seq(0, N): + y[i] += A[i, j] * x[j] + + +print("1: Original GEMV kernel") +print(gemv) +print() + + +""" +2: Finding cursors +""" +# Find a cursor to the i loop by name +i_loop = gemv.find_loop("i") + +# Find the same i loop by pattern +i_loop2 = gemv.find("for i in _: _") + +# Check that two cursors are pointing to the same 'i' loop +assert i_loop == i_loop2 + +print("2: i_loop points to:") +print(i_loop) +print() + + +""" +3: Navigating with cursors +""" +# Find cursors to key parts of the code +j_loop = i_loop.body()[0] # j is the only statement in i's body +C_store = j_loop.body()[0] # y[i] = ... is the only statement in j's body +j_loop_parent = j_loop.parent() # The parent of the j loop + +# Check that j_loop's parent is indeed pointing to the i_loop +assert i_loop == j_loop_parent + +print("3: j_loop points to:") +print(j_loop) +print() + + +""" +4: Applying scheduling primitives & Cursor forwarding +""" +# First, rename the gemv +g = rename(gemv, "gemv_scheduled") + +# Divide the i loop by 8 +g = divide_loop(g, i_loop, 8, ["io", "ii"], perfect=True) + +# Divide the j loop by 8 +g = divide_loop(g, j_loop, 8, ["jo", "ji"], perfect=True) + +# Now, we want to reorder ii and jo loops, by lifting the scope of j_loop +# We can still use the j_loop cursor! +g1 = lift_scope(g, j_loop) +g2 = lift_scope(g, g.forward(j_loop)) + +# Assert that g1 and g2 are the same (`j_loop` is implicitly forwarded in the first line) +assert g1 == g2 + +print("4: Tiled gemv") +print(g1) +print("4: g.forward(j_loop) points to:") +print(g.forward(j_loop)) +print() + + +""" +5: Defining a new scheduling operator +""" + + +def tile_2D(p, i_lp, j_lp, i_itrs, j_itrs, i_sz, j_sz): + """ + Perform a 2D tiling of the i and j loops. + Args: + p: Procedure to be tiled + i_lp: Name of the i loop + j_lp: Name of the j loop + i_itrs: New iterators for the i loop + j_itrs: New iterators for the j loop + i_sz: Tile size for the i loop + j_sz: Tile size for the j loop + """ + p = divide_loop(p, i_lp, i_sz, i_itrs, perfect=True) + p = divide_loop(p, j_lp, j_sz, j_itrs, perfect=True) + p = lift_scope(p, j_itrs[0]) + return p + + +# Example usage of tile_2D to perform 2D tiling on the gemv kernel. +final_g = tile_2D(gemv, i_loop, j_loop, ["io", "ii"], ["jo", "ji"], 8, 8) + +print("5: tile_2D applied gemv:") +print(final_g) + + +__all__ = ["final_g"] diff --git a/examples/rvm_conv1d/.gitignore b/examples/rvm_conv1d/.gitignore new file mode 100644 index 000000000..466e24805 --- /dev/null +++ b/examples/rvm_conv1d/.gitignore @@ -0,0 +1 @@ +out/ \ No newline at end of file diff --git a/examples/rvm_conv1d/Makefile b/examples/rvm_conv1d/Makefile new file mode 100644 index 000000000..4720f3da6 --- /dev/null +++ b/examples/rvm_conv1d/Makefile @@ -0,0 +1,31 @@ +PROG = conv1d +OUT = out/ +CC = "${RISCV}/bin/clang" +SPIKE = "${RISCV}/bin/spike" +ASFLAGS = -march=rv32imc_xtheadmatrix0p1 -menable-experimental-extensions +CFLAGS = -O2 -g3 $(ASFLAGS) + +default: sim +exo_comp: exo/conv1d_exo.c + +$(OUT)/$(PROG).elf: $(OUT)/$(PROG).o $(OUT)/conv1d_exo.o + $(CC) $(LDFLAGS) -o $@ $^ + +$(OUT)/$(PROG).o: main.c exo/conv1d_exo.h conv1Di32.h $(OUT) + $(CC) $(CFLAGS) -o $@ -c $< + +$(OUT)/conv1d_exo.o: exo/conv1d_exo.c $(OUT) + $(CC) $(CFLAGS) -o $@ -c $< + +$(OUT): + @mkdir -p $(OUT) + +exo/conv1d_exo.h: exo/conv1d_exo.c +exo/conv1d_exo.c: exo/conv1d.py + exocc -o exo/ --stem conv1d_exo exo/conv1d.py + +conv1Di32.h: gen_stimuli.py + python3 $< + +sim: $(OUT)/$(PROG).elf + @$(SPIKE) --isa=RV32IMC_xmatrix pk -s $< \ No newline at end of file diff --git a/examples/rvm_conv1d/README.md b/examples/rvm_conv1d/README.md new file mode 100644 index 000000000..93b2e77ca --- /dev/null +++ b/examples/rvm_conv1d/README.md @@ -0,0 +1,48 @@ +# Conv1D on RVM example + +This is an implementation of a simplified 1D convolution routine, using a custom [RISC-V ISA extension called RVM](https://github.com/esl-epfl/xheep_matrix_spec/tree/main). + +The tutorial accompanying this example is on [the main website](https://exo-lang.dev/tutorial.html). This page will just show you how to first compile the Exo program to C, and how to run it as well (optional.) + +## File organization + +* `main.c` - driver program testing handwritten vs Exo routine +* `gen_stimuli.py` - generate C arrays used as test vectors for conv1d routine, with expected output +* `conv1Di32.h` - generated output from `gen_stimuli.py` +* `exo/conv1d.py` - Exo code for conv1d +* `exo/conv1d_exo.{c,d,h}` - generated outputs from Exo + + +## Setup Exo & Compile + +First follow [the documentation](https://github.com/exo-lang/exo#install-exo) to install Exo, if you have not already. We assume `exocc` is in `$PATH`, and you have `make` installed. To compile the exo program in `exo/conv1d.py`, run: + +```bash +make exo_comp +``` + +The resulting C code for the example will be in `exo/conv1d_exo.c`. + +From here, if you would like to also compile the program to a RISC-V binary, and run it in a simulator, you will need the custom RVM toolchain. The following steps walk through that process. Otherwise, you can stop here. + +## Install RVM toolchain + +RVM is the custom RISC-V extension, which supports instructions and registers to do matrix operations. It requires a custom LLVM toolchain to build code, and in order to run programs, a fork of the Spike simulator. [The repo for RVM has a guide to set up these components](https://github.com/esl-epfl/xheep_matrix_spec/blob/main/BUILDING.md). In the end you should have the LLVM tools as well as Spike installed under `$RISCV/bin`. + + +## Build + +Run `make` to build the driver program, and simulate it in spike. **This assumes you have `$RISCV` defined from the installation step.** You should see an output like this: + +``` +$ make +... +handwritten err: 0 +exo err: 0 +2350 ticks +93797 cycles +93799 instructions +0.99 CPI +``` + +Note that the cycle counts are *not* accurate, and they should not be used to measure performance. Unfortunately, the hardware for RVM is not public as of today, and the Spike simulator is not meant to simulate these details, so it is only used for testing functional correctness. \ No newline at end of file diff --git a/examples/rvm_conv1d/conv1Di32.h b/examples/rvm_conv1d/conv1Di32.h new file mode 100644 index 000000000..b083e0bc3 --- /dev/null +++ b/examples/rvm_conv1d/conv1Di32.h @@ -0,0 +1,62 @@ +#ifndef _CONV1Di32 +#define _CONV1Di32 +// This file is automatically generated +int32_t __attribute__((section(".xheep_data_interleaved"))) +DATA[] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, + 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, + 12, 13, 14, 15, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; + +int32_t __attribute__((section(".xheep_data_interleaved"))) +KERNELS[] = {0, 1, 2, 3, 10, 11, 12, 13, 20, 21, 22, 23, 30, 31, 32, 33, 100, + 101, 102, 103, 110, 111, 112, 113, 120, 121, 122, 123, 130, 131, 132, 133, + 200, 201, 202, 203, 210, 211, 212, 213, 220, 221, 222, 223, 230, 231, 232, + 233, 300, 301, 302, 303, 310, 311, 312, 313, 320, 321, 322, 323, 330, 331, + 332, 333, 400, 401, 402, 403, 410, 411, 412, 413, 420, 421, 422, 423, 430, + 431, 432, 433, 500, 501, 502, 503, 510, 511, 512, 513, 520, 521, 522, 523, + 530, 531, 532, 533, 600, 601, 602, 603, 610, 611, 612, 613, 620, 621, 622, + 623, 630, 631, 632, 633, 700, 701, 702, 703, 710, 711, 712, 713, 720, 721, + 722, 723, 730, 731, 732, 733, 800, 801, 802, 803, 810, 811, 812, 813, 820, + 821, 822, 823, 830, 831, 832, 833, 900, 901, 902, 903, 910, 911, 912, 913, + 920, 921, 922, 923, 930, 931, 932, 933, 1000, 1001, 1002, 1003, 1010, 1011, + 1012, 1013, 1020, 1021, 1022, 1023, 1030, 1031, 1032, 1033, 1100, 1101, + 1102, 1103, 1110, 1111, 1112, 1113, 1120, 1121, 1122, 1123, 1130, 1131, + 1132, 1133, 1200, 1201, 1202, 1203, 1210, 1211, 1212, 1213, 1220, 1221, + 1222, 1223, 1230, 1231, 1232, 1233, 1300, 1301, 1302, 1303, 1310, 1311, + 1312, 1313, 1320, 1321, 1322, 1323, 1330, 1331, 1332, 1333, 1400, 1401, + 1402, 1403, 1410, 1411, 1412, 1413, 1420, 1421, 1422, 1423, 1430, 1431, + 1432, 1433, 1500, 1501, 1502, 1503, 1510, 1511, 1512, 1513, 1520, 1521, + 1522, 1523, 1530, 1531, 1532, 1533}; + +int32_t __attribute__((section(".xheep_data_interleaved"))) EXPECTED[] = {416, + 680, 944, 1208, 1472, 1736, 2000, 2264, 2528, 2792, 3056, 3320, 3584, 2696, + 1800, 900, 2816, 4680, 6544, 8408, 10272, 12136, 14000, 15864, 17728, 19592, + 21456, 23320, 25184, 19496, 13400, 6900, 5216, 8680, 12144, 15608, 19072, + 22536, 26000, 29464, 32928, 36392, 39856, 43320, 46784, 36296, 25000, 12900, + 7616, 12680, 17744, 22808, 27872, 32936, 38000, 43064, 48128, 53192, 58256, + 63320, 68384, 53096, 36600, 18900, 10016, 16680, 23344, 30008, 36672, 43336, + 50000, 56664, 63328, 69992, 76656, 83320, 89984, 69896, 48200, 24900, 12416, + 20680, 28944, 37208, 45472, 53736, 62000, 70264, 78528, 86792, 95056, + 103320, 111584, 86696, 59800, 30900, 14816, 24680, 34544, 44408, 54272, + 64136, 74000, 83864, 93728, 103592, 113456, 123320, 133184, 103496, 71400, + 36900, 17216, 28680, 40144, 51608, 63072, 74536, 86000, 97464, 108928, + 120392, 131856, 143320, 154784, 120296, 83000, 42900, 19616, 32680, 45744, + 58808, 71872, 84936, 98000, 111064, 124128, 137192, 150256, 163320, 176384, + 137096, 94600, 48900, 22016, 36680, 51344, 66008, 80672, 95336, 110000, + 124664, 139328, 153992, 168656, 183320, 197984, 153896, 106200, 54900, + 24416, 40680, 56944, 73208, 89472, 105736, 122000, 138264, 154528, 170792, + 187056, 203320, 219584, 170696, 117800, 60900, 26816, 44680, 62544, 80408, + 98272, 116136, 134000, 151864, 169728, 187592, 205456, 223320, 241184, + 187496, 129400, 66900, 29216, 48680, 68144, 87608, 107072, 126536, 146000, + 165464, 184928, 204392, 223856, 243320, 262784, 204296, 141000, 72900, + 31616, 52680, 73744, 94808, 115872, 136936, 158000, 179064, 200128, 221192, + 242256, 263320, 284384, 221096, 152600, 78900, 34016, 56680, 79344, 102008, + 124672, 147336, 170000, 192664, 215328, 237992, 260656, 283320, 305984, + 237896, 164200, 84900, 36416, 60680, 84944, 109208, 133472, 157736, 182000, + 206264, 230528, 254792, 279056, 303320, 327584, 254696, 175800, 90900}; + +#define N 16 +#define IC 4 +#define W 4 +#define OC 16 +#define PAD 1 +#endif \ No newline at end of file diff --git a/examples/rvm_conv1d/exo/.gitignore b/examples/rvm_conv1d/exo/.gitignore new file mode 100644 index 000000000..c06c83d3c --- /dev/null +++ b/examples/rvm_conv1d/exo/.gitignore @@ -0,0 +1,4 @@ +__pycache__/ +conv1d_exo.c +conv1d_exo.h +conv1d_exo.d diff --git a/examples/rvm_conv1d/exo/conv1d.py b/examples/rvm_conv1d/exo/conv1d.py new file mode 100644 index 000000000..38b51d6cc --- /dev/null +++ b/examples/rvm_conv1d/exo/conv1d.py @@ -0,0 +1,410 @@ +from __future__ import annotations + +import os +import sys + +import exo.API_cursors as pc +from exo import proc +from exo.libs.memories import * +from exo.platforms.x86 import * +from exo.stdlib.scheduling import * +from exo.stdlib.stdlib import * + +############# +# ALGORITHM # +############# +N = 16 +IC = 4 +W = 4 +OC = 16 +TILE = 4 + + +def gen_conv1d(): + @proc + def generic_conv1d( + data: i32[IC, N], + kernels: i32[OC, IC, W], + out: i32[OC, N], + ): + # do the convolution + for i in seq(0, OC): + for j in seq(0, N): + # zero out the result memory + out[i, j] = 0.0 + for c in seq(0, IC): + for r in seq(0, W): + y: i32 + if j + r < N: + y = data[c, j + r] + else: + y = 0 + out[i, j] += kernels[i, c, r] * y + + return generic_conv1d + + +############## +# HW LIBRARY # +############## + + +class RVM_TILE(StaticMemory): + NUM_RVM_TILES = 8 + StaticMemory.init_state(NUM_RVM_TILES) + tile_dict = {} + + @classmethod + def reset_allocations(cls): + cls.init_state(cls.NUM_RVM_TILES) + cls.tile_dict = {} + + @classmethod + def can_read(cls): + return False + + @classmethod + def alloc(cls, new_name, prim_type, shape, srcinfo): + if not (len(shape) == 2): + raise MemGenError("Must be a 2D tile.") + if not (shape[0].isdecimal() and int(shape[0]) == 4): + raise MemGenError("Number of tile rows must be 4.") + if not (shape[1].isdecimal() and int(shape[1]) == 4): + raise MemGenError("Number of tile columns must be 4.") + + tile_num = cls.find_free_chunk() + cls.mark(tile_num) + cls.tile_dict[new_name] = tile_num + return f'#define {new_name} "m{7-tile_num}"' + + @classmethod + def free(cls, new_name, prim_type, shape, srcinfo): + tile_num = cls.tile_dict[new_name] + del cls.tile_dict[new_name] + cls.unmark(tile_num) + return f"#undef {new_name}" + + +@instr( + 'asm volatile("mld.w "{dst_int}", (%1), %0" :: "r"(4*({src}.strides[0])), "r"(&{src_data}));' +) +def rvm_mld(dst: [i32][4, 4] @ RVM_TILE, src: [i32][4, 4] @ DRAM): + assert stride(src, 1) == 1 + assert stride(dst, 1) == 1 + + for i in seq(0, 4): + for j in seq(0, 4): + dst[i, j] = src[i, j] + + +@instr('asm volatile("mzero "{dst_int});') +def rvm_mzero(dst: [i32][4, 4] @ RVM_TILE): + assert stride(dst, 1) == 1 + + for i in seq(0, 4): + for j in seq(0, 4): + dst[i, j] = 0.0 + + +@instr( + 'asm volatile("mst.w "{src_int}", (%1), %0" :: "r"(4*({dst}.strides[0])), "r"(&{dst_data}));' +) +def rvm_mst(src: [i32][4, 4] @ RVM_TILE, dst: [i32][4, 4] @ DRAM): + assert stride(src, 1) == 1 + assert stride(dst, 1) == 1 + + for i in seq(0, 4): + for j in seq(0, 4): + dst[i, j] = src[i, j] + + +@instr('asm volatile("mmasa.w "{md_int}", "{ms1_int}", "{ms2_int});') +def rvm_mmasa( + md: [i32][4, 4] @ RVM_TILE, ms1: [i32][4, 4] @ RVM_TILE, ms2: [i32][4, 4] @ RVM_TILE +): + assert stride(md, 1) == 1 + assert stride(ms1, 1) == 1 + assert stride(ms2, 1) == 1 + for i in seq(0, 4): + for j in seq(0, 4): + for k in seq(0, 4): + md[i, j] += ms2[i, k] * ms1[j, k] + + +########################## +# CUSTOM REWRITING RULES # +########################## + + +def fuse_two_loops(p, c): + """ + for i in ...: <- c + for j in ...: + s1 + for k in ...: <- c.next() + for i in ...: + s2 + ----> + for i in ...: <- c + for j in ...: + s1 + for k in ...: + s2 + """ + try: + next_c = c.next() + except: + return p, False + + if isinstance(c, pc.ForCursor) and isinstance(next_c, pc.ForCursor): + if c.name() == next_c.name() and expr_to_string(c.hi()) == expr_to_string( + next_c.hi() + ): + p = fuse(p, c, next_c, unsafe_disable_check=False) + return p, True + else: + tgt_c, count = find_child_loop(next_c, c.name()) + if tgt_c: + p = lift_scope_n(p, tgt_c, n_lifts=count) + p = fuse(p, c, tgt_c, unsafe_disable_check=False) + return p, True + + return p, False + + +def fuse_all_loops(p, cursor): + """ + recursively calls fuse_two_loops to all the loops + """ + while True: + if isinstance(cursor, pc.ForCursor): + p = fuse_all_loops(p, cursor.body()[0]) + + # Fuse in current scope + p, b = fuse_two_loops(p, cursor) + + if b: + cursor = p.forward(cursor) + else: + try: + cursor = p.forward(cursor).next() + except: + break + + return p + + +def autolift_alloc(p, alloc_c, dep_set=None, max_size=0, lift=True): + """ + for i in seq(0, 10): + for j in seq(0, 20): + a : R <- alloc_c, dep_set = {'i'} + a[i] = ... + ----> + a : R[10] <- if size is less than max_size + for i in seq(0, n): + for j in seq(0, m): + a[i] = ... + """ + alloc_c = p.forward(alloc_c) + loop_c = get_enclosing_loop(p, alloc_c) + accum_size = 1 + while True: + try: + if not isinstance(loop_c, pc.ForCursor): + break + if dep_set == None or loop_c.name() in dep_set: + if ( + isinstance(loop_c.hi(), LiteralCursor) + and accum_size * loop_c.hi().value() <= max_size + ): + p = expand_dim(p, alloc_c, loop_c.hi().value(), loop_c.name()) + accum_size = accum_size * loop_c.hi().value() + if lift: + p = lift_alloc(p, alloc_c) + loop_c = loop_c.parent() + except: + break + return p + + +def reorder_top(p, c): + """ + for i in seq(0, 10): + s1 + s2 + s3 <- c + ----> + for i in seq(0, 10): + s3 <- c + s1 + s2 + """ + c = p.forward(c) + while True: + try: + p = reorder_stmts(p, c.expand(1, 0)) + c = p.forward(c) + except: + break + return p + + +def fission_as_much_as_possible(p, cursor): + """ + for i in ...: + for j in ...: + s1 + s2 <- cursor + s3 + ---> + for i in ...: + for j in ...: + s2 + + for i in ...: + for j in ...: + s1 + s3 + """ + cursor = p.forward(cursor) + p = reorder_top(p, cursor) + gap_c = cursor.after() + while True: + try: + p = fission(p, gap_c) + gap_c = p.forward(gap_c).parent().after() + except: + break + + return p + + +def lift_scope_n(p, c, n_lifts=1): + """ + for i in seq(0, 10): + for j in seq(0, 10): + for k in seq(0, 10): + if ...: <- c + s1 + ----> if n_lifts == 2: + for i in seq(0, 10): + if ...: <- c + for j in seq(0, 10): + for k in seq(0, 10): + s1 + """ + for i in range(0, n_lifts): + p = lift_scope(p, c) + return p + + +def remove_redundant_loops(p, c, num=0): + """ + for i in ...: + for j in ...: + s1[j] <- c + ---> + for j in ...: + s1[j] <- c + """ + c = p.forward(c) + cur_depth = 0 + while True: + c = c.parent() + if not isinstance(c, pc.ForCursor): + break + try: + if cur_depth >= num: + break + hi = c.hi().value() + name = c.name() + child = p.forward(c).body()[0] + p = remove_loop(p, c) + cur_depth += 1 + except: + continue + return p + + +############## +# SCHEDULING # +############## + + +def optimize_conv(p): + p = rename(p, "exo_conv1d_tile_lt_kw") + + # Before scheduling, grab cursors to the object code. + i_loop = p.find("for i in _:_") + j_loop = p.find("for j in _:_") + c_loop = p.find("for c in _:_") + y_alloc = p.find("y : _") + y_assign = p.find("y = data[_]") + + # Tile outer loops to TILE size for RVM + p, _ = tile_loops(p, [(i_loop, TILE), (j_loop, TILE)], perfect=True) + p, _ = tile_loops(p, [(i_loop, 4)], perfect=True) + i_loop_reg = p.find("for ioi in _:_") + p = reorder_loops(p, i_loop_reg) + + # Stage output to out_tile + p, (out_alloc, out_tile, body, _) = auto_stage_mem( + p, p.find_loop("c").expand(1, 0), "out", "out_tile", rc=True + ) + p = autolift_alloc(p, out_tile, max_size=4 * 4 * 4, dep_set=["ioi", "ii", "ji"]) + + # Block the zero initialization and store blocks + p = fission_as_much_as_possible(p, body) + p = fission_as_much_as_possible(p, body[0]) + + # Reorder c loop to the top + p = lift_scope_n(p, c_loop, 3) + + # Stage y + p = autolift_alloc(p, y_alloc, max_size=4 * 4, dep_set=["r", "ji"]) + p = lift_alloc(p, y_alloc, n_lifts=2) + + # Fission the initialization loop and remove redundant loops + p = fission_as_much_as_possible(p, y_assign.parent()) + p = remove_redundant_loops(p, y_assign.parent(), num=2) + + # Stage kernels to kernel_tile and y to data_tile + ii_loop = p.forward(c_loop).body()[2].body()[0] + p, (kernel_alloc, _, _, _) = auto_stage_mem( + p, ii_loop, "kernels", "kernel_tile", rc=True + ) + p = simplify(expand_dim(p, kernel_alloc, 4, ii_loop.parent().name())) + p = lift_alloc(p, kernel_alloc) + p, (data_alloc, _, _, _) = auto_stage_mem( + p, ii_loop.parent(), "y", "data_tile", rc=True + ) + + # Set adequate memories + p = set_memory(p, y_alloc, DRAM_STATIC) + p = set_memory(p, out_tile, RVM_TILE) + p = set_memory(p, kernel_alloc, RVM_TILE) + p = set_memory(p, data_alloc, RVM_TILE) + + # Replace inner loops to calls to RVM instructions + p = replace_all(p, [rvm_mzero, rvm_mst, rvm_mld, rvm_mmasa]) + + # Clean up + p = unroll_loop(p, "ioi") + p = unroll_loop(p, "ioi") + p = unroll_loop(p, "ioi") + p = simplify(p) + p = unroll_buffer(p, kernel_alloc, 0) + p = reuse_buffer(p, "kernel_tile_0: _", "kernel_tile_3: _") + p = unroll_buffer(p, "out_tile", 0) + + return p + + +def make_routine(): + generic_conv1d = gen_conv1d() + rvm_optimized = optimize_conv(generic_conv1d) + return rvm_optimized + + +exo_conv1d_tile_lt_kw = make_routine() diff --git a/examples/rvm_conv1d/gen_stimuli.py b/examples/rvm_conv1d/gen_stimuli.py new file mode 100644 index 000000000..341d972c5 --- /dev/null +++ b/examples/rvm_conv1d/gen_stimuli.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python + +import sys +import random + +# Copyright 2017 ETH Zurich and University of Bologna. +# Copyright and related rights are licensed under the Solderpad Hardware +# License, Version 0.51 (the License); you may not use this file except in +# compliance with the License. You may obtain a copy of the License at +# http://solderpad.org/licenses/SHL-0.51. Unless required by applicable law +# or agreed to in writing, software, hardware and materials distributed under +# this License is distributed on an AS IS BASIS, WITHOUT WARRANTIES OR +# CONDITIONS OF ANY KIND, either express or implied. See the License for the +# specific language governing permissions and limitations under the License. + + +def write_arr(f, name, arr, ctype, size, linebreak): + f.write(ctype + " " + name + "[] = {\n\t") + i = 1 + for v in arr: + if i % size == 0: + f.write("%d};\n\n" % (v)) + elif i % linebreak == 0: + f.write("%d,\n\t" % (v)) + else: + f.write("%d," % (v)) + i += 1 + return + + +################################################################################ +f = open("conv1Di32.h", "w") +f.write("#ifndef _CONV1Di32 \n") +f.write("#define _CONV1Di32 \n") +f.write("// This file is automatically generated\n") + + +N = 16 +IC = 4 +OC = 16 +W = 4 +RANGE = 4095 + +data = [] +kernel = [] +expected = [] + +pad = 1 + +# N C W format +for i in range(0, IC): + for j in range(0, N): + data.append(j) + # data.append(random.randint(-RANGE, RANGE-1)) + +# O I W format +for i in range(0, OC): + for j in range(0, IC): + for k in range(0, W): + kernel.append(i * 100 + j * 10 + k) + # kernel.append(random.randint(-RANGE, RANGE-1)) + +# O W format +for i in range(0, OC): + for j in range(0, N): + sum = 0 + for w_i in range(0, W): + for w_j in range(0, IC): + data_idx = j + w_i + data_at_idx = 0 + if data_idx < N: + data_at_idx = data[w_j * N + j + w_i] + sum += kernel[(IC * i + w_j) * W + w_i] * data_at_idx + expected.append(sum) + + +write_arr( + f, + "DATA", + data, + 'int32_t __attribute__((section(".xheep_data_interleaved")))', + IC * N, + 128, +) +write_arr( + f, + "KERNELS", + kernel, + 'int32_t __attribute__((section(".xheep_data_interleaved")))', + OC * IC * W, + 128, +) +write_arr( + f, + "EXPECTED", + expected, + 'int32_t __attribute__((section(".xheep_data_interleaved")))', + OC * N, + 128, +) + +f.write("#define N %d\n" % N) +f.write("#define IC %d\n" % IC) +f.write("#define W %d\n" % W) +f.write("#define OC %d\n" % OC) +f.write("#define PAD %d\n" % 1) + + +f.write("#endif") diff --git a/examples/rvm_conv1d/main.c b/examples/rvm_conv1d/main.c new file mode 100644 index 000000000..9f32513d3 --- /dev/null +++ b/examples/rvm_conv1d/main.c @@ -0,0 +1,162 @@ + +/* Includes */ +#include +#include +#include +#include + +#include "conv1Di32.h" +#include "exo/conv1d_exo.h" + +//////////////////// +// CONFIGURATION // +////////////////// + +#define TILE 4 + +///////////// +// MACROS // +/////////// + +#define CEIL_DIV(a, b) ((((a) % (b)) != 0) ? (((a) / (b)) + 1) : (a) / (b)) + +int32_t out[OC * N]; +int32_t data_tile[TILE][IC * W]; +int32_t result[OC * N]; +int32_t small_data_tile_a[TILE * TILE]; +int32_t small_data_tile_b[TILE * TILE]; + +//////////////// +// MAIN CODE // +////////////// + +void conv1d_tile_lt_kw_reord(int32_t *data, int32_t *kernels, int32_t *out) { + // should be ceil_div(ic*kw, tile) * tile + // and initialized to 0 + int tile_i_len = CEIL_DIV(OC, TILE * 4); + int tile_j_len = CEIL_DIV(N, TILE); + int data_base; + int cycles; + int32_t *kernel_base = kernels; + register int32_t *small_data_tile = small_data_tile_a; + register int32_t *temp; + for (int tile_i = 0; tile_i < tile_i_len; tile_i++) { + data_base = 0; + for (int tile_j = 0; tile_j < tile_j_len; tile_j++) { + asm volatile("mzero m1"); + asm volatile("mzero m2"); + asm volatile("mzero m3"); + asm volatile("mzero m4"); + int data_row = 0; + for (int tile_k = 0; tile_k < IC; tile_k++) { + // CSR_CLEAR_BITS(CSR_REG_MCOUNTINHIBIT, 0x1); + // CSR_WRITE(CSR_REG_MCYCLE, 0); + for (int replica = 0; replica < TILE; replica++) { + int ofs = data_base + replica; + int drow_ofs = data_row + ofs; + int dtile_ofs = replica * TILE; + for (int i = 0; i < W; i++) { + // Check that we are not out of bounds of the input in the current + // channel this should not block: addresses are different + small_data_tile[dtile_ofs] = 0; + if (ofs < N) { + small_data_tile[dtile_ofs] = data[drow_ofs]; + } + + ofs++; + drow_ofs++; + dtile_ofs++; + } + // CSR_READ(CSR_REG_MCYCLE, &cycles); + // printf("cyc: %d\n", cycles); + } + data_row += N; + + asm volatile( + "mld.w m0, (%1), %0" ::"r"(TILE * 4), "r"(small_data_tile)); + asm volatile("mld.w m5, (%1), %0" ::"r"(IC * W * 4), "r"(kernel_base)); + asm volatile("mmasa.w m1, m0, m5"); + asm volatile("mld.w m6, (%1), %0" ::"r"(IC * W * 4), + "r"(kernel_base + TILE * IC * W)); + asm volatile("mmasa.w m2, m0, m6"); + asm volatile("mld.w m7, (%1), %0" ::"r"(IC * W * 4), + "r"(kernel_base + TILE * IC * W * 2)); + asm volatile("mmasa.w m3, m0, m7"); + asm volatile("mld.w m5, (%1), %0" ::"r"(IC * W * 4), + "r"(kernel_base + TILE * IC * W * 3)); + asm volatile("mmasa.w m4, m0, m5"); + kernel_base += W; + // swap + // asm ("xor %0, %0, %1" : "=r"(small_data_tile_cur) : + // "r"(small_data_tile_old)); asm ("xor %0, %0, %1" : + // "=r"(small_data_tile_old) : "r"(small_data_tile_cur)); asm ("xor %0, + // %0, %1" : "=r"(small_data_tile_cur) : "r"(small_data_tile_old)); + } + int32_t *outptr = (out + (tile_i * N * 4 + tile_j) * TILE); + asm volatile("mst.w m1, (%1), %0" ::"r"(N * 4), "r"(outptr)); + asm volatile("mst.w m2, (%1), %0" ::"r"(N * 4), "r"(outptr + TILE * N)); + asm volatile( + "mst.w m3, (%1), %0" ::"r"(N * 4), "r"(outptr + TILE * N * 2)); + asm volatile( + "mst.w m4, (%1), %0" ::"r"(N * 4), "r"(outptr + TILE * N * 3)); + + data_base += TILE; + kernel_base -= W * IC; + } + kernel_base += TILE * IC * W * 4; + } +} + +#define BRANCHLESS_TERNARY(c, x, y) ((-(c)&x) | (~(-(c)) & y)); +void conv1d_cpu(int32_t *data, int32_t *kernels, int32_t *out) { + for (int i = 0; i < OC; i++) { + for (int j = 0; j < N; j++) { + out[N * i + j] = 0; + for (int w_i = 0; w_i < W; w_i++) { + for (int w_j = 0; w_j < IC; w_j++) { + int data_idx = j + w_i; + int kernel_idx = (IC * i + w_j) * W + w_i; + int data_at_idx = + BRANCHLESS_TERNARY(data_idx < N, data[w_j * N + j + w_i], 0); + out[N * i + j] += data_at_idx * kernels[kernel_idx]; + } + } + } + } +} + +int check_result(int32_t *result) { + int err = 0; + for (int i = 0; i < OC; i++) { + for (int j = 0; j < N; j++) { + if (result[N * i + j] != EXPECTED[N * i + j]) { + err++; + printf("exp %d got %d\n\r", EXPECTED[N * i + j], result[N * i + j]); + } + } + } + return err; +} + +int main() { + for (int i = 0; i < TILE; i++) { + for (int j = 0; j < TILE; j++) { + small_data_tile_a[i * TILE + j] = 0; + small_data_tile_b[i * TILE + j] = 0; + } + } + + conv1d_tile_lt_kw_reord(DATA, KERNELS, result); + printf("handwritten err: %d\n\r", check_result(result)); + + for (int i = 0; i < OC; i++) { + for (int j = 0; j < N; j++) { + result[i * N + j] = 0; + } + } + + exo_conv1d_tile_lt_kw(NULL, DATA, KERNELS, result); + printf("exo err: %d\n\r", check_result(result)); + + return 0; +} diff --git a/requirements.txt b/requirements.txt index ed467db0d..029815662 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ PySMT==0.9.6 asdl-adt==0.1.0 asdl==0.1.5 -build==1.2.2 -z3-solver==4.13.2.0 +build==1.2.2.post1 +z3-solver==4.13.3.0 yapf==0.40.2 diff --git a/src/exo/API.py b/src/exo/API.py index 7889f5509..3a690ca3d 100644 --- a/src/exo/API.py +++ b/src/exo/API.py @@ -5,27 +5,27 @@ from pathlib import Path from typing import Optional, Union, List -import exo.LoopIR_scheduling as scheduling -from exo.LoopIR_scheduling import SchedulingError +import exo.rewrite.LoopIR_scheduling as scheduling +from exo.rewrite.LoopIR_scheduling import SchedulingError from .API_types import ProcedureBase, ExoType -from . import LoopIR as LoopIR -from .LoopIR_compiler import run_compile, compile_to_strings -from .configs import Config -from .boundscheck import CheckBounds -from .memory import Memory -from .parse_fragment import parse_fragment -from .pattern_match import match_pattern -from .prelude import * -from .new_eff import Check_Aliasing +from .core import LoopIR as LoopIR +from .backend.LoopIR_compiler import run_compile, compile_to_strings +from .core.configs import Config +from .frontend.boundscheck import CheckBounds +from .core.memory import Memory +from .frontend.parse_fragment import parse_fragment +from .frontend.pattern_match import match_pattern +from .core.prelude import * +from .rewrite.new_eff import Check_Aliasing # Moved to new file -from .proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc -from .pyparser import get_ast_from_python, Parser, get_src_locals -from .typecheck import TypeChecker +from .core.proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc +from .frontend.pyparser import get_ast_from_python, Parser, get_src_locals +from .frontend.typecheck import TypeChecker from . import API_cursors as C -from . import internal_cursors as IC +from .core import internal_cursors as IC # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # @@ -245,7 +245,7 @@ def body(self): block = self._root()._child_block("body") return C.lift_cursor(block, self) - def find(self, pattern, many=False): + def find(self, pattern, many=False, call_depth=1): """ Find the most specific possible cursor for the given pattern. For example, a pattern matching a single assignment statement @@ -256,7 +256,7 @@ def find(self, pattern, many=False): In any event, if no matches are found, a SchedulingError is raised """ - return C.find(self._root(), self, pattern, many) + return C.find(self._root(), self, pattern, many, call_depth=call_depth + 1) def find_loop(self, pattern, many=False): """ @@ -273,7 +273,7 @@ def find_loop(self, pattern, many=False): name, count = results[1], (results[2] if results[2] else "") pattern = f"for {name} in _: _{count}" - return self.find(pattern, many) + return self.find(pattern, many, call_depth=1) def find_alloc_or_arg(self, pattern): _name_count_re = r"^([a-zA-Z_]\w*)\s*(\#\s*[0-9]+)?$" @@ -286,10 +286,10 @@ def find_alloc_or_arg(self, pattern): pattern = f"{name}: _{count}" - return self.find(pattern) + return self.find(pattern, call_depth=1) def find_all(self, pattern): - return self.find(pattern, many=True) + return self.find(pattern, many=True, call_depth=1) # ---------------------------------------------- # # execution / compilation operations diff --git a/src/exo/API_cursors.py b/src/exo/API_cursors.py index 2a8b8b755..e9b090544 100644 --- a/src/exo/API_cursors.py +++ b/src/exo/API_cursors.py @@ -7,18 +7,18 @@ from . import API # TODO: remove this circular import from .API_types import ExoType, loopir_type_to_exotype -from .LoopIR import LoopIR -from .configs import Config -from .memory import Memory +from .core.LoopIR import LoopIR +from .core.configs import Config +from .core.memory import Memory -from . import internal_cursors as C -from .pattern_match import match_pattern -from .prelude import Sym +from .core import internal_cursors as C +from .frontend.pattern_match import match_pattern +from .core.prelude import Sym # expose this particular exception as part of the API -from .internal_cursors import InvalidCursorError -from .LoopIR_pprint import _print_cursor -from .LoopIR_scheduling import SchedulingError +from .core.internal_cursors import InvalidCursorError +from .core.LoopIR_pprint import _print_cursor +from .rewrite.LoopIR_scheduling import SchedulingError # --------------------------------------------------------------------------- # @@ -72,9 +72,8 @@ class Cursor(ABC): | Literal( value : bool, int, or float ) | UnaryMinus( arg : Expr ) | BinaryOp( op : str, lhs : Expr, rhs : Expr ) - | BuiltIn( name : str, args : ExprList ) + | Extern( name : str, args : ExprList ) | WindowExpr( name : str, idx : *(see below) ) - | BuiltIn( name : str, args : ExprList ) The `idx` argument of `WindowExpr` is a list containing either `Expr` or `(Expr,Expr)` (a pair of expressions) at each position. @@ -128,8 +127,8 @@ def parent(self): return InvalidCursor() return lift_cursor(impl_parent, self._proc) - def find(self, pattern, many=False): - return find(self._impl, self._proc, pattern, many) + def find(self, pattern, many=False, call_depth=1): + return find(self._impl, self._proc, pattern, many, call_depth=call_depth + 1) def _child_node(self, *args, **kwargs): return lift_cursor(self._impl._child_node(*args, **kwargs), self._proc) @@ -783,7 +782,7 @@ def rhs(self) -> ExprCursor: return self._child_node("rhs") -class BuiltInFunctionCursor(ExprCursor): +class ExternFunctionCursor(ExprCursor): """ Cursor pointing to the call to some built-in function `name ( args )` @@ -791,13 +790,13 @@ class BuiltInFunctionCursor(ExprCursor): def name(self) -> str: assert isinstance(self._impl, C.Node) - assert isinstance(self._impl._node, LoopIR.BuiltIn) + assert isinstance(self._impl._node, LoopIR.Extern) return self._impl._node.f.name() def args(self) -> ExprListCursor: assert isinstance(self._impl, C.Node) - assert isinstance(self._impl._node, LoopIR.BuiltIn) + assert isinstance(self._impl._node, LoopIR.Extern) return ExprListCursor(self._impl._child_block("args"), self._proc) @@ -923,8 +922,8 @@ def lift_cursor(impl, proc): return UnaryMinusCursor(impl, proc) elif isinstance(n, LoopIR.BinOp): return BinaryOpCursor(impl, proc) - elif isinstance(n, LoopIR.BuiltIn): - return BuiltInFunctionCursor(impl, proc) + elif isinstance(n, LoopIR.Extern): + return ExternFunctionCursor(impl, proc) elif isinstance(n, LoopIR.WindowExpr): return WindowExprCursor(impl, proc) elif isinstance(n, LoopIR.StrideExpr): @@ -937,7 +936,7 @@ def lift_cursor(impl, proc): assert False, f"bad case: {type(impl)}" -def find(scope: C, proc: API.Procedure, pattern: str, many: bool): +def find(scope: C, proc: API.Procedure, pattern: str, many: bool, call_depth=1): """ Find the most specific possible cursor for the given pattern in the given scope of the proc. For example, a pattern matching a @@ -953,7 +952,7 @@ def find(scope: C, proc: API.Procedure, pattern: str, many: bool): raise TypeError("expected a pattern string") default_match_no = None if many else 0 raw_cursors = match_pattern( - scope, pattern, call_depth=1, default_match_no=default_match_no + scope, pattern, call_depth=call_depth + 1, default_match_no=default_match_no ) assert isinstance(raw_cursors, list) cursors = [] @@ -1000,7 +999,7 @@ def find(scope: C, proc: API.Procedure, pattern: str, many: bool): "LiteralCursor", "UnaryMinusCursor", "BinaryOpCursor", - "BuiltInFunctionCursor", + "ExternFunctionCursor", "WindowExprCursor", "StrideExprCursor", # diff --git a/src/exo/API_scheduling.py b/src/exo/API_scheduling.py index 54513d99a..c35bc280a 100644 --- a/src/exo/API_scheduling.py +++ b/src/exo/API_scheduling.py @@ -9,16 +9,16 @@ from .API import Procedure import exo.API_cursors as PC -from .LoopIR import LoopIR, T -import exo.LoopIR_scheduling as scheduling +from .core.LoopIR import LoopIR, T +import exo.rewrite.LoopIR_scheduling as scheduling from .API_types import ExoType -from .LoopIR_unification import DoReplace, UnificationError -from .configs import Config -from .memory import Memory -from .parse_fragment import parse_fragment -from .prelude import * -from . import internal_cursors as ic +from .rewrite.LoopIR_unification import DoReplace, UnificationError +from .core.configs import Config +from .core.memory import Memory +from .frontend.parse_fragment import parse_fragment +from .core.prelude import * +from .core import internal_cursors as ic def is_subclass_obj(x, cls): @@ -381,8 +381,7 @@ def _cursor_call(self, expr_pattern, all_args): self.err("expected an ExprCursor or pattern string") proc = all_args["proc"] - # TODO: Remove all need for `call_depth` - matches = proc.find(expr_pattern, many=self.match_many) + matches = proc.find(expr_pattern, many=self.match_many, call_depth=1) if self.match_many: for m in matches: @@ -411,8 +410,7 @@ def _cursor_call(self, stmt_pattern, all_args): self.err("expected a StmtCursor or pattern string") proc = all_args["proc"] - # TODO: Remove all need for `call_depth` - matches = proc.find(stmt_pattern, many=self.match_many) + matches = proc.find(stmt_pattern, many=self.match_many, call_depth=1) match = matches[0] if self.match_many else matches if not isinstance(match, PC.StmtCursor): @@ -441,8 +439,7 @@ def _cursor_call(self, block_pattern, all_args): self.err("expected a Cursor or pattern string") proc = all_args["proc"] - # TODO: Remove all need for `call_depth` - matches = proc.find(block_pattern, many=self.match_many) + matches = proc.find(block_pattern, many=self.match_many, call_depth=1) match = matches[0] if self.match_many else matches if isinstance(match, PC.StmtCursor): @@ -540,7 +537,7 @@ def _cursor_call(self, alloc_pattern, all_args): if not isinstance(cursor, (PC.AllocCursor, PC.ArgCursor)): proc = all_args["proc"] try: - cursor = proc.find(alloc_pattern) + cursor = proc.find(alloc_pattern, call_depth=1) except: for arg in proc.args(): if arg.name() == name: diff --git a/src/exo/API_types.py b/src/exo/API_types.py index e87e22f43..a3deabab9 100644 --- a/src/exo/API_types.py +++ b/src/exo/API_types.py @@ -1,5 +1,5 @@ from enum import Enum, auto -from .LoopIR import LoopIR, T +from .core.LoopIR import LoopIR, T class ProcedureBase: diff --git a/src/exo/__init__.py b/src/exo/__init__.py index 6eba4861b..95fe0c050 100644 --- a/src/exo/__init__.py +++ b/src/exo/__init__.py @@ -7,14 +7,15 @@ config, ExoType, ) -from .LoopIR_scheduling import SchedulingError -from .parse_fragment import ParseFragmentError -from .configs import Config -from .memory import Memory, DRAM +from .rewrite.LoopIR_scheduling import SchedulingError +from .frontend.parse_fragment import ParseFragmentError +from .core.configs import Config +from .core.memory import Memory, DRAM +from .core.extern import Extern from . import stdlib -__version__ = "0.2.1" +__version__ = "1.0.0" __all__ = [ "Procedure", @@ -25,6 +26,7 @@ "config", "Config", "Memory", + "Extern", "DRAM", "SchedulingError", "ParseFragmentError", diff --git a/src/exo/LoopIR_compiler.py b/src/exo/backend/LoopIR_compiler.py similarity index 96% rename from src/exo/LoopIR_compiler.py rename to src/exo/backend/LoopIR_compiler.py index b590c19bf..090a215de 100644 --- a/src/exo/LoopIR_compiler.py +++ b/src/exo/backend/LoopIR_compiler.py @@ -6,15 +6,15 @@ from dataclasses import dataclass from pathlib import Path -from .LoopIR import LoopIR, LoopIR_Do, get_writes_of_stmts, T, CIR -from .configs import ConfigError +from ..core.LoopIR import LoopIR, LoopIR_Do, get_writes_of_stmts, T, CIR +from ..core.configs import ConfigError from .mem_analysis import MemoryAnalysis -from .memory import MemGenError, Memory, DRAM, StaticMemory +from ..core.memory import MemGenError, Memory, DRAM, StaticMemory from .parallel_analysis import ParallelAnalysis from .prec_analysis import PrecisionAnalysis -from .prelude import * +from ..core.prelude import * from .win_analysis import WindowAnalysis -from .range_analysis import IndexRangeEnvironment +from ..rewrite.range_analysis import IndexRangeEnvironment def sanitize_str(s): @@ -196,18 +196,18 @@ def do_t(self, t): pass -class LoopIR_FindBuiltIns(LoopIR_Do): +class LoopIR_FindExterns(LoopIR_Do): def __init__(self, proc): - self._builtins = set() + self._externs = set() super().__init__(proc) def result(self): - return self._builtins + return self._externs # to improve efficiency def do_e(self, e): - if isinstance(e, LoopIR.BuiltIn): - self._builtins.add(e.f) + if isinstance(e, LoopIR.Extern): + self._externs.add((e.f, e.type.basetype().ctype())) else: super().do_e(e) @@ -247,12 +247,12 @@ def find_all_mems(proc_list): return [m for m in mems] -def find_all_builtins(proc_list): - builtins = set() +def find_all_externs(proc_list): + externs = set() for p in proc_list: - builtins.update(LoopIR_FindBuiltIns(p).result()) + externs.update(LoopIR_FindExterns(p).result()) - return [b for b in builtins] + return externs def find_all_configs(proc_list): @@ -376,10 +376,10 @@ def from_lines(x): # Body contents memory_code = _compile_memories(find_all_mems(proc_list)) - builtin_code = _compile_builtins(find_all_builtins(proc_list)) private_fwd_decls = [] proc_bodies = [] instrs_global = [] + analyzed_proc_list = [] needed_helpers = set() @@ -424,6 +424,8 @@ def from_lines(x): proc_bodies.append(b) + analyzed_proc_list.append(p) + # Structs are just blobs of code... still sort them for output stability struct_defns = [x.definition for x in sorted(struct_defns, key=lambda x: x.name)] @@ -454,12 +456,14 @@ def from_lines(x): {from_lines(public_fwd_decls)} """ + extern_code = _compile_externs(find_all_externs(analyzed_proc_list)) + helper_code = [_static_helpers[v] for v in needed_helpers] body_contents = [ helper_code, instrs_global, memory_code, - builtin_code, + extern_code, private_fwd_decls, proc_bodies, ] @@ -470,12 +474,12 @@ def from_lines(x): return header_contents, body_contents -def _compile_builtins(builtins): - builtin_code = [] - for b in sorted(builtins, key=lambda x: x.name()): - if glb := b.globl(): - builtin_code.append(glb) - return builtin_code +def _compile_externs(externs): + extern_code = [] + for f, t in sorted(externs, key=lambda x: x[0].name() + x[1]): + if glb := f.globl(t): + extern_code.append(glb) + return extern_code def _compile_memories(mems): @@ -971,7 +975,7 @@ def comp_fnarg(self, e, fn, i, *, prec=0): x for x, _ in get_writes_of_stmts(fn.body) ) else: - raise NotImplementedError("Passing windows to built-ins") + raise NotImplementedError("Passing windows to externs") win_struct = self.get_window_type(e.type, is_const) data, strides = self.window_struct_fields(e) return f"(struct {win_struct}){{ &{data}, {{ {strides} }} }}" @@ -1044,9 +1048,9 @@ def comp_e(self, e, prec=0): elif isinstance(e, LoopIR.USub): return f'-{self.comp_e(e.arg, op_prec["~"])}' - elif isinstance(e, LoopIR.BuiltIn): - args = [self.comp_fnarg(a, e, i) for i, a in enumerate(e.args)] - return e.f.compile(args) + elif isinstance(e, LoopIR.Extern): + args = [self.comp_e(a) for a in e.args] + return e.f.compile(args, e.type.basetype().ctype()) elif isinstance(e, LoopIR.StrideExpr): basetyp = self.envtyp[e.name] diff --git a/src/exo/backend/__init__.py b/src/exo/backend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/exo/mem_analysis.py b/src/exo/backend/mem_analysis.py similarity index 97% rename from src/exo/mem_analysis.py rename to src/exo/backend/mem_analysis.py index 0743503d3..39eaf267c 100644 --- a/src/exo/mem_analysis.py +++ b/src/exo/backend/mem_analysis.py @@ -1,7 +1,7 @@ from collections import ChainMap -from .LoopIR import LoopIR +from ..core.LoopIR import LoopIR -from .memory import Memory +from ..core.memory import Memory # --------------------------------------------------------------------------- # @@ -69,7 +69,7 @@ def used_e(e): elif isinstance(e, LoopIR.BinOp): res += used_e(e.lhs) res += used_e(e.rhs) - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): for ei in e.args: res += used_e(ei) elif isinstance(e, (LoopIR.WindowExpr, LoopIR.StrideExpr)): diff --git a/src/exo/parallel_analysis.py b/src/exo/backend/parallel_analysis.py similarity index 89% rename from src/exo/parallel_analysis.py rename to src/exo/backend/parallel_analysis.py index 03e070f38..f82ae1aee 100644 --- a/src/exo/parallel_analysis.py +++ b/src/exo/backend/parallel_analysis.py @@ -1,6 +1,6 @@ -from .LoopIR import LoopIR, LoopIR_Rewrite +from ..core.LoopIR import LoopIR, LoopIR_Rewrite -from .new_eff import Check_ParallelizeLoop +from ..rewrite.new_eff import Check_ParallelizeLoop class ParallelAnalysis(LoopIR_Rewrite): diff --git a/src/exo/prec_analysis.py b/src/exo/backend/prec_analysis.py similarity index 86% rename from src/exo/prec_analysis.py rename to src/exo/backend/prec_analysis.py index 224173ebc..3d5dac05a 100644 --- a/src/exo/prec_analysis.py +++ b/src/exo/backend/prec_analysis.py @@ -1,4 +1,4 @@ -from .LoopIR import LoopIR, LoopIR_Rewrite, T +from ..core.LoopIR import LoopIR, LoopIR_Rewrite, T # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # @@ -199,6 +199,28 @@ def map_e(self, e): typ = lhs.type return LoopIR.BinOp(e.op, lhs, rhs, typ, e.srcinfo) + elif isinstance(e, LoopIR.Extern): + typ = T.R + for a in e.args: + if a.type != T.R: + typ = a.type + + new_args = [] + for a in e.args: + a = self.apply_e(a) + if typ != a.type: + # coerce if const and real + if a.type == T.R: + a = self.coerce_e(a, typ) + else: + self.err( + e, + f"all extern arguments must have a same type, got {typ} and {a.type}", + ) + new_args.append(a) + + return LoopIR.Extern(e.f, new_args, typ, e.srcinfo) + return super().map_e(e) # this routine allows for us to retro-actively @@ -224,6 +246,18 @@ def coerce_e(self, e, btyp): assert rhs.type == btyp return LoopIR.BinOp(e.op, lhs, rhs, btyp, e.srcinfo) + elif isinstance(e, LoopIR.Extern): + assert e.type == T.R + # coerce if T.R + args = [] + for a in e.args: + if a.type == T.R: + args.append(self.coerce_e(a, btyp)) + else: + assert a.type == btyp + args.append(a) + return LoopIR.Extern(e.f, args, btyp, e.srcinfo) + else: assert False, f"Should not be coercing a {type(e)} Node" diff --git a/src/exo/win_analysis.py b/src/exo/backend/win_analysis.py similarity index 97% rename from src/exo/win_analysis.py rename to src/exo/backend/win_analysis.py index 1a8ed77c6..293e63fbb 100644 --- a/src/exo/win_analysis.py +++ b/src/exo/backend/win_analysis.py @@ -1,4 +1,4 @@ -from .LoopIR import LoopIR, T, LoopIR_Rewrite +from ..core.LoopIR import LoopIR, T, LoopIR_Rewrite # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # diff --git a/src/exo/builtins.py b/src/exo/builtins.py deleted file mode 100644 index 913620171..000000000 --- a/src/exo/builtins.py +++ /dev/null @@ -1,142 +0,0 @@ -# --------------------------------------------------------------------------- # -# --------------------------------------------------------------------------- # -# BuiltIn superclass - - -class BuiltIn_Typecheck_Error(Exception): - def __init__(self, msg): - self._builtin_err_msg = str(msg) - - def __str__(self): - return self._builtin_err_msg - - -_BErr = BuiltIn_Typecheck_Error - - -class BuiltIn: - def __init__(self, name): - self._name = name - - def name(self): - return self._name - - def globl(self): - raise NotImplementedError() - - def typecheck(self, args): - raise NotImplementedError() - - def compile(self, args): - raise NotImplementedError() - - -class _Sin(BuiltIn): - def __init__(self): - super().__init__("sin") - - def typecheck(self, args): - if len(args) != 1: - raise _BErr(f"expected 1 argument, got {len(args)}") - - atyp = args[0].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 1 to be a real scalar value, but " - f"got type {atyp}" - ) - return atyp - - def globl(self): - return "#include " - - def compile(self, args): - return f"sin((double)*{args[0]})" - - -sin = _Sin() - - -class _Relu(BuiltIn): - def __init__(self): - super().__init__("relu") - - def typecheck(self, args): - if len(args) != 1: - raise _BErr(f"expected 1 argument, got {len(args)}") - - atyp = args[0].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 1 to be a real scalar value, but " - f"got type {atyp}" - ) - return atyp - - def globl(self): - s = ( - "double _relu_(double x) {\n" - " if (x > 0.0) return x;\n" - " else return 0.0;\n" - "}\n" - ) - return s - - def compile(self, args): - return f"_relu_((double)*{args[0]})" - - -relu = _Relu() - - -class _Select(BuiltIn): - def __init__(self): - super().__init__("select") - - def typecheck(self, args): - if len(args) != 4: - raise _BErr(f"expected 4 arguments, got {len(args)}") - - atyp = args[0].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 1 to be a real scalar value, but " - f"got type {atyp}" - ) - - atyp = args[1].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 2 to be a real scalar value, but " - f"got type {atyp}" - ) - - atyp = args[2].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 3 to be a real scalar value, but " - f"got type {atyp}" - ) - - atyp = args[3].type - if not atyp.is_real_scalar(): - raise _BErr( - f"expected argument 4 to be a real scalar value, but " - f"got type {atyp}" - ) - return atyp - - def globl(self): - s = ( - "double _select_(double x, double v, double y, double z) {\n" - " if (x < v) return y;\n" - " else return z;\n" - "}\n" - ) - return s - - def compile(self, args): - return f"_select_((double)*{args[0]}, (double)*{args[1]}, (double)*{args[2]}, (double)*{args[3]})" - - -select = _Select() diff --git a/src/exo/LoopIR.py b/src/exo/core/LoopIR.py similarity index 99% rename from src/exo/LoopIR.py rename to src/exo/core/LoopIR.py index 733b0d133..9ee862779 100644 --- a/src/exo/LoopIR.py +++ b/src/exo/core/LoopIR.py @@ -4,7 +4,7 @@ from asdl_adt import ADT, validators -from .builtins import BuiltIn +from .extern import Extern from .configs import Config from .memory import Memory from .prelude import Sym, SrcInfo, extclass @@ -92,7 +92,7 @@ def __new__(cls, op): | Const( object val ) | USub( expr arg ) -- i.e. -(...) | BinOp( binop op, expr lhs, expr rhs ) - | BuiltIn( builtin f, expr* args ) + | Extern( extern f, expr* args ) | WindowExpr( sym name, w_access* idx ) | StrideExpr( sym name, int dim ) | ReadConfig( config config, string field ) @@ -130,7 +130,7 @@ def __new__(cls, op): "name": validators.instance_of(Identifier, convert=True), "sym": Sym, "mem": Type[Memory], - "builtin": BuiltIn, + "extern": Extern, "config": Config, "binop": validators.instance_of(Operator, convert=True), "srcinfo": SrcInfo, @@ -190,7 +190,7 @@ def __new__(cls, op): | Const ( object val ) | USub ( expr arg ) -- i.e. -(...) | BinOp ( op op, expr lhs, expr rhs ) - | BuiltIn( builtin f, expr* args ) + | Extern( extern f, expr* args ) | WindowExpr( sym name, w_access* idx ) | StrideExpr( sym name, int dim ) | ParRange( expr lo, expr hi ) -- only use for loop cond @@ -221,7 +221,7 @@ def __new__(cls, op): "name": validators.instance_of(Identifier, convert=True), "sym": Sym, "mem": Type[Memory], - "builtin": BuiltIn, + "extern": Extern, "config": Config, "loopir_proc": LoopIR.proc, "op": validators.instance_of(Operator, convert=True), @@ -270,14 +270,13 @@ def __new__(cls, op): | Const ( object val ) | USub ( expr arg ) -- i.e. -(...) | BinOp ( op op, expr lhs, expr rhs ) - | BuiltIn ( builtin f, expr* args ) + | Extern ( name f, expr* args ) | ReadConfig( string config, string field ) attributes( srcinfo srcinfo ) } """, ext_types={ "name": validators.instance_of(IdentifierOrHole, convert=True), - "builtin": BuiltIn, "op": validators.instance_of(Operator, convert=True), "srcinfo": SrcInfo, }, @@ -673,7 +672,7 @@ def map_e(self, e): rhs=new_rhs or e.rhs, type=new_type or e.type, ) - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): new_type = self.map_t(e.type) new_args = self.map_exprs(e.args) if any((new_type, new_args is not None)): @@ -810,7 +809,7 @@ def do_e(self, e): elif etyp is LoopIR.BinOp: self.do_e(e.lhs) self.do_e(e.rhs) - elif etyp is LoopIR.BuiltIn: + elif etyp is LoopIR.Extern: for a in e.args: self.do_e(a) elif etyp is LoopIR.USub: @@ -914,7 +913,7 @@ def match_e(self, e1, e2): and self.match_e(e1.lhs, e2.lhs) and self.match_e(e1.rhs, e2.rhs) ) - elif isinstance(e1, LoopIR.BuiltIn): + elif isinstance(e1, LoopIR.Extern): # TODO: check f equality return e1.f is e2.f and all( self.match_e(a1, a2) for a1, a2 in zip(e1.args, e2.args) diff --git a/src/exo/LoopIR_pprint.py b/src/exo/core/LoopIR_pprint.py similarity index 96% rename from src/exo/LoopIR_pprint.py rename to src/exo/core/LoopIR_pprint.py index 79eb13e5f..4464976e3 100644 --- a/src/exo/LoopIR_pprint.py +++ b/src/exo/core/LoopIR_pprint.py @@ -271,7 +271,7 @@ def pacc(w): return f"{self.get_name(e.name)}[{', '.join([pacc(w) for w in e.idx])}]" elif isinstance(e, UAST.StrideExpr): return f"stride({self.get_name(e.name)}, {e.dim})" - elif isinstance(e, UAST.BuiltIn): + elif isinstance(e, UAST.Extern): pname = e.f.name() or "_anon_" args = [self.pexpr(a) for a in e.args] return f"{pname}({','.join(args)})" @@ -507,7 +507,7 @@ def _print_expr(e, env: PrintEnv, prec: int = 0) -> str: elif isinstance(e, LoopIR.StrideExpr): return f"stride({env.get_name(e.name)}, {e.dim})" - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): pname = e.f.name() or "_anon_" args = [_print_expr(a, env) for a in e.args] return f"{pname}({', '.join(args)})" @@ -581,6 +581,9 @@ def _print_w_access(node, env: PrintEnv) -> str: def _print_cursor(cur): + if cur == None: + raise InvalidCursorError("Trying to print the Invalid Cursor!") + if isinstance(cur, Node) and not isinstance(cur._node, (LoopIR.proc, LoopIR.stmt)): raise NotImplementedError( "Cursor printing is only implemented for procs and statements" @@ -625,31 +628,43 @@ def _print_cursor_proc( def _print_cursor_block( cur: Block, target: Cursor, env: PrintEnv, indent: str ) -> list[str]: - def while_cursor(c, move, k): + def while_next(c): s = [] while True: try: - c = move(c) - s.expand(k(c)) + c = c.next() + s.extend(local_stmt(c)) except: return s + def while_prev(c): + s = [] + while True: + try: + c = c.prev() + s.append(local_stmt(c)) + except: + s.reverse() + return [x for xs in s for x in xs] + def local_stmt(c): return _print_cursor_stmt(c, target, env, indent) if isinstance(target, Gap) and target in cur: if target._type == GapType.Before: return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), + *while_prev(target.anchor()), f"{indent}[GAP - Before]", - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *local_stmt(target.anchor()), + *while_next(target.anchor()), ] else: assert target._type == GapType.After return [ - *while_cursor(target.anchor(), lambda g: g.prev(), local_stmt), + *while_prev(target.anchor()), + *local_stmt(target.anchor()), f"{indent}[GAP - After]", - *while_cursor(target.anchor(), lambda g: g.next(), local_stmt), + *while_next(target.anchor()), ] elif isinstance(target, Block) and target in cur: @@ -658,9 +673,9 @@ def local_stmt(c): block.extend(local_stmt(stmt)) block.append(f"{indent}# BLOCK END") return [ - *while_cursor(target[0], lambda g: g.prev(), local_stmt), + *while_prev(target[0]), *block, - *while_cursor(target[-1], lambda g: g.next(), local_stmt), + *while_next(target[-1]), ] else: diff --git a/src/exo/core/__init__.py b/src/exo/core/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/exo/configs.py b/src/exo/core/configs.py similarity index 100% rename from src/exo/configs.py rename to src/exo/core/configs.py diff --git a/src/exo/core/extern.py b/src/exo/core/extern.py new file mode 100644 index 000000000..b1ae39d6d --- /dev/null +++ b/src/exo/core/extern.py @@ -0,0 +1,36 @@ +import math + +# --------------------------------------------------------------------------- # +# --------------------------------------------------------------------------- # +# Extern superclass + + +class Extern_Typecheck_Error(Exception): + def __init__(self, msg): + self._builtin_err_msg = str(msg) + + def __str__(self): + return self._builtin_err_msg + + +_EErr = Extern_Typecheck_Error + + +class Extern: + def __init__(self, name): + self._name = name + + def name(self): + return self._name + + def globl(self, prim_type): + raise NotImplementedError() + + def typecheck(self, args): + raise NotImplementedError() + + def interpret(self, args): + raise NotImplementedError() + + def compile(self, args, prim_type): + raise NotImplementedError() diff --git a/src/exo/internal_cursors.py b/src/exo/core/internal_cursors.py similarity index 99% rename from src/exo/internal_cursors.py rename to src/exo/core/internal_cursors.py index 21ed814f7..1d4c98a8f 100644 --- a/src/exo/internal_cursors.py +++ b/src/exo/core/internal_cursors.py @@ -337,7 +337,7 @@ def _delete(self): internal classes and modules, but not from end-user code. """ # TODO: refactor this; LoopIR should not be imported here - from exo.LoopIR import LoopIR + from exo.core.LoopIR import LoopIR pass_stmt = [LoopIR.Pass(self.parent()._node.srcinfo)] return self._replace([], empty_default=pass_stmt) diff --git a/src/exo/memory.py b/src/exo/core/memory.py similarity index 100% rename from src/exo/memory.py rename to src/exo/core/memory.py diff --git a/src/exo/prelude.py b/src/exo/core/prelude.py similarity index 100% rename from src/exo/prelude.py rename to src/exo/core/prelude.py diff --git a/src/exo/proc_eqv.py b/src/exo/core/proc_eqv.py similarity index 100% rename from src/exo/proc_eqv.py rename to src/exo/core/proc_eqv.py diff --git a/src/exo/frontend/__init__.py b/src/exo/frontend/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/exo/boundscheck.py b/src/exo/frontend/boundscheck.py similarity index 99% rename from src/exo/boundscheck.py rename to src/exo/frontend/boundscheck.py index da6ea372c..b82bddc93 100644 --- a/src/exo/boundscheck.py +++ b/src/exo/frontend/boundscheck.py @@ -4,8 +4,8 @@ import pysmt from pysmt import shortcuts as SMT -from .LoopIR import LoopIR, T, Operator, Config -from .prelude import * +from ..core.LoopIR import LoopIR, T, Operator, Config +from ..core.prelude import * # --------------------------------------------------------------------------- # @@ -1117,7 +1117,7 @@ def eff_e(self, e, type_env): return eff_null(e.srcinfo) elif isinstance(e, LoopIR.WindowExpr): return eff_null(e.srcinfo) - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): return eff_null(e.srcinfo) elif isinstance(e, LoopIR.StrideExpr): return eff_null(e.srcinfo) diff --git a/src/exo/parse_fragment.py b/src/exo/frontend/parse_fragment.py similarity index 93% rename from src/exo/parse_fragment.py rename to src/exo/frontend/parse_fragment.py index 35ac914da..8cb164401 100644 --- a/src/exo/parse_fragment.py +++ b/src/exo/frontend/parse_fragment.py @@ -2,7 +2,7 @@ from collections import ChainMap from . import pyparser -from .LoopIR import T, LoopIR_Do, LoopIR, PAST +from ..core.LoopIR import T, LoopIR_Do, LoopIR, PAST # --------------------------------------------------------------------------- # @@ -20,13 +20,23 @@ class ParseFragmentError(Exception): def parse_fragment( - proc, fragment, ctx_stmt, call_depth=0, configs=[], scope="before", expr_holes=None + proc, fragment, ctx_stmt, call_depth=1, configs=[], scope="before", expr_holes=None ): + stack_frames: [inspect.FrameInfo] = inspect.stack() # get source location where this is getting called from - caller = inspect.getframeinfo(inspect.stack()[call_depth + 1][0]) + caller = inspect.getframeinfo(stack_frames[call_depth][0]) + func_locals = ChainMap(stack_frames[call_depth].frame.f_locals) + func_globals = ChainMap(stack_frames[call_depth].frame.f_globals) # parse the pattern we're going to use to match - p_ast = pyparser.pattern(fragment, filename=caller.filename, lineno=caller.lineno) + p_ast = pyparser.pattern( + fragment, + filename=caller.filename, + lineno=caller.lineno, + srclocals=func_locals, + srcglobals=func_globals, + ) + if isinstance(p_ast, PAST.expr): return ParseFragment( p_ast, proc, ctx_stmt, configs, scope, expr_holes @@ -47,7 +57,7 @@ def parse_fragment( PAST.USub: LoopIR.USub, PAST.BinOp: LoopIR.BinOp, PAST.StrideExpr: LoopIR.StrideExpr, - PAST.BuiltIn: LoopIR.BuiltIn, + PAST.Extern: LoopIR.Extern, PAST.ReadConfig: LoopIR.ReadConfig, } @@ -234,14 +244,14 @@ def parse_e(self, pat): typ = {float: T.R, bool: T.bool, int: T.int}.get(type(pat.val)) assert typ is not None, "bad type!" return LoopIR.Const(pat.val, typ, self.srcinfo) - elif isinstance(pat, PAST.BuiltIn): + elif isinstance(pat, PAST.Extern): args = [self.parse_e(a) for a in pat.args] try: typ = pat.f.typecheck(args) - except BuiltIn_Typecheck_Error as err: + except Extern_Typecheck_Error as err: raise ParseFragmentError(err) - return LoopIR.BuiltIn(pat.f, args, typ, self.srcinfo) + return LoopIR.Extern(pat.f, args, typ, self.srcinfo) elif isinstance(pat, PAST.ReadConfig): if pat.config not in self.configs: raise ParseFragmentError( @@ -304,12 +314,12 @@ def check_sym_consistency(sym): rhs=self.rebuild_ast(loopIR_expr.rhs), srcinfo=self.srcinfo, ) - elif isinstance(loopIR_expr, LoopIR.BuiltIn): + elif isinstance(loopIR_expr, LoopIR.Extern): args = [self.rebuild_ast(a) for a in loopIR_expr.args] try: typ = loopIR_expr.f.typecheck(args) - except BuiltIn_Typecheck_Error as err: + except Extern_Typecheck_Error as err: raise ParseFragmentError(err) if typ != loopIR_expr.typ: diff --git a/src/exo/pattern_match.py b/src/exo/frontend/pattern_match.py similarity index 94% rename from src/exo/pattern_match.py rename to src/exo/frontend/pattern_match.py index 7dca84c83..55eca676b 100644 --- a/src/exo/pattern_match.py +++ b/src/exo/frontend/pattern_match.py @@ -3,14 +3,15 @@ import inspect import re from typing import Optional, Iterable +from collections import ChainMap -import exo.pyparser as pyparser -from exo.LoopIR import LoopIR, PAST +import exo.frontend.pyparser as pyparser +from exo.core.LoopIR import LoopIR, PAST # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # # Pattern Matching Errors -from exo.internal_cursors import Cursor, Node, Block +from exo.core.internal_cursors import Cursor, Node, Block class PatternMatchError(Exception): @@ -59,7 +60,7 @@ def get_match_no(pattern_str: str) -> Optional[int]: def match_pattern( context: Cursor, pattern_str: str, - call_depth=0, + call_depth=1, default_match_no=None, use_sym_id=False, ): @@ -78,12 +79,19 @@ def match_pattern( else: match_no = default_match_no # None means match-all + stack_frames: [inspect.FrameInfo] = inspect.stack() # get source location where this is getting called from - caller = inspect.getframeinfo(inspect.stack()[call_depth + 1][0]) + caller = inspect.getframeinfo(stack_frames[call_depth][0]) + func_locals = ChainMap(stack_frames[call_depth].frame.f_locals) + func_globals = ChainMap(stack_frames[call_depth].frame.f_globals) # parse the pattern we're going to use to match p_ast = pyparser.pattern( - pattern_str, filename=caller.filename, lineno=caller.lineno + pattern_str, + filename=caller.filename, + lineno=caller.lineno, + srclocals=func_locals, + srcglobals=func_globals, ) # do the pattern match, to find the nodes in ast @@ -109,7 +117,7 @@ def match_pattern( PAST.Const: [LoopIR.Const], PAST.USub: [LoopIR.USub], PAST.BinOp: [LoopIR.BinOp], - PAST.BuiltIn: [LoopIR.BuiltIn], + PAST.Extern: [LoopIR.Extern], PAST.ReadConfig: [LoopIR.ReadConfig], PAST.E_Hole: None, } @@ -324,8 +332,8 @@ def match_e(self, pat, e): ) elif isinstance(e, LoopIR.USub): return self.match_e(pat.arg, e.arg) - elif isinstance(e, LoopIR.BuiltIn): - return pat.f is e.f and all( + elif isinstance(e, LoopIR.Extern): + return self.match_name(pat.f, e.f.name()) and all( self.match_e(pa, sa) for pa, sa in zip(pat.args, e.args) ) elif isinstance(e, LoopIR.ReadConfig): @@ -383,7 +391,7 @@ def _children(cur) -> Iterable[Node]: yield from _children_from_attrs(cur, n, "arg") elif isinstance(n, LoopIR.BinOp): yield from _children_from_attrs(cur, n, "lhs", "rhs") - elif isinstance(n, LoopIR.BuiltIn): + elif isinstance(n, LoopIR.Extern): yield from _children_from_attrs(cur, n, "args") else: assert False, f"case {type(n)} unsupported" diff --git a/src/exo/pyparser.py b/src/exo/frontend/pyparser.py similarity index 95% rename from src/exo/pyparser.py rename to src/exo/frontend/pyparser.py index f997543c6..b341b42eb 100644 --- a/src/exo/pyparser.py +++ b/src/exo/frontend/pyparser.py @@ -9,11 +9,11 @@ from asdl_adt.validators import ValidationError -from .API_types import ProcedureBase -from .builtins import * -from .configs import Config -from .LoopIR import UAST, PAST, front_ops -from .prelude import * +from ..API_types import ProcedureBase +from ..core.configs import Config +from ..core.LoopIR import UAST, PAST, front_ops +from ..core.prelude import * +from ..core.extern import Extern # --------------------------------------------------------------------------- # @@ -90,7 +90,7 @@ def get_src_locals(*, depth): # Pattern-Parser top-level, invoked on strings rather than as a decorator -def pattern(s, filename=None, lineno=None): +def pattern(s, filename=None, lineno=None, srclocals=None, srcglobals=None): assert isinstance(s, str) src = s @@ -119,7 +119,13 @@ def getsrcinfo(node): ), ) - parser = Parser(module.body, getsrcinfo, is_fragment=True) + parser = Parser( + module.body, + getsrcinfo, + is_fragment=True, + func_globals=srcglobals, + srclocals=srclocals, + ) return parser.result() @@ -166,15 +172,18 @@ def __init__( self.is_fragment = is_fragment self.push() + special_cases = ["stride"] + for key, val in self.globals.items(): + if isinstance(val, Extern): + special_cases.append(key) + for key, val in self.locals.items(): + if isinstance(val, Extern): + special_cases.append(key) - builtins = {"sin": sin, "relu": relu, "select": select} if is_fragment: self.AST = PAST else: self.AST = UAST - # add builtins - for key, val in builtins.items(): - self.locals[key] = val if as_func: self._cached_result = self.parse_fdef(module_ast, instr=instr) @@ -184,7 +193,6 @@ def __init__( is_expr = False if len(module_ast) == 1: s = module_ast[0] - special_cases = list(builtins.keys()) + ["stride"] if isinstance(s, pyast.Expr) and ( not isinstance(s.value, pyast.Call) or s.value.func.id in special_cases @@ -875,8 +883,8 @@ def parse_expr(self, e): nm = self.locals[nm_node.id] elif nm_node.id in self.globals: nm = self.globals[nm_node.id] - else: - nm = None + else: # could not resolve name to anything + self.err(nm_node, f"variable '{nm_node.id}' undefined") if isinstance(nm, SizeStub): nm = nm.nm @@ -891,8 +899,11 @@ def parse_expr(self, e): ) else: return UAST.Const(nm, self.getsrcinfo(e)) - else: # could not resolve name to anything - self.err(nm_node, f"variable '{nm_node.id}' undefined") + else: + self.err( + nm_node, + f"variable '{nm_node.id}' has unsupported type {type(nm)}", + ) if is_window: return UAST.WindowExpr(nm, idxs, self.getsrcinfo(e)) @@ -1064,21 +1075,30 @@ def parse_expr(self, e): # handle built-in functions else: - f = self.eval_expr(e.func) fname = e.func.id + if self.is_fragment: + if len(e.keywords) > 0: + self.err( + f, "cannot call a extern function " "with keyword arguments" + ) + args = [self.parse_expr(a) for a in e.args] - if not isinstance(f, BuiltIn): - self.err( - e.func, f"expected called object " "to be a builtin function" - ) + return self.AST.Extern(fname, args, self.getsrcinfo(e)) + else: + f = self.eval_expr(e.func) - if len(e.keywords) > 0: - self.err( - f, "cannot call a builtin function " "with keyword arguments" - ) - args = [self.parse_expr(a) for a in e.args] + if not isinstance(f, Extern): + self.err( + e.func, f"expected called object " "to be a extern function" + ) + + if len(e.keywords) > 0: + self.err( + f, "cannot call a extern function " "with keyword arguments" + ) + args = [self.parse_expr(a) for a in e.args] - return self.AST.BuiltIn(f, args, self.getsrcinfo(e)) + return self.AST.Extern(f, args, self.getsrcinfo(e)) else: self.err(e, "unsupported form of expression") diff --git a/src/exo/syntax.py b/src/exo/frontend/syntax.py similarity index 100% rename from src/exo/syntax.py rename to src/exo/frontend/syntax.py diff --git a/src/exo/typecheck.py b/src/exo/frontend/typecheck.py similarity index 98% rename from src/exo/typecheck.py rename to src/exo/frontend/typecheck.py index a9581f9b2..9b27197a7 100644 --- a/src/exo/typecheck.py +++ b/src/exo/frontend/typecheck.py @@ -1,4 +1,4 @@ -from .LoopIR import ( +from ..core.LoopIR import ( T, UAST, LoopIR, @@ -6,8 +6,8 @@ get_writeconfigs, get_loop_iters, ) -from .builtins import BuiltIn_Typecheck_Error -from .memory import * +from ..core.extern import Extern_Typecheck_Error +from ..core.memory import * # --------------------------------------------------------------------------- # @@ -555,17 +555,17 @@ def check_e(self, e, is_index=False): return LoopIR.BinOp(e.op, lhs, rhs, typ, e.srcinfo) - elif isinstance(e, UAST.BuiltIn): + elif isinstance(e, UAST.Extern): args = [self.check_e(a) for a in e.args] try: typ = e.f.typecheck(args) - except BuiltIn_Typecheck_Error as err: + except Extern_Typecheck_Error as err: typ = T.err self.err(e, str(err)) - return LoopIR.BuiltIn(e.f, args, typ, e.srcinfo) + return LoopIR.Extern(e.f, args, typ, e.srcinfo) elif isinstance(e, UAST.StrideExpr): idx, typ = self.check_access(e, e.name, [], lvalue=False) diff --git a/src/exo/libs/externs.py b/src/exo/libs/externs.py new file mode 100644 index 000000000..8752ed695 --- /dev/null +++ b/src/exo/libs/externs.py @@ -0,0 +1,234 @@ +from exo.core.extern import Extern, _EErr + + +class _Sin(Extern): + def __init__(self): + super().__init__("sin") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return "#include " + + # def interpret(self, args): + # return math.sin(args[0]) + + def compile(self, args, prim_type): + return f"sin(({prim_type}){args[0]})" + + +sin = _Sin() + + +class _Relu(Extern): + def __init__(self): + super().__init__("relu") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + s = ( + f"{prim_type} _relu_{prim_type}({prim_type} x) " + "{\n" + " if (x > 0.0) return x;\n" + " else return 0.0;\n" + "}\n" + ) + return s + + # def interpret(self, args): + # if args[0] > 0: + # return args[0] + # else: + # return 0 + + def compile(self, args, prim_type): + return f"_relu_{prim_type}(({prim_type}){args[0]})" + + +relu = _Relu() + + +class _Select(Extern): + def __init__(self): + super().__init__("select") + + def typecheck(self, args): + if len(args) != 4: + raise _EErr(f"expected 4 arguments, got {len(args)}") + + for i in range(len(args)): + atyp = args[i].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument {i+1} to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + s = ( + f"{prim_type} _select_{prim_type}({prim_type} x,{prim_type} v,{prim_type} y,{prim_type} z)" + + " {\n" + " if (x < v) return y;\n" + " else return z;\n" + "}\n" + ) + return s + + # def interpret(self, args): + # x = args[0] + # v = args[1] + # y = args[2] + # z = args[3] + # if x < v: + # return y + # else: + # return z + + def compile(self, args, prim_type): + return f"_select_{prim_type}(({prim_type}){args[0]}, ({prim_type}){args[1]}, ({prim_type}){args[2]}, ({prim_type}){args[3]})" + + +select = _Select() + + +class _Expf(Extern): + def __init__(self): + super().__init__("expf") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return "#include " + + # def interpret(self, args): + # return math.expf(args[0]) + + def compile(self, args, prim_type): + return f"expf(({prim_type})({args[0]}))" + + +expf = _Expf() + + +class _FmaxF(Extern): + def __init__(self): + super().__init__("fmaxf") + + def typecheck(self, args): + if len(args) != 2: + raise _EErr(f"expected 2 argument, got {len(args)}") + + for i in range(len(args)): + atyp = args[i].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument {i+1} to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return "#include " + + # def interpret(self, args): + # return math.fmaxf(args[0], args[1]) + + def compile(self, args, prim_type): + return f"fmaxf(({prim_type})({args[0]}), ({prim_type})({args[1]}))" + + +fmaxf = _FmaxF() + + +class _Sigmoid(Extern): + def __init__(self): + super().__init__("sigmoid") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return f""" +#include +{prim_type} sigmoid({prim_type} x) {{ + return 1 / (1 + exp(-x)); +}} +""" + + # def interpret(self, args): + # return math.sigmoid(args[0]) + + def compile(self, args, prim_type): + return f"sigmoid(({prim_type})({args[0]}))" + + +sigmoid = _Sigmoid() + + +class _Sqrt(Extern): + def __init__(self): + super().__init__("sqrt") + + def typecheck(self, args): + if len(args) != 1: + raise _EErr(f"expected 1 argument, got {len(args)}") + + atyp = args[0].type + if not atyp.is_real_scalar(): + raise _EErr( + f"expected argument 1 to be a real scalar value, but " + f"got type {atyp}" + ) + return atyp + + def globl(self, prim_type): + return "#include " + + # def interpret(self, args): + # return math.sqrt(args[0]) + + def compile(self, args, prim_type): + return f"sqrt(({prim_type})({args[0]}))" + + +sqrt = _Sqrt() diff --git a/src/exo/libs/memories.py b/src/exo/libs/memories.py index b37b11867..8ce892997 100644 --- a/src/exo/libs/memories.py +++ b/src/exo/libs/memories.py @@ -1,4 +1,4 @@ -from ..memory import Memory, DRAM, StaticMemory, MemGenError, generate_offset +from exo.core.memory import Memory, DRAM, StaticMemory, MemGenError, generate_offset def _is_const_size(sz, c): diff --git a/src/exo/platforms/gemmini.py b/src/exo/platforms/gemmini.py index c91e4a460..5f598817b 100644 --- a/src/exo/platforms/gemmini.py +++ b/src/exo/platforms/gemmini.py @@ -2,6 +2,7 @@ from exo import proc, instr, DRAM, config from exo.libs.memories import GEMM_SCRATCH, GEMM_ACCUM +from exo.libs.externs import select, relu from exo.stdlib.scheduling import * @@ -800,8 +801,10 @@ def clamp(src: f32, dst: i8): h: f32 l = -128.0 h = 127.0 - dst = select(h, src, h, src) - dst = select(src, l, l, dst) + tmp: f32 + tmp = select(h, src, h, src) + tmp = select(src, l, l, tmp) + dst = tmp def new_config_st(): diff --git a/src/exo/platforms/x86.py b/src/exo/platforms/x86.py index 9a7b6bc70..f049e1c49 100644 --- a/src/exo/platforms/x86.py +++ b/src/exo/platforms/x86.py @@ -2,6 +2,7 @@ from .. import instr, DRAM from ..libs.memories import AVX2, AVX512 +from ..libs.externs import relu, select # --------------------------------------------------------------------------- # # Prefetching diff --git a/src/exo/LoopIR_scheduling.py b/src/exo/rewrite/LoopIR_scheduling.py similarity index 99% rename from src/exo/LoopIR_scheduling.py rename to src/exo/rewrite/LoopIR_scheduling.py index 0c65580d7..c67f0ec70 100644 --- a/src/exo/LoopIR_scheduling.py +++ b/src/exo/rewrite/LoopIR_scheduling.py @@ -2,7 +2,7 @@ from collections import ChainMap from typing import List, Tuple, Optional -from .LoopIR import ( +from ..core.LoopIR import ( LoopIR, LoopIR_Rewrite, Alpha_Rename, @@ -35,13 +35,13 @@ from .range_analysis import IndexRangeEnvironment, IndexRange, index_range_analysis -from .prelude import * -from .proc_eqv import get_strictest_eqv_proc -import exo.internal_cursors as ic +from ..core.prelude import * +from ..core.proc_eqv import get_strictest_eqv_proc +import exo.core.internal_cursors as ic import exo.API as api -from .pattern_match import match_pattern -from .memory import DRAM -from .typecheck import check_call_types +from ..frontend.pattern_match import match_pattern +from ..core.memory import DRAM +from ..frontend.typecheck import check_call_types from functools import partial @@ -2410,7 +2410,7 @@ def wrapper(body): if cur_c._node in par_s.body: def wrapper(body): - return par_s.update(body=body) + return par_s.update(body=body, orelse=[]) ir, fwd_wrap = pre_c._wrap(wrapper, "body") fwd = _compose(fwd_wrap, fwd) @@ -2424,7 +2424,9 @@ def wrapper(body): assert cur_c._node in par_s.orelse def wrapper(orelse): - return par_s.update(body=None, orelse=orelse) + return par_s.update( + body=[LoopIR.Pass(par_s.srcinfo)], orelse=orelse + ) ir, fwd_wrap = post_c._wrap(wrapper, "orelse") fwd = _compose(fwd_wrap, fwd) diff --git a/src/exo/LoopIR_unification.py b/src/exo/rewrite/LoopIR_unification.py similarity index 99% rename from src/exo/LoopIR_unification.py rename to src/exo/rewrite/LoopIR_unification.py index 9a6e49621..16cf2182d 100644 --- a/src/exo/LoopIR_unification.py +++ b/src/exo/rewrite/LoopIR_unification.py @@ -7,7 +7,7 @@ from asdl_adt import ADT from pysmt import shortcuts as SMT -from .LoopIR import ( +from ..core.LoopIR import ( LoopIR, T, LoopIR_Do, @@ -17,9 +17,9 @@ LoopIR_Dependencies, ) from .LoopIR_scheduling import SchedulingError -from .prelude import * +from ..core.prelude import * from .new_eff import Check_Aliasing -import exo.internal_cursors as ic +import exo.core.internal_cursors as ic def _get_smt_solver(): @@ -797,7 +797,7 @@ def all_bound_e(self, be): return self.all_bound_e(be.arg) elif isinstance(be, LoopIR.BinOp): return self.all_bound_e(be.lhs) and self.all_bound_e(be.rhs) - elif isinstance(be, LoopIR.BuiltIn): + elif isinstance(be, LoopIR.Extern): return all(self.all_bound_e(a) for a in be.args) else: assert False, "unsupported case" @@ -819,7 +819,7 @@ def is_exact_e(self, e0, e1): and self.is_exact_e(e0.lhs, e1.lhs) and self.is_exact_e(e0.rhs, e1.rhs) ) - elif isinstance(e0, LoopIR.BuiltIn): + elif isinstance(e0, LoopIR.Extern): return e0.f == e1.f and all( self.is_exact_e(a0, a1) for a0, a1 in zip(e0.args, e1.args) ) @@ -1165,7 +1165,7 @@ def unify_e(self, pe, be): ) self.unify_e(pe.lhs, be.lhs) self.unify_e(pe.rhs, be.rhs) - elif isinstance(pe, LoopIR.BuiltIn): + elif isinstance(pe, LoopIR.Extern): if pe.f != be.f: raise UnificationError( f"cannot unify builtin '{pe.f.name()}' (@{pe.srcinfo}) " diff --git a/src/exo/rewrite/__init__.py b/src/exo/rewrite/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/exo/analysis_simplify.py b/src/exo/rewrite/analysis_simplify.py similarity index 100% rename from src/exo/analysis_simplify.py rename to src/exo/rewrite/analysis_simplify.py diff --git a/src/exo/new_analysis_core.py b/src/exo/rewrite/new_analysis_core.py similarity index 99% rename from src/exo/new_analysis_core.py rename to src/exo/rewrite/new_analysis_core.py index 0fa410f2e..11ed09da2 100644 --- a/src/exo/new_analysis_core.py +++ b/src/exo/rewrite/new_analysis_core.py @@ -11,8 +11,8 @@ from asdl_adt import ADT, validators from asdl_adt.validators import ValidationError -from .LoopIR import T, LoopIR -from .prelude import * +from ..core.LoopIR import T, LoopIR +from ..core.prelude import * _first_run = True diff --git a/src/exo/new_eff.py b/src/exo/rewrite/new_eff.py similarity index 99% rename from src/exo/new_eff.py rename to src/exo/rewrite/new_eff.py index ec03b6b3c..cf8c5d73a 100644 --- a/src/exo/new_eff.py +++ b/src/exo/rewrite/new_eff.py @@ -2,10 +2,10 @@ from enum import Enum from itertools import chain -from .LoopIR import Alpha_Rename, SubstArgs, LoopIR_Do -from .configs import reverse_config_lookup, Config +from ..core.LoopIR import Alpha_Rename, SubstArgs, LoopIR_Do +from ..core.configs import reverse_config_lookup, Config from .new_analysis_core import * -from .proc_eqv import get_repr_proc +from ..core.proc_eqv import get_repr_proc # --------------------------------------------------------------------------- # # --------------------------------------------------------------------------- # @@ -1133,7 +1133,7 @@ def expr_effs(e): return expr_effs(e.arg) elif isinstance(e, LoopIR.BinOp): return expr_effs(e.lhs) + expr_effs(e.rhs) - elif isinstance(e, LoopIR.BuiltIn): + elif isinstance(e, LoopIR.Extern): return list_expr_effs(e.args) elif isinstance(e, LoopIR.WindowExpr): @@ -1570,7 +1570,7 @@ def Shadows(a1, a2): import inspect import textwrap -from .API_types import ProcedureBase +from ..API_types import ProcedureBase class SchedulingError(Exception): diff --git a/src/exo/range_analysis.py b/src/exo/rewrite/range_analysis.py similarity index 99% rename from src/exo/range_analysis.py rename to src/exo/rewrite/range_analysis.py index c84014c60..3b810195b 100644 --- a/src/exo/range_analysis.py +++ b/src/exo/rewrite/range_analysis.py @@ -3,9 +3,9 @@ from dataclasses import dataclass from typing import Optional, Tuple -from .LoopIR import LoopIR, T, LoopIR_Compare +from ..core.LoopIR import LoopIR, T, LoopIR_Compare from .new_eff import Check_ExprBound -from .prelude import Sym, _null_srcinfo_obj +from ..core.prelude import Sym, _null_srcinfo_obj # TODO: we should implement a more general index analysis which diff --git a/src/exo/stdlib/halide_scheduling_ops.py b/src/exo/stdlib/halide_scheduling_ops.py index c5b2fa48f..baf19ad6d 100644 --- a/src/exo/stdlib/halide_scheduling_ops.py +++ b/src/exo/stdlib/halide_scheduling_ops.py @@ -1,7 +1,7 @@ from __future__ import annotations from exo.API_cursors import * -from exo.LoopIR import get_reads_of_expr, LoopIR # TODO: get rid of this +from exo.core.LoopIR import get_reads_of_expr, LoopIR # TODO: get rid of this from .range_analysis import bounds_inference from .scheduling import * diff --git a/src/exo/stdlib/inspection.py b/src/exo/stdlib/inspection.py index 4bbedf04a..ad299c4f9 100644 --- a/src/exo/stdlib/inspection.py +++ b/src/exo/stdlib/inspection.py @@ -4,7 +4,7 @@ from exo.libs.memories import * from exo.platforms.x86 import * from exo.platforms.neon import * -from exo.syntax import * +from exo.frontend.syntax import * from exo.API_cursors import * from exo.stdlib.analysis import * @@ -24,7 +24,7 @@ def expr_children(expr): elif isinstance(expr, BinaryOpCursor): yield expr.lhs() yield expr.rhs() - elif isinstance(expr, BuiltInFunctionCursor): + elif isinstance(expr, ExternFunctionCursor): yield from expr.args() elif isinstance(expr, (LiteralCursor, ReadConfigCursor)): pass @@ -381,7 +381,7 @@ def is_mod(proc, expr): def is_builtin(proc, expr, name): expr = proc.forward(expr) - return isinstance(expr, BuiltInFunctionCursor) and expr.name() == name + return isinstance(expr, ExternFunctionCursor) and expr.name() == name def is_select(proc, expr): @@ -563,7 +563,7 @@ def expr_list_to_string(expr_list, subst): lhs_str = expr_to_string(expr_cursor.lhs(), subst) rhs_str = expr_to_string(expr_cursor.rhs(), subst) return f"({lhs_str}{binop_str}{rhs_str})" - elif isinstance(expr_cursor, BuiltInFunctionCursor): + elif isinstance(expr_cursor, ExternFunctionCursor): name = expr_cursor.name() args_str = expr_list_to_string(expr_cursor.args(), subst) return f"({name}({args_str[1:-1]}))" diff --git a/src/exo/stdlib/range_analysis.py b/src/exo/stdlib/range_analysis.py index 8334f7e40..ead38f578 100644 --- a/src/exo/stdlib/range_analysis.py +++ b/src/exo/stdlib/range_analysis.py @@ -2,7 +2,7 @@ from typing import Tuple from exo.API_cursors import * -from exo.range_analysis import IndexRange +from exo.rewrite.range_analysis import IndexRange from .inspection import get_parents diff --git a/src/exo/stdlib/rc_wrappers.py b/src/exo/stdlib/rc_wrappers.py index 4d6e9bd56..72982cae5 100644 --- a/src/exo/stdlib/rc_wrappers.py +++ b/src/exo/stdlib/rc_wrappers.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from exo import * -from exo.syntax import * +from exo.frontend.syntax import * from exo.stdlib.scheduling import * from exo.API_cursors import * diff --git a/src/exo/stdlib/scheduling.py b/src/exo/stdlib/scheduling.py index 43e18939c..9e30eb177 100644 --- a/src/exo/stdlib/scheduling.py +++ b/src/exo/stdlib/scheduling.py @@ -96,7 +96,7 @@ from .analysis import check_call_mem_types from ..API_cursors import * -from ..LoopIR_unification import UnificationError as _UnificationError +from ..rewrite.LoopIR_unification import UnificationError as _UnificationError # --------------------------------------------------------------------------- # diff --git a/src/exo/stdlib/stdlib.py b/src/exo/stdlib/stdlib.py index 6900a2dd0..7b3fd4b3a 100644 --- a/src/exo/stdlib/stdlib.py +++ b/src/exo/stdlib/stdlib.py @@ -2,8 +2,7 @@ from dataclasses import dataclass from exo import * -from exo.syntax import * -from exo.API_cursors import * +from exo.frontend.syntax import * from .scheduling import * from .inspection import * diff --git a/tests/amx/test_amx_instr.py b/tests/amx/test_amx_instr.py index b34cbe599..b9226d855 100644 --- a/tests/amx/test_amx_instr.py +++ b/tests/amx/test_amx_instr.py @@ -7,7 +7,7 @@ from exo.stdlib.scheduling import * from .amx import * from .harness_amx import AMXTestBuilder -from exo.memory import MemGenError +from exo.core.memory import MemGenError def reorder_back(proc, pattern): diff --git a/tests/golden/asplos25/test_gemmini_matmul_new/test_matmul.txt b/tests/golden/asplos25/test_gemmini_matmul_new/test_matmul.txt index cd100f9dc..b55d9f095 100644 --- a/tests/golden/asplos25/test_gemmini_matmul_new/test_matmul.txt +++ b/tests/golden/asplos25/test_gemmini_matmul_new/test_matmul.txt @@ -99,11 +99,6 @@ void matmul_on_gemmini( c_code_str_Context *ctxt, int_fast32_t N, int_fast32_t M #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" -double _relu_(double x) { - if (x > 0.0) return x; - else return 0.0; -} - /* relying on the following instruction..." config_ld_i8_id1(src_stride) diff --git a/tests/golden/asplos25/test_gemmini_matmul_old/test_matmul.txt b/tests/golden/asplos25/test_gemmini_matmul_old/test_matmul.txt index c892fcf29..f43b34b48 100644 --- a/tests/golden/asplos25/test_gemmini_matmul_old/test_matmul.txt +++ b/tests/golden/asplos25/test_gemmini_matmul_old/test_matmul.txt @@ -99,11 +99,6 @@ void matmul_on_cpu( c_code_str_Context *ctxt, int_fast32_t N, int_fast32_t M, co #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" -double _relu_(double x) { - if (x > 0.0) return x; - else return 0.0; -} - /* relying on the following instruction..." config_ld_i8_id1(src_stride) diff --git a/tests/golden/test_apps/test_gemmini_conv.txt b/tests/golden/test_apps/test_gemmini_conv.txt index 0f4ceb63f..c50409bde 100644 --- a/tests/golden/test_apps/test_gemmini_conv.txt +++ b/tests/golden/test_apps/test_gemmini_conv.txt @@ -161,12 +161,12 @@ void conv_3_cpu( test_case_Context *ctxt, int8_t* output, const int32_t* bias, c #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" -double _relu_(double x) { +int8_t _relu_int8_t(int8_t x) { if (x > 0.0) return x; else return 0.0; } -double _select_(double x, double v, double y, double z) { +float _select_float(float x,float v,float y,float z) { if (x < v) return y; else return z; } @@ -191,8 +191,10 @@ float l; float h; l = -128.0f; h = 127.0f; -*dst = (int8_t)(_select_((double)*&h, (double)*src, (double)*&h, (double)*src)); -*dst = _select_((double)*src, (double)*&l, (double)*&l, (double)*dst); +float tmp; +tmp = _select_float((float)h, (float)*src, (float)h, (float)*src); +tmp = _select_float((float)*src, (float)l, (float)l, (float)tmp); +*dst = (int8_t)(tmp); } @@ -1336,7 +1338,7 @@ for (int_fast32_t b = 0; b < 4; b++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } output[b * 100352 + orow * 3584 + ocol * 128 + och] = tmp_res2; } @@ -1570,7 +1572,7 @@ for (int_fast32_t b = 0; b < 4; b++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } output[b * 50176 + orow * 3584 + ocol * 256 + och] = tmp_res2; } @@ -1617,7 +1619,7 @@ for (int_fast32_t b = 0; b < 4; b++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } output[b * 200704 + orow * 3584 + ocol * 64 + och] = tmp_res2; } diff --git a/tests/golden/test_apps/test_gemmini_matmul.txt b/tests/golden/test_apps/test_gemmini_matmul.txt index 205df4e61..89a5be56e 100644 --- a/tests/golden/test_apps/test_gemmini_matmul.txt +++ b/tests/golden/test_apps/test_gemmini_matmul.txt @@ -213,12 +213,12 @@ void matmul_6( test_case_Context *ctxt, const float* scale, bool act, const int8 #include "gemm_acc_malloc.h" #include #include "gemm_malloc.h" -double _relu_(double x) { +int8_t _relu_int8_t(int8_t x) { if (x > 0.0) return x; else return 0.0; } -double _select_(double x, double v, double y, double z) { +float _select_float(float x,float v,float y,float z) { if (x < v) return y; else return z; } @@ -243,8 +243,10 @@ float l; float h; l = -128.0f; h = 127.0f; -*dst = (int8_t)(_select_((double)*&h, (double)*src, (double)*&h, (double)*src)); -*dst = _select_((double)*src, (double)*&l, (double)*&l, (double)*dst); +float tmp; +tmp = _select_float((float)h, (float)*src, (float)h, (float)*src); +tmp = _select_float((float)*src, (float)l, (float)l, (float)tmp); +*dst = (int8_t)(tmp); } @@ -307,7 +309,7 @@ for (int_fast32_t i = 0; i < 3136; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 512 + j] = tmp_res2; } @@ -344,7 +346,7 @@ for (int_fast32_t i = 0; i < 3136; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 128 + j] = tmp_res2; } @@ -381,7 +383,7 @@ for (int_fast32_t i = 0; i < 784; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 1024 + j] = tmp_res2; } @@ -418,7 +420,7 @@ for (int_fast32_t i = 0; i < 12544; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 256 + j] = tmp_res2; } @@ -455,7 +457,7 @@ for (int_fast32_t i = 0; i < 512; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 512 + j] = tmp_res2; } @@ -492,7 +494,7 @@ for (int_fast32_t i = 0; i < 12544; i++) { int8_t tmp_res2; clamp(ctxt,&tmp_res1,&tmp_res2); if (act == true) { - tmp_res2 = _relu_((double)*&tmp_res2); + tmp_res2 = _relu_int8_t((int8_t)tmp_res2); } C[i * 64 + j] = tmp_res2; } diff --git a/tests/golden/test_apps/test_x86_conv.txt b/tests/golden/test_apps/test_x86_conv.txt index a2b815842..7070eca2d 100644 --- a/tests/golden/test_apps/test_x86_conv.txt +++ b/tests/golden/test_apps/test_x86_conv.txt @@ -65,11 +65,6 @@ void conv_specialized( void *ctxt, const float* inp, float* output, const float* #include #include -double _relu_(double x) { - if (x > 0.0) return x; - else return 0.0; -} - // conv_specialized( // inp : f32[5, 82, 102, 128] @DRAM, // output : f32[5, 80, 100, 128] @DRAM, diff --git a/tests/golden/test_cursors/test_cursor_print.txt b/tests/golden/test_cursors/test_cursor_print.txt new file mode 100644 index 000000000..dbdcd3046 --- /dev/null +++ b/tests/golden/test_cursors/test_cursor_print.txt @@ -0,0 +1,81 @@ +def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): # <-- NODE + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + [GAP - Before] + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + [GAP - After] + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): # <-- NODE + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + [GAP - Before] + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + [GAP - After] + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + # BLOCK START + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + # BLOCK END + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0def foo(n: size, x: i8[n] @ DRAM): + for j in seq(0, n - 1): + x[j] = 2.0 + # BLOCK START + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + # BLOCK END + for j in seq(0, n - 1): + x[j] = 3.0 \ No newline at end of file diff --git a/tests/golden/test_externs/test_expf.txt b/tests/golden/test_externs/test_expf.txt new file mode 100644 index 000000000..ebbab2553 --- /dev/null +++ b/tests/golden/test_externs/test_expf.txt @@ -0,0 +1,61 @@ +#include "test.h" + +#include +#include + +#include +// foo( +// x : i8[16] @DRAM, +// y : i8[16] @DRAM +// ) +void foo( void *ctxt, const int8_t* x, int8_t* y ) { +for (int_fast32_t i = 0; i < 16; i++) { + y[i] = expf((int8_t)(x[i] + y[i])); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : i8[16] @DRAM, +// y : i8[16] @DRAM +// ) +void foo( void *ctxt, const int8_t* x, int8_t* y ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_extern_find.txt b/tests/golden/test_externs/test_extern_find.txt new file mode 100644 index 000000000..bc820ef0e --- /dev/null +++ b/tests/golden/test_externs/test_extern_find.txt @@ -0,0 +1,2 @@ +def foo(a: f32 @ DRAM): + a = sin(a) # <-- NODE \ No newline at end of file diff --git a/tests/golden/test_externs/test_fmaxf.txt b/tests/golden/test_externs/test_fmaxf.txt new file mode 100644 index 000000000..af16a7798 --- /dev/null +++ b/tests/golden/test_externs/test_fmaxf.txt @@ -0,0 +1,61 @@ +#include "test.h" + +#include +#include + +#include +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ) { +for (int_fast32_t i = 0; i < 16; i++) { + y[i] = fmaxf((float)(x[i]), (float)(y[i] * 2.0f)); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_relu.txt b/tests/golden/test_externs/test_relu.txt new file mode 100644 index 000000000..f2fd00d91 --- /dev/null +++ b/tests/golden/test_externs/test_relu.txt @@ -0,0 +1,63 @@ +#include "test.h" + +#include +#include + +float _relu_float(float x) { + if (x > 0.0) return x; + else return 0.0; +} + +// foo( +// x : f32[16] @DRAM +// ) +void foo( void *ctxt, float* x ) { +for (int_fast32_t i = 0; i < 16; i++) { + x[i] = _relu_float((float)3.0f); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM +// ) +void foo( void *ctxt, float* x ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_relu2.txt b/tests/golden/test_externs/test_relu2.txt new file mode 100644 index 000000000..8d5174c56 --- /dev/null +++ b/tests/golden/test_externs/test_relu2.txt @@ -0,0 +1,63 @@ +#include "test.h" + +#include +#include + +float _relu_float(float x) { + if (x > 0.0) return x; + else return 0.0; +} + +// foo( +// x : f32[16] @DRAM +// ) +void foo( void *ctxt, float* x ) { +for (int_fast32_t i = 0; i < 16; i++) { + x[i] = _relu_float((float)x[i]); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM +// ) +void foo( void *ctxt, float* x ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_relu3.txt b/tests/golden/test_externs/test_relu3.txt new file mode 100644 index 000000000..d1b294fc3 --- /dev/null +++ b/tests/golden/test_externs/test_relu3.txt @@ -0,0 +1,67 @@ +#include "test.h" + +#include +#include + +float _relu_float(float x) { + if (x > 0.0) return x; + else return 0.0; +} + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM, +// z : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, const float* y, float* z ) { +for (int_fast32_t i = 0; i < 16; i++) { + z[i] = _relu_float((float)x[i] + y[i]); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM, +// z : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, const float* y, float* z ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_relu4.txt b/tests/golden/test_externs/test_relu4.txt new file mode 100644 index 000000000..e1d141c51 --- /dev/null +++ b/tests/golden/test_externs/test_relu4.txt @@ -0,0 +1,63 @@ +#include "test.h" + +#include +#include + +int8_t _relu_int8_t(int8_t x) { + if (x > 0.0) return x; + else return 0.0; +} + +// foo( +// x : i8[16] @DRAM +// ) +void foo( void *ctxt, int8_t* x ) { +for (int_fast32_t i = 0; i < 16; i++) { + x[i] = _relu_int8_t((int8_t)((int8_t) 3.0)); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : i8[16] @DRAM +// ) +void foo( void *ctxt, int8_t* x ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_select.txt b/tests/golden/test_externs/test_select.txt new file mode 100644 index 000000000..fa71ccbad --- /dev/null +++ b/tests/golden/test_externs/test_select.txt @@ -0,0 +1,67 @@ +#include "test.h" + +#include +#include + +int8_t _select_int8_t(int8_t x,int8_t v,int8_t y,int8_t z) { + if (x < v) return y; + else return z; +} + +// foo( +// x : i8[16] @DRAM, +// y : i8[16] @DRAM, +// z : i8[16] @DRAM +// ) +void foo( void *ctxt, const int8_t* x, const int8_t* y, int8_t* z ) { +for (int_fast32_t i = 0; i < 16; i++) { + z[i] = _select_int8_t((int8_t)x[i] * ((int8_t) 2), (int8_t)y[i], (int8_t)z[i] + y[i], (int8_t)-x[i]); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : i8[16] @DRAM, +// y : i8[16] @DRAM, +// z : i8[16] @DRAM +// ) +void foo( void *ctxt, const int8_t* x, const int8_t* y, int8_t* z ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_sigmoid.txt b/tests/golden/test_externs/test_sigmoid.txt new file mode 100644 index 000000000..bc202a82b --- /dev/null +++ b/tests/golden/test_externs/test_sigmoid.txt @@ -0,0 +1,66 @@ +#include "test.h" + +#include +#include + + +#include +float sigmoid(float x) { + return 1 / (1 + exp(-x)); +} + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ) { +for (int_fast32_t i = 0; i < 16; i++) { + y[i] = sigmoid((float)(x[i] + y[i])); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_sin.txt b/tests/golden/test_externs/test_sin.txt new file mode 100644 index 000000000..3c6784c39 --- /dev/null +++ b/tests/golden/test_externs/test_sin.txt @@ -0,0 +1,59 @@ +#include "test.h" + +#include +#include + +#include +// foo( +// x : i8[16] @DRAM +// ) +void foo( void *ctxt, int8_t* x ) { +for (int_fast32_t i = 0; i < 16; i++) { + x[i] = sin((int8_t)x[i] * ((int8_t) 2)); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : i8[16] @DRAM +// ) +void foo( void *ctxt, int8_t* x ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_externs/test_sqrt.txt b/tests/golden/test_externs/test_sqrt.txt new file mode 100644 index 000000000..d37ce59b5 --- /dev/null +++ b/tests/golden/test_externs/test_sqrt.txt @@ -0,0 +1,61 @@ +#include "test.h" + +#include +#include + +#include +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ) { +for (int_fast32_t i = 0; i < 16; i++) { + y[i] = sqrt((float)(x[i] + y[i])); +} +} + + +#pragma once +#ifndef TEST_H +#define TEST_H + +#ifdef __cplusplus +extern "C" { +#endif + + +#include +#include + +// Compiler feature macros adapted from Hedley (public domain) +// https://github.com/nemequ/hedley + +#if defined(__has_builtin) +# define EXO_HAS_BUILTIN(builtin) __has_builtin(builtin) +#else +# define EXO_HAS_BUILTIN(builtin) (0) +#endif + +#if EXO_HAS_BUILTIN(__builtin_assume) +# define EXO_ASSUME(expr) __builtin_assume(expr) +#elif EXO_HAS_BUILTIN(__builtin_unreachable) +# define EXO_ASSUME(expr) \ + ((void)((expr) ? 1 : (__builtin_unreachable(), 1))) +#else +# define EXO_ASSUME(expr) ((void)(expr)) +#endif + + + +// foo( +// x : f32[16] @DRAM, +// y : f32[16] @DRAM +// ) +void foo( void *ctxt, const float* x, float* y ); + + + +#ifdef __cplusplus +} +#endif +#endif // TEST_H diff --git a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt index e31e3e652..7667657dd 100644 --- a/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_block_replace_forwarding_for_blocks.txt @@ -26,10 +26,18 @@ def baz(n: size, m: size): # BLOCK START x: f32 @ DRAM # BLOCK END + pass + pass + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + pass + pass # BLOCK START for k in seq(0, n): pass @@ -51,7 +59,11 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START pass pass - # BLOCK END \ No newline at end of file + # BLOCK END + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt index de57d7e17..cacb83a11 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_blocks.txt @@ -15,11 +15,14 @@ def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 # BLOCK START x = 1.0 x = 2.0 x = 3.0 # BLOCK END + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM @@ -29,11 +32,19 @@ def bar(n: size, m: size): x = 0.0 x = 1.0 # BLOCK END + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 # BLOCK START x = 4.0 x = 5.0 diff --git a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt index f8cfced04..bca928ff4 100644 --- a/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt +++ b/tests/golden/test_internal_cursors/test_cursor_pretty_print_gaps.txt @@ -1,28 +1,71 @@ def bar(n: size, m: size): [GAP - Before] + x: f32 @ DRAM + for i in seq(0, n): + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): + x: f32 @ DRAM [GAP - Before] + for i in seq(0, n): + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): [GAP - Before] + for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): [GAP - Before] + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 + x = 1.0 [GAP - Before] + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 def bar(n: size, m: size): x: f32 @ DRAM for i in seq(0, n): for j in seq(0, m): + x = 0.0 + x = 1.0 + x = 2.0 + x = 3.0 + x = 4.0 + x = 5.0 [GAP - After] \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt index 8fbffdb76..1b45337ab 100644 --- a/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_delete_forwarding_for_blocks.txt @@ -24,10 +24,15 @@ def baz(n: size, m: size): # BLOCK START x: f32 @ DRAM # BLOCK END + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y = 1.1 for k in seq(0, n): @@ -49,6 +54,7 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y = 1.1 for k in seq(0, n): @@ -61,4 +67,8 @@ def baz(n: size, m: size): for j in seq(0, m): # BLOCK START x: f32 @ DRAM - # BLOCK END \ No newline at end of file + # BLOCK END + y = 1.1 + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt index 16fad7def..df3f44091 100644 --- a/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_insert_forwarding_for_blocks.txt @@ -31,10 +31,19 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 # BLOCK END + pass + y: f32 @ DRAM + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + x = 0.0 + pass # BLOCK START y: f32 @ DRAM y = 1.1 diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt index 2ab01ea00..9db487954 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks.txt @@ -31,10 +31,17 @@ def baz(n: size, m: size): y = 1.1 x = 0.0 # BLOCK END + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + y: f32 @ DRAM + y = 1.1 + x = 0.0 # BLOCK START for k in seq(0, n): pass @@ -57,7 +64,12 @@ def baz(n: size, m: size): def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM # BLOCK START y: f32 @ DRAM y = 1.1 - # BLOCK END \ No newline at end of file + # BLOCK END + x = 0.0 + for k in seq(0, n): + pass + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt index da231454d..cd1464eea 100644 --- a/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt +++ b/tests/golden/test_internal_cursors/test_move_forwarding_for_blocks_gap_after.txt @@ -41,7 +41,9 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 for k in seq(0, n): + pass # BLOCK START y: f32 @ DRAM y = 1.1 - # BLOCK END \ No newline at end of file + # BLOCK END + pass \ No newline at end of file diff --git a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt index e161e1281..161b23e34 100644 --- a/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt +++ b/tests/golden/test_internal_cursors/test_wrap_forwarding_for_blocks.txt @@ -31,10 +31,21 @@ def baz(n: size, m: size): x: f32 @ DRAM x = 0.0 # BLOCK END + for k in seq(0, 8): + y: f32 @ DRAM + y = 1.1 + for k in seq(0, n): + pass + pass def baz(n: size, m: size): for i in seq(0, n): for j in seq(0, m): + x: f32 @ DRAM + x = 0.0 + for k in seq(0, 8): + y: f32 @ DRAM + y = 1.1 # BLOCK START for k in seq(0, n): pass diff --git a/tests/test_codegen.py b/tests/test_codegen.py index bdc95a3f9..3fe2ab678 100644 --- a/tests/test_codegen.py +++ b/tests/test_codegen.py @@ -8,6 +8,7 @@ from exo import proc, instr, Procedure, DRAM, compile_procs_to_strings from exo.libs.memories import MDRAM, MemGenError, StaticMemory, DRAM_STACK +from exo.libs.externs import * from exo.stdlib.scheduling import * mock_registers = 0 diff --git a/tests/test_config.py b/tests/test_config.py index 88447d88c..c30d7fbb9 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -4,6 +4,7 @@ from exo import proc, DRAM, config, instr from exo.libs.memories import GEMM_SCRATCH +from exo.libs.externs import * from exo.stdlib.scheduling import * diff --git a/tests/test_cursors.py b/tests/test_cursors.py index 672ca0f52..0e26f62ee 100644 --- a/tests/test_cursors.py +++ b/tests/test_cursors.py @@ -4,6 +4,7 @@ from exo import proc, ExoType from exo.libs.memories import * +from exo.libs.externs import * from exo.API_cursors import * from exo.stdlib.inspection import * @@ -729,3 +730,34 @@ def foo(n: size, x: i8[n]): if_stmt = foo.find("if _: _ ") i_loop_alternative = if_stmt.find("for i in _: _") assert i_loop2 == i_loop_alternative + + +def test_cursor_print(golden): + @proc + def foo(n: size, x: i8[n]): + for j in seq(0, n - 1): + x[j] = 2.0 + for i in seq(0, n): + pass + if n > 1: + for i in seq(0, n): + x[i] = 0.0 + for j in seq(0, n - 1): + x[j] = 3.0 + + i_loop2 = foo.find("for i in _:_ #1") + i_loop1 = foo.find("for i in _:_ #0") + + res = str(i_loop2) + str(i_loop2.before()) + str(i_loop2.after()) + res += ( + str(i_loop1) + + str(i_loop1.before()) + + str(i_loop1.after()) + + str(i_loop1.expand(1, 0)) + + str(i_loop1.expand(0, 1)) + ) + + assert res == golden + + with pytest.raises(InvalidCursorError, match="Trying to print the Invalid Cursor!"): + print(i_loop1.parent()) diff --git a/tests/test_error_reporting.py b/tests/test_error_reporting.py index 60874805c..f3126a822 100644 --- a/tests/test_error_reporting.py +++ b/tests/test_error_reporting.py @@ -6,7 +6,7 @@ from exo import SchedulingError from exo import proc -from exo.syntax import * +from exo.frontend.syntax import * from exo.stdlib.scheduling import * # skipping because the API has changed to invalidate this particular diff --git a/tests/test_externs.py b/tests/test_externs.py new file mode 100644 index 000000000..5b5d1033f --- /dev/null +++ b/tests/test_externs.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import pytest + +from exo import proc, DRAM, Procedure, config, compile_procs_to_strings +from exo.libs.externs import * +from exo.stdlib.scheduling import SchedulingError + + +def test_relu(golden, compiler): + @proc + def foo(x: f32[16]): + for i in seq(0, 16): + x[i] = relu(3.0) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_relu2(golden, compiler): + @proc + def foo(x: f32[16]): + for i in seq(0, 16): + x[i] = relu(x[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_relu3(golden, compiler): + @proc + def foo(x: f32[16], y: f32[16], z: f32[16]): + for i in seq(0, 16): + z[i] = relu(x[i] + y[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_relu4(golden, compiler): + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = relu(3.0) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_relu5(): + with pytest.raises(TypeError, match="expected 1 argument, got 2"): + + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = relu(3.0, 2.0) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = relu(i) + + +def test_sin(golden, compiler): + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = sin(x[i] * 2) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_sin2(golden, compiler): + with pytest.raises(TypeError, match="expected 1 argument, got 2"): + + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = sin(x[i] * 2, 3) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16]): + for i in seq(0, 16): + x[i] = sin(i) + + +def test_select(golden, compiler): + @proc + def foo(x: i8[16], y: i8[16], z: i8[16]): + for i in seq(0, 16): + z[i] = select(x[i] * 2, y[i], z[i] + y[i], -x[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_expf(golden, compiler): + @proc + def foo(x: i8[16], y: i8[16]): + for i in seq(0, 16): + y[i] = expf(x[i] + y[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_expf2(): + with pytest.raises(TypeError, match="expected 1 argument, got"): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = expf(x[0], x[0]) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = expf(True) + + +def test_fmaxf(golden, compiler): + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = fmaxf(x[i], y[i] * 2) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_fmaxf2(): + with pytest.raises(TypeError, match="expected 2 argument, got 1"): + + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = fmaxf(x[i]) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = fmaxf(i, x[i]) + + with pytest.raises( + TypeError, match="expected argument 2 to be a real scalar value," + ): + + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = fmaxf(x[i], i) + + +def test_sigmoid(golden, compiler): + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = sigmoid(x[i] + y[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_sigmoid2(): + with pytest.raises(TypeError, match="expected 1 argument, got"): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = sigmoid(x[0], x[0]) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = sigmoid(True) + + +def test_sqrt(golden, compiler): + @proc + def foo(x: f32[16], y: f32[16]): + for i in seq(0, 16): + y[i] = sqrt(x[i] + y[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + assert c_file + h_file == golden + + compiler.compile(foo) + + +def test_sqrt2(): + with pytest.raises(TypeError, match="expected 1 argument, got"): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = sqrt(x[0], x[0]) + + with pytest.raises( + TypeError, match="expected argument 1 to be a real scalar value," + ): + + @proc + def foo(x: i8[16], y: i8[16]): + y[0] = sqrt(True) + + +def test_select_error(): + @proc + def foo(x: i8[16], y: f32[16], z: f64[16]): + for i in seq(0, 16): + z[i] = select(x[i] * 2, y[i], z[i], -x[i]) + + with pytest.raises(TypeError, match="all extern arguments must have a same type"): + c_file, h_file = compile_procs_to_strings([foo], "test.h") + + +def test_type_error(): + with pytest.raises(TypeError, match="expected scalar type"): + + @proc + def foo(x: i8[16], y: f32[16], z: f64[16]): + for i in seq(0, 16): + z[i] = select(i * 2, y[i], z[i], -x[i]) + + with pytest.raises(TypeError, match="expected 4 arguments, got 3"): + + @proc + def foo(x: i8[16], y: f32[16], z: f64[16]): + for i in seq(0, 16): + z[i] = select(i * 2, y[i], z[i]) + + +def test_select_fine(): + @proc + def foo(x: i8[16], y: i8[16], z: i8[16]): + for i in seq(0, 16): + z[i] = select(0.0, y[i], z[i], -x[i]) + + c_file, h_file = compile_procs_to_strings([foo], "test.h") + + +def test_two(): + c = 2 + + @proc + def foo(a: f32): + a = a + c + + with pytest.raises(SchedulingError, match="find: failed to find matches"): + foo.find("a + c").parent() + + +def test_extern_find(golden): + @proc + def foo(a: f32): + a = sin(a) + + assert golden == str(foo.find("sin(a)").parent()) diff --git a/tests/test_internal_cursors.py b/tests/test_internal_cursors.py index b7c941f7f..d4c0ef99e 100644 --- a/tests/test_internal_cursors.py +++ b/tests/test_internal_cursors.py @@ -3,17 +3,17 @@ import pytest from exo import proc, SchedulingError, Procedure -from exo.LoopIR import LoopIR, T -from exo.LoopIR_pprint import _print_cursor -from exo.internal_cursors import ( +from exo.core.LoopIR import LoopIR, T +from exo.core.LoopIR_pprint import _print_cursor +from exo.core.internal_cursors import ( Cursor, Block, InvalidCursorError, Node, ) -from exo.pattern_match import match_pattern -from exo.prelude import Sym -from exo.syntax import size, f32 +from exo.frontend.pattern_match import match_pattern +from exo.core.prelude import Sym +from exo.frontend.syntax import size, f32 def _find_cursors(ctx, pattern): diff --git a/tests/test_neon.py b/tests/test_neon.py index ce69364fe..be6819e41 100644 --- a/tests/test_neon.py +++ b/tests/test_neon.py @@ -9,7 +9,7 @@ from exo import proc from exo.platforms.neon import * from exo.stdlib.scheduling import * -from exo.memory import MemGenError +from exo.core.memory import MemGenError import numpy as np diff --git a/tests/test_new_eff.py b/tests/test_new_eff.py index 31f07361c..7f079e4a7 100644 --- a/tests/test_new_eff.py +++ b/tests/test_new_eff.py @@ -2,7 +2,7 @@ import pytest -from exo.new_eff import * +from exo.rewrite.new_eff import * from exo import proc, config, DRAM, SchedulingError from exo.stdlib.scheduling import * diff --git a/tests/test_range_analysis.py b/tests/test_range_analysis.py index bdb371eeb..195eaac22 100644 --- a/tests/test_range_analysis.py +++ b/tests/test_range_analysis.py @@ -2,7 +2,7 @@ from exo.stdlib.scheduling import * from exo import proc -from exo.range_analysis import ( +from exo.rewrite.range_analysis import ( constant_bound, arg_range_analysis, IndexRangeEnvironment, @@ -11,7 +11,7 @@ infer_range, bounds_inference, ) -from exo.LoopIR import LoopIR, T +from exo.core.LoopIR import LoopIR, T def test_affine_index_range(): diff --git a/tests/test_rvv.py b/tests/test_rvv.py index 0f224c5f2..87addc109 100644 --- a/tests/test_rvv.py +++ b/tests/test_rvv.py @@ -9,7 +9,7 @@ from exo import proc from exo.platforms.rvv import * from exo.stdlib.scheduling import * -from exo.memory import MemGenError +from exo.core.memory import MemGenError import numpy as np diff --git a/tests/test_schedules.py b/tests/test_schedules.py index a0e42ebb3..f8be47285 100644 --- a/tests/test_schedules.py +++ b/tests/test_schedules.py @@ -481,6 +481,48 @@ def foo(): fission(foo, foo.find("x = 0.0").after(), n_lifts=2) +def test_if_fission(): + @proc + def before(x: size, y: f32): + if x < 10: + y += 1 + y += 2 + else: + y += 3 + y += 4 + + @proc + def fission_if(x: size, y: f32): + if x < 10: + y += 1 + if x < 10: + y += 2 + else: + y += 3 + y += 4 + + @proc + def fission_else(x: size, y: f32): + if x < 10: + y += 1 + y += 2 + else: + y += 3 + if x < 10: + pass + else: + y += 4 + + test_fission_if = rename(before, "fission_if") + test_fission_if = fission(test_fission_if, test_fission_if.find("y += 1").after()) + assert str(fission_if) == str(test_fission_if) + test_fission_else = rename(before, "fission_else") + test_fission_else = fission( + test_fission_else, test_fission_else.find("y += 3").after() + ) + assert str(fission_else) == str(test_fission_else) + + def test_resize_dim(golden): @proc def foo(): diff --git a/tests/test_typecheck.py b/tests/test_typecheck.py index 393bc4d74..fe9f86d0a 100644 --- a/tests/test_typecheck.py +++ b/tests/test_typecheck.py @@ -4,7 +4,8 @@ from exo import proc, config from exo.libs.memories import GEMM_SCRATCH -from exo.pyparser import ParseError +from exo.frontend.pyparser import ParseError +from exo.libs.externs import * # --- Typechecking tests --- diff --git a/tests/test_uast.py b/tests/test_uast.py index a670e71f5..08f771e8f 100644 --- a/tests/test_uast.py +++ b/tests/test_uast.py @@ -1,7 +1,14 @@ from __future__ import annotations +import pytest + from exo import DRAM -from exo.pyparser import Parser, get_src_locals, get_ast_from_python +from exo.frontend.pyparser import ( + Parser, + get_src_locals, + get_ast_from_python, + ParseError, +) def to_uast(f): @@ -10,7 +17,7 @@ def to_uast(f): body, getsrcinfo, func_globals=f.__globals__, - srclocals=get_src_locals(depth=3), + srclocals=get_src_locals(depth=2), instr=("TEST", ""), as_func=True, ) @@ -57,3 +64,58 @@ def alloc_nest( res[i, j] = rloc[j] assert str(to_uast(alloc_nest)) == golden + + +global_str = "What is 6 times 9?" +global_num = 42 + + +def test_variable_lookup_positive(): + def func(f: f32): + for i in seq(0, 42): + f += 1 + + reference = to_uast(func) + + def func(f: f32): + for i in seq(0, global_num): + f += 1 + + test_global = to_uast(func) + assert str(test_global) == str(reference) + + local_num = 42 + + def func(f: f32): + for i in seq(0, local_num): + f += 1 + + test_local = to_uast(func) + assert str(test_local) == str(reference) + + +def test_variable_lookup_type_error(): + def func(f: f32): + for i in seq(0, global_str): + f += 1 + + with pytest.raises(ParseError, match="type "): + to_uast(func) + + local_str = "xyzzy" + + def func(f: f32): + for i in seq(0, local_str): + f += 1 + + with pytest.raises(ParseError, match="type "): + to_uast(func) + + +def test_variable_lookup_name_error(): + def func(f: f32): + for i in seq(0, xyzzy): + f += 1 + + with pytest.raises(ParseError, match="'xyzzy' undefined"): + to_uast(func)