Skip to content

Commit

Permalink
debug test_interp.py - all tests expected to pass do pass
Browse files Browse the repository at this point in the history
  • Loading branch information
meganfrisella committed Nov 12, 2024
1 parent 9bd7383 commit 430e6de
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 16 deletions.
2 changes: 1 addition & 1 deletion src/exo/API.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from .frontend.pattern_match import match_pattern
from .core.prelude import *
from .rewrite.new_eff import Check_Aliasing
from .LoopIR_interpreter import run_interpreter
from .backend.LoopIR_interpreter import run_interpreter

# Moved to new file
from .core.proc_eqv import decl_new_proc, derive_proc, assert_eqv_proc, check_eqv_proc
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@

import numpy as np

from .LoopIR import LoopIR
from .LoopIR import T
from .prelude import *
from ..core.LoopIR import LoopIR
from ..core.LoopIR import T
from ..core.prelude import *

from .parallel_analysis import ParallelAnalysis
from .prec_analysis import PrecisionAnalysis
Expand Down Expand Up @@ -303,10 +303,10 @@ def stringify_w_access(a):

# BuiltIns don't go to the interpreter, they are just called (via call) like a proc
# TODO Discuss to make sure
elif isinstance(e, LoopIR.BuiltIn):
assert False, "Not implemented"
# args = [self.eval_e(a) for a in e.args]
# return e.f.interpret(args)
# elif isinstance(e, LoopIR.BuiltIn):
# assert False, "Not implemented"
# args = [self.eval_e(a) for a in e.args]
# return e.f.interpret(args)

elif isinstance(e, LoopIR.StrideExpr):
buf = self.env[e.name]
Expand Down
16 changes: 8 additions & 8 deletions tests/test_interp.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import numpy as np

from exo import proc, config
from exo import proc, config, instr
from exo.libs.memories import GEMM_SCRATCH
from exo.stdlib.scheduling import SchedulingError

Expand All @@ -15,9 +15,9 @@ def test_mat_mul(compiler):
@proc
def rank_k_reduce(
K: size,
A: f32[6, K] @ DRAM,
B: f32[K, 16] @ DRAM,
C: f32[6, 16] @ DRAM,
A: f32[6, K],
B: f32[K, 16],
C: f32[6, 16],
):
for i in seq(0, 6):
for j in seq(0, 16):
Expand Down Expand Up @@ -254,7 +254,7 @@ def foo(A: i8[3, 4]):

fn = compiler.compile(foo)

A = np.arange(3 * 4, dtype=float).reshape((3, 4))
A = np.arange(3 * 4, dtype=np.int8).reshape((3, 4))

fn(None, A)
foo.interpret(A=A)
Expand All @@ -273,7 +273,7 @@ def foo(A: [i8][3, 4]):

fn = compiler.compile(foo)

A = np.arange(6 * 8, dtype=float).reshape((6, 8))
A = np.arange(6 * 8, dtype=np.int8).reshape((6, 8))

fn(None, A[::2, ::2])
foo.interpret(A=A[::2, ::2])
Expand All @@ -289,7 +289,7 @@ def foo(A: [i8][3, 2, 3]):

fn = compiler.compile(foo)

A = np.arange(3 * 4 * 5, dtype=float).reshape((3, 4, 5))
A = np.arange(3 * 4 * 5, dtype=np.int8).reshape((3, 4, 5))

fn(None, A[::1, ::2, ::2])
foo.interpret(A=A[::1, ::2, ::2])
Expand All @@ -305,7 +305,7 @@ def foo(A: [i8][2, 4, 2]):

fn = compiler.compile(foo)

A = np.arange(3 * 4 * 5, dtype=float).reshape((3, 4, 5))
A = np.arange(3 * 4 * 5, dtype=np.int8).reshape((3, 4, 5))

fn(None, A[::2, ::1, ::3])
foo.interpret(A=A[::2, ::1, ::3])
Expand Down

0 comments on commit 430e6de

Please sign in to comment.