Skip to content

Commit

Permalink
add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 committed Feb 7, 2024
1 parent fd200ab commit 8823473
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 2 deletions.
6 changes: 4 additions & 2 deletions src/exo/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,8 +382,10 @@ def fix_stmt(self, pre_env, stmt: DataflowIR.stmt, post_env):
# if reducing, then expand to x = x + rhs
rhs_e = stmt.rhs
if isinstance(stmt, DataflowIR.Reduce):
read_buf = DataflowIR.Read(stmt.name, stmt.idx)
rhs_e = DataflowIR.BinOp("+", read_buf, rhs_e)
read_buf = DataflowIR.Read(
stmt.name, stmt.idx, rhs_e.type, stmt.srcinfo
)
rhs_e = DataflowIR.BinOp("+", read_buf, rhs_e, rhs_e.type, stmt.srcinfo)
# now we can handle both cases uniformly
rval = self.fix_expr(pre_env, rhs_e)
# need to be careful for buffers (no overwrite guarantee)
Expand Down
59 changes: 59 additions & 0 deletions tests/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,3 +83,62 @@ def foo(x: R[3], y: R[3], z: R):
print()
print(foo.dataflow())
print()


# TODO: Currently add_unsafe_guard lacks analysis, but we should be able to analyze this
def test_sliding_window():
@proc
def foo(n: size, m: size, dst: i8[n + m], src: i8[n + m]):
for i in seq(0, n):
for j in seq(0, m):
dst[i + j] = src[i + j]

foo = add_unsafe_guard(foo, "dst[_] = src[_]", "i == 0 or j == m - 1")
print()
print(foo.dataflow())
print()


# TODO: fission should be able to handle this
def test_fission_fail():
@proc
def foo(n: size, dst: i8[n + 1], src: i8[n + 1]):
for i in seq(0, n):
dst[i] = src[i]
dst[i + 1] = src[i + 1]

with pytest.raises(SchedulingError, match="Cannot fission"):
foo = fission(foo, foo.find("dst[i] = _").after())
print(foo)


# TODO: This is unsafe, lift_alloc should give an error
def test_lift_alloc_unsafe():
@proc
def foo():
for i in seq(0, 10):
a: i8[11] @ DRAM
a[i] = 1.0
a[i + 1] += 1.0

foo = lift_alloc(foo, "a : _")
print()
print(foo.dataflow())
print()


# TODO: We are not supporting this AFAIK but should keep this example in mind
def test_reduc():
@proc
def foo(n: size, a: f32, c: f32):
tmp: f32[n]
for i in seq(0, n):
for j in seq(0, 4):
tmp[i] = a
a = tmp[i] + 1.0
for i in seq(0, n):
c += tmp[i] # some use of tmp

print()
print(foo.dataflow())
print()

0 comments on commit 8823473

Please sign in to comment.