diff --git a/src/exo/LoopIR_interpreter.py b/src/exo/LoopIR_interpreter.py index f16cfdcd..3ef2390f 100644 --- a/src/exo/LoopIR_interpreter.py +++ b/src/exo/LoopIR_interpreter.py @@ -35,64 +35,16 @@ def run_interpreter(proc, kwargs): Interpreter(proc, kwargs) -# context is global -ctxt = defaultdict(dict) - class Interpreter: def __init__(self, proc, kwargs, use_randomization=False): - assert isinstance(proc, LoopIR.proc) - - proc = ParallelAnalysis().run(proc) - proc = PrecisionAnalysis().run(proc) # TODO: need this? - proc = WindowAnalysis().apply_proc(proc) - proc = MemoryAnalysis().run(proc) # TODO: need this? + if not isinstance(proc, LoopIR.proc): + raise TypeError(f"Expected {proc.name} to be of type proc") - self.proc = proc self.env = ChainMap() self.use_randomization = use_randomization + self.ctxt = defaultdict(dict) - # type check args - for a in proc.args: - if not str(a.name) in kwargs: - raise TypeError(f"expected argument '{a.name}' to be supplied") - - if a.type is T.size: - if not is_pos_int(kwargs[str(a.name)]): - raise TypeError( - f"expected size '{a.name}' to have positive integer value" - ) - self.env[a.name] = kwargs[str(a.name)] - elif a.type is T.index: - if type(kwargs[str(a.name)]) is not int: - raise TypeError( - f"expected index variable '{a.name}' to be an integer" - ) - self.env[a.name] = kwargs[str(a.name)] - elif a.type is T.bool: - if type(kwargs[str(a.name)]) is not bool: - raise TypeError(f"expected bool variable '{a.name}' to be a bool") - self.env[a.name] = kwargs[str(a.name)] - elif a.type is T.stride: - if type(kwargs[str(a.name)]) is not int: - raise TypeError( - f"expected stride variable '{a.name}' to be an integer" - ) - self.env[a.name] = kwargs[str(a.name)] - else: - self.typecheck_input_buffer(a, kwargs) - self.env[a.name] = kwargs[str(a.name)] - - # evaluate preconditions - for pred in proc.preds: - if isinstance(pred, LoopIR.Const): - continue - else: - assert self.eval_e(pred), "precondition not satisfied" - - # eval statements - self.env = self.env.new_child() - self.eval_stmts(proc.body) - self.env = self.env.parents + self.eval_proc(proc, kwargs) def _new_scope(self): self.env = self.env.new_child() @@ -154,6 +106,52 @@ def typecheck_input_buffer(self, proc_arg, kwargs): f"but got shape {tuple(buf.shape)}" ) + def eval_proc(self, proc, kwargs): + proc = ParallelAnalysis().run(proc) + proc = PrecisionAnalysis().run(proc) # TODO: need this? + proc = WindowAnalysis().apply_proc(proc) + proc = MemoryAnalysis().run(proc) # TODO: need this? + + for a in proc.args: + if not str(a.name) in kwargs: + raise TypeError(f"expected argument '{a.name}' to be supplied") + + if a.type is T.size: + if not is_pos_int(kwargs[str(a.name)]): + raise TypeError( + f"expected size '{a.name}' to have positive integer value" + ) + self.env[a.name] = kwargs[str(a.name)] + elif a.type is T.index: + if type(kwargs[str(a.name)]) is not int: + raise TypeError( + f"expected index variable '{a.name}' to be an integer" + ) + self.env[a.name] = kwargs[str(a.name)] + elif a.type is T.bool: + if type(kwargs[str(a.name)]) is not bool: + raise TypeError(f"expected bool variable '{a.name}' to be a bool") + self.env[a.name] = kwargs[str(a.name)] + elif a.type is T.stride: + if type(kwargs[str(a.name)]) is not int: + raise TypeError( + f"expected stride variable '{a.name}' to be an integer" + ) + self.env[a.name] = kwargs[str(a.name)] + else: + self.typecheck_input_buffer(a, kwargs) + self.env[a.name] = kwargs[str(a.name)] + + # evaluate preconditions + for pred in proc.preds: + if isinstance(pred, LoopIR.Const): + continue + else: + assert self.eval_e(pred), "precondition not satisfied" + + # eval statements + self.eval_stmts(proc.body) + def eval_stmts(self, stmts): for s in stmts: self.eval_s(s) @@ -161,7 +159,7 @@ def eval_stmts(self, stmts): def eval_s(self, s): if isinstance(s, LoopIR.Pass): pass - + elif isinstance(s, (LoopIR.Assign, LoopIR.Reduce)): lbuf = self.env[s.name] if len(s.idx) == 0: @@ -179,12 +177,14 @@ def eval_s(self, s): elif isinstance(s, LoopIR.WriteConfig): nm = s.config.name() rhs = self.eval_e(s.rhs) - ctxt[nm][s.field] = rhs + self.ctxt[nm][s.field] = rhs elif isinstance(s, LoopIR.WindowStmt): # nm = rbuf[...] assert s.name not in self.env, "WindowStmt should be a fresh assignment" - assert isinstance(s.rhs, LoopIR.WindowExpr), "WindowStmt rhs should be WindowExpr" + assert isinstance( + s.rhs, LoopIR.WindowExpr + ), "WindowStmt rhs should be WindowExpr" self.env[s.name] = self.eval_e(s.rhs) elif isinstance(s, LoopIR.If): @@ -225,7 +225,9 @@ def eval_s(self, s): argvals = [self.eval_e(a, call_arg=True) for a in s.args] argnames = [str(a.name) for a in s.f.args] kwargs = {nm: val for nm, val in zip(argnames, argvals)} - Interpreter(s.f, kwargs, use_randomization=self.use_randomization) + self._new_scope() + self.eval_proc(s.f, kwargs) + self._del_scope() else: assert False, "bad statement case" @@ -253,10 +255,14 @@ def stringify_w_access(a): assert False, "bad w_access case" # hack to handle interval indexes: LoopIR.Interval returns a string representing the interval - idx = ("0",) if len(e.idx) == 0 else tuple(stringify_w_access(a) for a in e.idx) + idx = ( + ("0",) + if len(e.idx) == 0 + else tuple(stringify_w_access(a) for a in e.idx) + ) res = eval(f"buf[{','.join(idx)}]") return res - + elif isinstance(e, LoopIR.Const): return e.val @@ -268,9 +274,12 @@ def stringify_w_access(a): return lhs - rhs elif e.op == "*": return lhs * rhs - elif e.op == "/": # is this right? - if isinstance(lhs, int): - return (lhs + rhs - 1) // rhs + elif e.op == "/": + if isinstance(lhs, int) and isinstance(rhs, int): + # this is what was here before and without the rhs check + # counter example of why this is wrong -3 / 2 == -1 in C and 0 in this impl + # return (lhs + rhs - 1) // rhs + return int(lhs / rhs) else: return lhs / rhs elif e.op == "%": @@ -293,9 +302,12 @@ def stringify_w_access(a): elif isinstance(e, LoopIR.USub): return -self.eval_e(e.arg) + # BuiltIns don't go to the interpreter, they are just called (via call) like a proc + # TODO Discuss to make sure elif isinstance(e, LoopIR.BuiltIn): - args = [self.eval_e(a) for a in e.args] - return e.f.interpret(args) + assert False, "Not implemented" + # args = [self.eval_e(a) for a in e.args] + # return e.f.interpret(args) elif isinstance(e, LoopIR.StrideExpr): buf = self.env[e.name] @@ -305,7 +317,7 @@ def stringify_w_access(a): elif isinstance(e, LoopIR.ReadConfig): nm = e.config.name() - return ctxt[nm][e.field] + return self.ctxt[nm][e.field] else: print(e) diff --git a/tests/test_interp.py b/tests/test_interp.py index 4dc2ae71..ad13df94 100644 --- a/tests/test_interp.py +++ b/tests/test_interp.py @@ -4,12 +4,13 @@ import numpy as np -from exo import proc, config +from exo import proc, instr, config from exo.libs.memories import GEMM_SCRATCH from exo.stdlib.scheduling import SchedulingError # ------- Interpreter tests --------- +# CURRENTLY FAILING def test_mat_mul(compiler): @proc def rank_k_reduce( @@ -22,18 +23,19 @@ def rank_k_reduce( for j in seq(0, 16): for k in seq(0, K): C[i, j] += A[i, k] * B[k, j] - + fn = compiler.compile(rank_k_reduce) K = 8 - A = np.arange(6*K, dtype=np.float32).reshape((6,K)) - B = np.arange(K*16, dtype=np.float32).reshape((K,16)) - C1 = np.zeros(6*16, dtype=np.float32).reshape((6,16)) - C2 = np.zeros(6*16, dtype=np.float32).reshape((6,16)) + A = np.arange(6 * K, dtype=np.float32).reshape((6, K)) + B = np.arange(K * 16, dtype=np.float32).reshape((K, 16)) + C1 = np.zeros(6 * 16, dtype=np.float32).reshape((6, 16)) + C2 = np.zeros(6 * 16, dtype=np.float32).reshape((6, 16)) fn(None, K, A, B, C1) rank_k_reduce.interpret(K=K, A=A, B=B, C=C2) - assert((C1 == C2).all()) + assert (C1 == C2).all() + def test_reduce_add(compiler): @proc @@ -51,14 +53,15 @@ def acc(N: size, A: f32[N], acc: f32): fn(None, n, A, x) acc.interpret(N=n, A=A, acc=y) - assert(x == y) + assert x == y + def test_scope1(compiler): @proc def foo(res: f32): a: f32 a = 1 - for i in seq(0,4): + for i in seq(0, 4): a: f32 a = 2 res = a @@ -70,14 +73,15 @@ def foo(res: f32): fn(None, x) foo.interpret(res=y) - assert(x == y) + assert x == y + def test_scope2(compiler): @proc def foo(res: f32): a: f32 a = 1 - for i in seq(0,4): + for i in seq(0, 4): a = 2 res = a @@ -88,7 +92,8 @@ def foo(res: f32): fn(None, x) foo.interpret(res=y) - assert(x == y) + assert x == y + def test_empty_seq(compiler): @proc @@ -103,8 +108,9 @@ def foo(res: f32): fn(None, x) foo.interpret(res=y) - assert(x == y) - + assert x == y + + def test_cond(compiler): @proc def foo(res: f32, p: bool): @@ -120,7 +126,8 @@ def foo(res: f32, p: bool): fn(None, x, False) foo.interpret(res=y, p=False) - assert(x == y) + assert x == y + def test_call(compiler): @proc @@ -140,7 +147,8 @@ def foo(res: f32): fn(None, x) foo.interpret(res=y) - assert(x == y) + assert x == y + def test_window_assert(compiler): @proc @@ -159,11 +167,12 @@ def foo( n = 6 m = 8 - src = np.arange(n*m, dtype=np.float32).reshape((n,m)) - dst = np.zeros(n*16, dtype=np.float32).reshape((n,16)) + src = np.arange(n * m, dtype=np.float32).reshape((n, m)) + dst = np.zeros(n * 16, dtype=np.float32).reshape((n, 16)) foo.interpret(n=n, m=m, src=src, dst=dst) - assert((dst[:,:8] == src).all()) + assert (dst[:, :8] == src).all() + def test_window_stmt1(compiler): @proc @@ -171,11 +180,11 @@ def foo(n: size, A: f32[n, 16], C: f32[n]): B = A[:, 0] for i in seq(0, n): C[i] = B[i] - + fn = compiler.compile(foo) n = 6 - A = np.arange(n*16, dtype=np.float32).reshape((n,16)) + A = np.arange(n * 16, dtype=np.float32).reshape((n, 16)) C1 = np.arange(n, dtype=np.float32) C2 = np.arange(n, dtype=np.float32) @@ -184,26 +193,28 @@ def foo(n: size, A: f32[n, 16], C: f32[n]): assert (C1 == C2).all() + def test_window_stmt2(compiler): @proc - def foo(n: size, A: f32[n], B: f32[n], C: f32[2*n]): + def foo(n: size, A: f32[n], B: f32[n], C: f32[2 * n]): for i in seq(0, n): C[i] = A[i] - for i in seq(n, 2*n): - C[i] = B[i-n] - + for i in seq(n, 2 * n): + C[i] = B[i - n] + fn = compiler.compile(foo) n = 6 A = np.arange(n, dtype=np.float32) B = np.arange(n, dtype=np.float32) - C1 = np.zeros(2*n, dtype=np.float32) - C2 = np.zeros(2*n, dtype=np.float32) + C1 = np.zeros(2 * n, dtype=np.float32) + C2 = np.zeros(2 * n, dtype=np.float32) fn(None, n, A, B, C1) foo.interpret(n=n, A=A, B=B, C=C2) assert (C1 == C2).all() + def test_window_stmt3(compiler): @proc def foo(A: f32[8], res: f32): @@ -218,9 +229,11 @@ def foo(A: f32[8], res: f32): fn(None, A, x) foo.interpret(A=A, res=y) - assert(x[0] == 4 and x == y) + assert x[0] == 4 and x == y + # TODO: discuss +# CURRENTLY FAILING # error can be better here def test_window_stmt4(compiler): @proc @@ -228,129 +241,143 @@ def foo(A: f32[8], C: [f32][4]): B = A[4:] C = B[:] + def test_stride_simple1(compiler): @proc - def bar(s0: stride, s1: stride, B: [i8][3,4]): - assert stride(B,0) == s0 - assert stride(B,1) == s1 + def bar(s0: stride, s1: stride, B: [i8][3, 4]): + assert stride(B, 0) == s0 + assert stride(B, 1) == s1 pass + @proc - def foo(A: i8[3,4]): - bar(stride(A, 0), stride(A, 1), A[:,:]) - + def foo(A: i8[3, 4]): + bar(stride(A, 0), stride(A, 1), A[:, :]) + fn = compiler.compile(foo) - A = np.arange(3*4, dtype=float).reshape((3,4)) + A = np.arange(3 * 4, dtype=np.int8).reshape((3, 4)) fn(None, A) foo.interpret(A=A) + def test_stride_simple2(compiler): @proc - def bar(s0: stride, s1: stride, B: [i8][1,1]): - assert stride(B,0) == s0 - assert stride(B,1) == s1 + def bar(s0: stride, s1: stride, B: [i8][1, 1]): + assert stride(B, 0) == s0 + assert stride(B, 1) == s1 pass + @proc - def foo(A: [i8][3,4]): - bar(stride(A, 0), stride(A, 1), A[0:1,1:2]) - + def foo(A: [i8][3, 4]): + bar(stride(A, 0), stride(A, 1), A[0:1, 1:2]) + fn = compiler.compile(foo) - A = np.arange(6*8, dtype=float).reshape((6,8)) + A = np.arange(6 * 8, dtype=np.int8).reshape((6, 8)) + + fn(None, A[::2, ::2]) + foo.interpret(A=A[::2, ::2]) - fn(None, A[::2,::2]) - foo.interpret(A=A[::2,::2]) def test_stride1(compiler): @proc - def foo(A: [i8][3,2,3]): - assert stride(A,0) == 20 - assert stride(A,1) == 5 * 2 - assert stride(A,2) == 1 * 2 + def foo(A: [i8][3, 2, 3]): + assert stride(A, 0) == 20 + assert stride(A, 1) == 5 * 2 + assert stride(A, 2) == 1 * 2 pass - + fn = compiler.compile(foo) - A = np.arange(3*4*5, dtype=float).reshape((3,4,5)) + A = np.arange(3 * 4 * 5, dtype=np.int8).reshape((3, 4, 5)) + + fn(None, A[::1, ::2, ::2]) + foo.interpret(A=A[::1, ::2, ::2]) - fn(None, A[::1,::2,::2]) - foo.interpret(A=A[::1,::2,::2]) def test_stride2(compiler): @proc - def foo(A: [i8][2,4,2]): - assert stride(A,0) == 20 * 2 - assert stride(A,1) == 5 * 1 - assert stride(A,2) == 1 * 3 + def foo(A: [i8][2, 4, 2]): + assert stride(A, 0) == 20 * 2 + assert stride(A, 1) == 5 * 1 + assert stride(A, 2) == 1 * 3 pass fn = compiler.compile(foo) - A = np.arange(3*4*5, dtype=float).reshape((3,4,5)) + A = np.arange(3 * 4 * 5, dtype=np.int8).reshape((3, 4, 5)) + + fn(None, A[::2, ::1, ::3]) + foo.interpret(A=A[::2, ::1, ::3]) - fn(None, A[::2,::1,::3]) - foo.interpret(A=A[::2,::1,::3]) # TODO: discuss +# CURRENTLY FAILING # updating param within stride conditional triggers validation error def test_branch_stride1(compiler): @proc - def bar(B: [i8][3,4], res: f32): - if (stride(B, 0) == 8): + def bar(B: [i8][3, 4], res: f32): + if stride(B, 0) == 8: res = 1 + @proc - def foo(A: i8[3,4], res: f32): - bar(A[:,:], res) + def foo(A: i8[3, 4], res: f32): + bar(A[:, :], res) + # but this is okay: def test_branch_stride2(compiler): @proc - def bar(B: [i8][3,4], res: f32): - if (stride(B, 0) == 8): + def bar(B: [i8][3, 4], res: f32): + if stride(B, 0) == 8: res = 1 + @proc - def foo(A: i8[3,4], res: f32): + def foo(A: i8[3, 4], res: f32): bar(A, res) - + fn = compiler.compile(foo) - A = np.arange(3*4, dtype=np.float32).reshape((3,4)) - x = np.zeros(1, dtype=np.float32) + A = np.arange(3 * 4, dtype=np.int8).reshape((3, 4)) + x = np.zeros(1, dtype=np.int8) y = np.zeros(1, dtype=np.float32) fn(None, A, x) foo.interpret(A=A, res=y) - assert(x == y) + assert x == y + # so is this def test_branch_stride3(compiler): @proc - def bar(B: [i8][3,4], res: f32): + def bar(B: [i8][3, 4], res: f32): a: f32 a = 0 - if (stride(B, 0) == 8): + if stride(B, 0) == 8: a = 1 res = a + @proc - def foo(A: i8[3,4], res: f32): - bar(A[:,:], res) + def foo(A: i8[3, 4], res: f32): + bar(A[:, :], res) fn = compiler.compile(foo) - A = np.arange(3*4, dtype=np.float32).reshape((3,4)) - x = np.zeros(1, dtype=np.float32) + A = np.arange(3 * 4, dtype=np.int8).reshape((3, 4)) + x = np.zeros(1, dtype=np.int8) y = np.zeros(1, dtype=np.float32) fn(None, A, x) foo.interpret(A=A, res=y) - assert(x == y) + assert x == y + def test_bounds_err_interp(): with pytest.raises(TypeError): @proc - def foo(N: size, A:f32[N], res: f32): + def foo(N: size, A: f32[N], res: f32): a: f32 res = A[3] @@ -360,11 +387,12 @@ def foo(N: size, A:f32[N], res: f32): foo.interpret(N=N, A=A, res=x) + def test_precond_interp_simple(): with pytest.raises(AssertionError): @proc - def foo(N: size, A:f32[N], res: f32): + def foo(N: size, A: f32[N], res: f32): assert N == 4 res = A[3] @@ -374,11 +402,13 @@ def foo(N: size, A:f32[N], res: f32): foo.interpret(N=N, A=A, res=x) + # TODO: discuss -# shouldn't this raise a runtime error? +# CURRENTLY FAILING +# shouldn't this raise a runtime error? def test_precond_comp_simple(compiler): @proc - def foo(N: size, A:f32[N], res: f32): + def foo(N: size, A: f32[N], res: f32): assert N == 4 res = A[3] @@ -391,29 +421,32 @@ def foo(N: size, A:f32[N], res: f32): print(x) assert 1 == 0 + def test_precond_interp_stride(): with pytest.raises(AssertionError): @proc - def foo(A:f32[1,8]): + def foo(A: f32[1, 8]): assert stride(A, 0) == 8 pass - A = np.arange(16, dtype=np.float32).reshape((1,16)) - foo.interpret(A=A[:,::2]) + A = np.arange(16, dtype=np.float32).reshape((1, 16)) + foo.interpret(A=A[:, ::2]) + # TODO: discuss # incorrectly informs the compiler about the stride def test_precond_comp_stride(compiler): @proc - def foo(A:f32[1,8]): + def foo(A: f32[1, 8]): assert stride(A, 0) == 8 pass fn = compiler.compile(foo) - A = np.arange(16, dtype=np.float32).reshape((1,16)) - fn(None, A[:,::2]) + A = np.arange(16, dtype=np.float32).reshape((1, 16)) + fn(None, A[:, ::2]) + def new_config(): @config @@ -423,6 +456,7 @@ class Config: return Config + def test_config(compiler): Config = new_config() @@ -437,12 +471,14 @@ def foo(x: f32): foo.interpret(x=x) assert x == 32.0 + def test_config_nested(compiler): Config = new_config() - + @proc def bar(x: f32): x = Config.a + Config.b + @proc def foo(x: f32): Config.a = 32.0 @@ -455,6 +491,7 @@ def foo(x: f32): foo.interpret(x=x) assert x == 48.0 + def test_par_bad(): with pytest.raises(TypeError): @@ -462,12 +499,13 @@ def test_par_bad(): def foo(x: f32[10], acc: f32): for i in par(0, 10): acc += x[i] - + x = np.arange(10, dtype=np.float32) a = np.zeros(1, dtype=np.float32) foo.interpret(x=x, acc=a) + def test_par_good(): @proc def foo(x: f32[10]): @@ -479,4 +517,40 @@ def foo(x: f32[10]): foo.interpret(x=x) assert (x == np.ones(10, dtype=np.float32)).all() -# TODO: test builtin + +def test_built_in(): + @instr("") + def four_wide_vector_add(m: size, A: [f64][m], B: [f64][m], C: [f64][m]): + assert m >= 4 + for i in seq(0, 4): + C[i] = A[i] + B[i] + + @proc + def dumb_vector_add(n: size, A: f64[n], B: f64[n], C: f64[n]): + assert n >= 5 + four_wide_vector_add(n - 1, A[1:], B[1:], C[1:]) + + @proc + def slightly_smarter_vector_add(n: size, A: f64[n], B: f64[n], C: f64[n]): + assert (n % 4) == 0 + assert n >= 8 + for j in seq(0, n / 4): + four_wide_vector_add( + 4, + A[j * 4 : (j * 4) + 4], + B[j * 4 : (j * 4) + 4], + C[j * 4 : (j * 4) + 4], + ) + + A = np.array([1] * 5, dtype=np.float64) + B = np.array([2] * 5, dtype=np.float64) + C = np.zeros(5, dtype=np.float64) + + dumb_vector_add.interpret(n=5, A=A, B=B, C=C) + assert (C == np.array([0, 3, 3, 3, 3], dtype=np.float64)).all() + + A = np.array([1] * 8, dtype=np.float64) + B = np.array([2] * 8, dtype=np.float64) + C = np.zeros(8, dtype=np.float64) + slightly_smarter_vector_add.interpret(n=8, A=A, B=B, C=C) + assert (C == np.array([3] * 8, dtype=np.float64)).all()