From 9a2fecd434812cec9f8ff01bfcd425f5b2547f92 Mon Sep 17 00:00:00 2001 From: Yuka Ikarashi Date: Sat, 17 Aug 2024 15:57:41 -0400 Subject: [PATCH] runtime --- src/exo/analysis.py | 148 +++++++++++++++++++++++++++++++++++--------- src/exo/dataflow.py | 97 ----------------------------- 2 files changed, 118 insertions(+), 127 deletions(-) diff --git a/src/exo/analysis.py b/src/exo/analysis.py index 9cfa42e4f..0d7e70162 100644 --- a/src/exo/analysis.py +++ b/src/exo/analysis.py @@ -9,7 +9,6 @@ from .dataflow import ( LoopIR_to_DataflowIR, ScalarPropagation, - GetControlPredicates, GetValues, D, ) @@ -376,6 +375,71 @@ def lift_es(es): return [lift_e(e) for e in es] +# --------------------------------------------------------------------------- # +# Getting control flow on DataflowIR. Will be unnecessary when we +# integrate control flow into abstract values. +# --------------------------------------------------------------------------- # + + +class GetControlPredicates(LoopIR_Do): + def __init__(self, proc, stmts): + self.proc = proc + self.stmts = stmts + self.preds = None + self.done = False + self.cur_preds = [] + + for a in self.proc.args: + if isinstance(a.type, T.Size): + size_pred = A.BinOp( + "<", + A.Const(0, T.int, null_srcinfo()), + A.Var(a.name, T.size, a.srcinfo), + T.bool, + null_srcinfo(), + ) + self.cur_preds.append(size_pred) + self.do_t(a.type) + + for pred in self.proc.preds: + self.cur_preds.append(lift_e(pred)) + self.do_e(pred) + + self.do_stmts(self.proc.body) + + def do_s(self, s): + if self.done: + return + + if s == self.stmts[0]: + self.preds = AAnd(*self.cur_preds) + self.done = True + + styp = type(s) + if styp is LoopIR.If: + self.cur_preds.append(lift_e(s.cond)) + self.do_stmts(s.body) + self.cur_preds.pop() + + self.cur_preds.append(A.Not(lift_e(s.cond), T.int, null_srcinfo())) + self.do_stmts(s.orelse) + self.cur_preds.pop() + + elif styp is LoopIR.For: + a_iter = A.Var(s.iter, T.int, s.srcinfo) + b1 = A.BinOp("<=", lift_e(s.lo), a_iter, T.bool, null_srcinfo()) + b2 = A.BinOp("<", a_iter, lift_e(s.hi), T.bool, null_srcinfo()) + cond = A.BinOp("and", b1, b2, T.bool, null_srcinfo()) + self.cur_preds.append(cond) + self.do_stmts(s.body) + self.cur_preds.pop() + + super().do_s(s) + + def result(self): + return self.preds.simplify() + + # Produce a set of AExprs which occur as right-hand-sides # of config writes. def possible_config_writes(stmts): @@ -1531,11 +1595,13 @@ def loop_globenv(i, lo_expr, hi_expr, body): def Check_ReorderStmts(proc, s1, s2): - datair, stmts = LoopIR_to_DataflowIR(proc, [s1, s2]).result() + # datair, stmts = LoopIR_to_DataflowIR(proc, [s1, s2]).result() + + # print("here in ReorderStmts") - assert len(stmts) == 2 + assert isinstance(s1, LoopIR.stmt) and isinstance(s2, LoopIR.stmt) - p = GetControlPredicates(datair, stmts).result() + p = GetControlPredicates(proc, [s1, s2]).result() slv = SMTSolver(verbose=False) slv.push() @@ -1554,11 +1620,13 @@ def Check_ReorderStmts(proc, s1, s2): def Check_ReorderLoops(proc, s): - datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result() + # datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result() - assert len(stmts) == 1 + # print("here in ReorderLoops") - p = GetControlPredicates(datair, stmts).result() + assert isinstance(s, LoopIR.For) + + p = GetControlPredicates(proc, [s]).result() slv = SMTSolver(verbose=False) slv.push() @@ -1632,11 +1700,13 @@ def bds(x, lo, hi): # /\ ( forall i,i'. May(InBound(i,i',e) /\ i < i') => Commutes(a1', a1) ) # def Check_ParallelizeLoop(proc, s): - datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result() + # datair, stmts = LoopIR_to_DataflowIR(proc, [s]).result() + + # print("Check_ParallelizeLoop") - assert len(stmts) == 1 + assert isinstance(s, LoopIR.For) - p = GetControlPredicates(datair, stmts).result() + p = GetControlPredicates(proc, [s]).result() slv = SMTSolver(verbose=False) slv.push() @@ -1688,9 +1758,11 @@ def bds(x, lo, hi): # def Check_FissionLoop(proc, loop, stmts1, stmts2, no_loop_var_1=False): - datair, d_loop = LoopIR_to_DataflowIR(proc, [loop]).result() + # print("Check_FissionLoop") - p = GetControlPredicates(datair, d_loop).result() + # datair, d_loop = LoopIR_to_DataflowIR(proc, [loop]).result() + + p = GetControlPredicates(proc, [loop]).result() slv = SMTSolver(verbose=False) slv.push() @@ -1774,9 +1846,9 @@ def lift_dexpr(e, key=None): def Check_DeleteConfigWrite(proc, stmts): assert len(stmts) > 0 - ir1, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result() - p = GetControlPredicates(ir1, d_stmts).result() + # print("here in DeleteConfigWrite") + p = GetControlPredicates(proc, stmts).result() slv = SMTSolver(verbose=False) slv.push() slv.assume(AMay(p)) @@ -1801,6 +1873,7 @@ def Check_DeleteConfigWrite(proc, stmts): ) # Below are the actual checks + ir1, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result() ScalarPropagation(ir1) @@ -1869,6 +1942,8 @@ def Check_ExtendEqv(proc1, proc2, stmts1, stmts2, cfg_mod): assert len(stmts1) == 1 assert len(stmts2) == 1 + # print("here in Check_ExtendEqv") + slv = SMTSolver(verbose=False) slv.push() @@ -1928,16 +2003,18 @@ def make_point(key): def Check_ExprEqvInContext(proc, expr0, stmts0, expr1, stmts1=None): + + # print("Check_ExprEqvInContext") assert len(stmts0) > 0 stmts1 = stmts1 or stmts0 - len_0 = len(stmts0) - datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts0 + stmts1).result() - d_stmts0 = d_stmts[0:len_0] - d_stmts1 = d_stmts[len_0:] + # len_0 = len(stmts0) + # datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts0 + stmts1).result() + # d_stmts0 = d_stmts[0:len_0] + # d_stmts1 = d_stmts[len_0:] - p0 = GetControlPredicates(datair, d_stmts0).result() - p1 = GetControlPredicates(datair, d_stmts1).result() + p0 = GetControlPredicates(proc, stmts0).result() + p1 = GetControlPredicates(proc, stmts1).result() slv = SMTSolver(verbose=False) slv.push() @@ -1954,11 +2031,13 @@ def Check_ExprEqvInContext(proc, expr0, stmts0, expr1, stmts1=None): def Check_BufferReduceOnly(proc, stmts, buf, ndim): + + # print("Check_BufferReduceOnly") assert len(stmts) > 0 - datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result() + # datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result() - p = GetControlPredicates(datair, d_stmts).result() + p = GetControlPredicates(proc, stmts).result() slv = SMTSolver(verbose=False) slv.push() @@ -1988,13 +2067,15 @@ def Check_Access_In_Window(proc, access_cursor, w_exprs, block_cursor): block_cursor is the context in which to interpret the access in. """ + # print("Check_Access_In_Window") + access = access_cursor._node block = [x._node for x in block_cursor] idxs = access.idx assert len(idxs) == len(w_exprs) - datair, d_stmts = LoopIR_to_DataflowIR(proc, block).result() - p = GetControlPredicates(datair, d_stmts).result() + # datair, d_stmts = LoopIR_to_DataflowIR(proc, block).result() + p = GetControlPredicates(proc, block).result() slv = SMTSolver(verbose=False) slv.push() @@ -2067,9 +2148,10 @@ def Check_Bounds(proc, alloc_stmt, block): if len(block) == 0: return - datair, stmts = LoopIR_to_DataflowIR(proc, block).result() + # print("Check_Bounds") + # datair, stmts = LoopIR_to_DataflowIR(proc, block).result() - p = GetControlPredicates(datair, stmts).result() + p = GetControlPredicates(proc, block).result() slv = SMTSolver(verbose=False) slv.push() @@ -2105,6 +2187,8 @@ def Check_Bounds(proc, alloc_stmt, block): def Check_IsDeadAfter(proc, stmts, bufname, ndim): + + # print("Check_IsDeadAfter") assert len(stmts) > 0 ap = PostEnv(proc, stmts).get_posteffs() @@ -2126,11 +2210,13 @@ def Check_IsDeadAfter(proc, stmts, bufname, ndim): def Check_IsIdempotent(proc, stmts): + + # print("Check_IsIdempotent") assert len(stmts) > 0 - datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result() + # datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result() - p = GetControlPredicates(datair, d_stmts).result() + p = GetControlPredicates(proc, stmts).result() slv = SMTSolver(verbose=False) slv.push() @@ -2144,10 +2230,11 @@ def Check_IsIdempotent(proc, stmts): def Check_ExprBound(proc, stmts, expr, op, value, exception=True): + # print("Check_ExprBound") assert len(stmts) > 0 - datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result() - p = GetControlPredicates(datair, d_stmts).result() + # datair, d_stmts = LoopIR_to_DataflowIR(proc, stmts).result() + p = GetControlPredicates(proc, stmts).result() # TODO: Check_ExprBound does not depend on configuration states so this can be skipped, but more fundamentally running abstract interpretation this many times is simply too slow. # ScalarPropagation(datair) @@ -2335,5 +2422,6 @@ def do_s(self, s): def Check_Aliasing(proc): + # print("Check_Aliasing") helper = _Check_Aliasing_Helper(proc) # that's it diff --git a/src/exo/dataflow.py b/src/exo/dataflow.py index 8ec275f78..5f1da8a77 100644 --- a/src/exo/dataflow.py +++ b/src/exo/dataflow.py @@ -829,100 +829,3 @@ def abs_builtin(self, builtin, args): # TODO: write a short circuit for select builtin return D.Const(builtin.interpret(vargs), args[0].typ) - - -# --------------------------------------------------------------------------- # -# Getting control flow on DataflowIR. Will be unnecessary when we -# integrate control flow into abstract values. -# --------------------------------------------------------------------------- # - - -def lift_dataflow(e): - if e.type.is_indexable() or e.type.is_stridable() or e.type == T.bool: - if isinstance(e, DataflowIR.Read): - assert len(e.idx) == 0 - return A.Var(e.name, e.type, e.srcinfo) - elif isinstance(e, DataflowIR.Const): - return A.Const(e.val, e.type, e.srcinfo) - elif isinstance(e, DataflowIR.BinOp): - return A.BinOp( - e.op, lift_dataflow(e.lhs), lift_dataflow(e.rhs), e.type, e.srcinfo - ) - elif isinstance(e, DataflowIR.USub): - return A.USub(lift_dataflow(e.arg), e.type, e.srcinfo) - elif isinstance(e, DataflowIR.StrideExpr): - return A.Stride(e.name, e.dim, e.type, e.srcinfo) - elif isinstance(e, DataflowIR.ReadConfig): - return A.Var(e.config_field, e.type, e.srcinfo) - else: - f"bad case: {type(e)}" - else: - assert e.type.is_numeric() - if e.type.is_real_scalar(): - if isinstance(e, DataflowIR.Const): - return A.Const(e.val, e.type, e.srcinfo) - elif isinstance(e, DataflowIR.Read): - return A.ConstSym(e.name, e.type, e.srcinfo) - elif isinstance(e, DataflowIR.ReadConfig): - return A.Var(e.config_field, e.type, e.srcinfo) - - return A.Unk(T.err, e.srcinfo) - - -class GetControlPredicates(DataflowIR_Do): - def __init__(self, datair, stmts): - self.datair = datair - self.stmts = stmts - self.preds = None - self.done = False - self.cur_preds = [] - - for a in self.datair.args: - if isinstance(a.type, T.Size): - size_pred = A.BinOp( - "<", - A.Const(0, T.int, null_srcinfo()), - A.Var(a.name, T.size, a.srcinfo), - T.bool, - null_srcinfo(), - ) - self.cur_preds.append(size_pred) - self.do_t(a.type) - - for pred in self.datair.preds: - self.cur_preds.append(lift_dataflow(pred)) - self.do_e(pred) - - self.do_stmts(self.datair.body.stmts) - - def do_s(self, s): - if self.done: - return - - if s == self.stmts[0]: - self.preds = AAnd(*self.cur_preds) - self.done = True - - styp = type(s) - if styp is DataflowIR.If: - self.cur_preds.append(lift_dataflow(s.cond)) - self.do_stmts(s.body.stmts) - self.cur_preds.pop() - - self.cur_preds.append(A.Not(lift_dataflow(s.cond), T.int, null_srcinfo())) - self.do_stmts(s.orelse.stmts) - self.cur_preds.pop() - - elif styp is DataflowIR.For: - a_iter = A.Var(s.iter, T.int, s.srcinfo) - b1 = A.BinOp("<=", lift_dataflow(s.lo), a_iter, T.bool, null_srcinfo()) - b2 = A.BinOp("<", a_iter, lift_dataflow(s.hi), T.bool, null_srcinfo()) - cond = A.BinOp("and", b1, b2, T.bool, null_srcinfo()) - self.cur_preds.append(cond) - self.do_stmts(s.body.stmts) - self.cur_preds.pop() - - super().do_s(s) - - def result(self): - return self.preds.simplify()