Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
yamaguchi1024 committed Apr 24, 2024
1 parent 7a152a5 commit 959d90c
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 43 deletions.
114 changes: 71 additions & 43 deletions src/exo/dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@
| Var( sym name )
| Const( object val, type type )
| BinOp( binop op, mexpr lhs, mexpr rhs )
| Array( mexpr arg, sym idx ) -- !!This is not enough b/c we'd like to be able to affine-index-access arrays, but is a literal implementation of Fluid update
path = ( aexpr constraints, mexpr tgt )
| Array( sym name, avar *dims )
path = ( aexpr nc, aexpr sc, mexpr tgt ) -- perform weak update for now
env = ( avar *dims, path* paths ) -- This can handle index access uniformly!
}
""",
Expand Down Expand Up @@ -109,6 +109,9 @@ def validateAbsEnv(obj):
# Top Level Call to Dataflow analysis
# --------------------------------------------------------------------------- #

aexpr_false = A.Const(False, T.bool, null_srcinfo())
aexpr_true = A.Const(True, T.bool, null_srcinfo())


class LoopIR_to_DataflowIR:
def __init__(self, proc):
Expand Down Expand Up @@ -257,14 +260,18 @@ def __str__(self):
if isinstance(self, D.BinOp):
return str(self.lhs) + str(e.op) + str(self.rhs)
if isinstance(self, D.Array):
return str(self.arg) + "[" + str(self.idx) + "]"
dim_str = "["
for d in self.dims:
dim_str += str(d)
dim_str += "]"
return str(self.name) + dim_str

assert False, "bad case"


@extclass(AbstractDomains.path)
def __str__(self):
return "(" + str(self.constraints) + ", " + str(self.tgt) + ")"
return "(" + str(self.nc) + ", " + str(self.sc) + ") : " + str(self.tgt)


@extclass(AbstractDomains.env)
Expand All @@ -291,13 +298,14 @@ def update(env: D.env, rval: list[D.path]):
merge_paths = []
for pre_path in env.paths:
for rval_path in rval:
pre_cons = pre_path.constraints.simplify()
rval_cons = rval_path.constraints.simplify()
pre_cons = pre_path.nc.simplify()
rval_cons = rval_path.nc.simplify()

if isinstance(pre_path.tgt, D.Unk):
pre_paths.remove(pre_path)
elif pre_cons == rval_cons:
merge_paths.append(D.path(rval_cons, rval_path.tgt))
# TODO: Handle strong update
merge_paths.append(D.path(rval_cons, rval_path.sc, rval_path.tgt))
pre_paths.remove(pre_path)
rval_paths.remove(rval_path)

Expand All @@ -308,21 +316,39 @@ def bind_cons(cons: A.expr, rval: list[D.path]):
new_paths = []

for path in rval:
new_cons = A.BinOp("and", path.constraints, cons, T.bool, null_srcinfo())
new_path = D.path(new_cons.simplify(), path.tgt)
new_nc = A.BinOp("and", path.nc, cons, T.bool, null_srcinfo())
new_path = D.path(new_nc.simplify(), path.sc, path.tgt)
new_paths.append(new_path)

return new_paths


def propagate_cons(cons: A.expr, env: D.env):
return D.env(env.dims, bind_cons(cons, env.paths))


def ir_to_aexpr(e: DataflowIR.expr):
if isinstance(e, DataflowIR.Const):
ae = A.Const(e.val, e.type, null_srcinfo())
elif isinstance(e, DataflowIR.Read):
ae = A.Var(e.name, e.type, null_srcinfo())
elif isinstance(e, Sym):
ae = A.Var(e, T.index, null_srcinfo())
else:
assert False, f"got {e} of type {type(e)}"

return ae


class AbstractInterpretation(ABC):
def __init__(self, proc: DataflowIR.proc):
self.proc = proc

# setup initial values
init_env = self.proc.body.ctxts[0]
for a in proc.args:
init_env[a.name] = self.abs_init_val(a.name, a.type)
if a.type.is_numeric():
init_env[a.name] = self.abs_init_val(a.name, a.type)

# We probably ought to somehow use precondition assertions
# TODO: leave it for now
Expand Down Expand Up @@ -356,14 +382,7 @@ def fix_stmt(self, pre_env, stmt: DataflowIR.stmt, post_env):
# Handle constraints
cons = A.Const(True, T.bool, null_srcinfo())
for b, e in zip(pre_env[stmt.name].dims, stmt.idx):
# TODO!!: Replace this with a general pass to convert DataflowIR to Aexpr
if isinstance(e, DataflowIR.Const):
e = A.Const(e.val, e.type, null_srcinfo())
elif isinstance(e, DataflowIR.Read):
e = A.Var(e.name, e.type, null_srcinfo())
else:
assert False, "???"
eq = A.BinOp("==", b, e, T.bool, null_srcinfo())
eq = A.BinOp("==", b, ir_to_aexpr(e), T.bool, null_srcinfo())
cons = A.BinOp("and", cons, eq, T.bool, null_srcinfo())

rval = bind_cons(cons, rval)
Expand Down Expand Up @@ -421,28 +440,31 @@ def fix_stmt(self, pre_env, stmt: DataflowIR.stmt, post_env):
elif isinstance(stmt, DataflowIR.For):
# TODO: Add support for loop-condition analysis in some way?

# set up the loop body for fixed-point iteration
pre_body = stmt.body.ctxts[0]
iter_cons = self.abs_iter_val(
ir_to_aexpr(stmt.iter), ir_to_aexpr(stmt.lo), ir_to_aexpr(stmt.hi)
)
for nm, val in pre_env.items():
pre_body[nm] = val
# initialize the loop iteration variable
lo = self.fix_expr(pre_env, stmt.lo)
hi = self.fix_expr(pre_env, stmt.hi)
pre_body[stmt.iter] = self.abs_iter_val(lo, hi)
pre_body[nm] = propagate_cons(iter_cons, val)

# Commenting out the following. We don't need to run a fixed-point

# set up the loop body for fixed-point iteration
# run this loop until we reach a fixed-point
at_fixed_point = False
while not at_fixed_point:
# propagate in the loop
self.fix_block(stmt.body)
at_fixed_point = True
# copy the post-values for the loop back around to
# the pre-values, by joining them together
for nm, prev_val in pre_body.items():
next_val = stmt.body.ctxts[-1][nm]
val = self.abs_join(prev_val, next_val)
at_fixed_point = at_fixed_point and prev_val == val
pre_body[nm] = val
# at_fixed_point = False
# while not at_fixed_point:
# propagate in the loop
# self.fix_block(stmt.body)
# at_fixed_point = True
# copy the post-values for the loop back around to
# the pre-values, by joining them together
# for nm, prev_val in pre_body.items():
# next_val = stmt.body.ctxts[-1][nm]
# val = self.abs_join(prev_val, next_val)
# at_fixed_point = at_fixed_point and prev_val == val
# pre_body[nm] = val

self.fix_block(stmt.body)

# determine the post-env as join of pre-env and loop results
for nm, pre_val in pre_env.items():
Expand Down Expand Up @@ -496,7 +518,7 @@ def abs_alloc_val(self, name, typ):
"""Define initial value of an allocation"""

@abstractmethod
def abs_iter_val(self, lo, hi):
def abs_iter_val(self, name, lo, hi):
"""Define value of an iteration variable"""

@abstractmethod
Expand Down Expand Up @@ -525,11 +547,15 @@ def abs_builtin(self, builtin, args):


def make_empty_path(me: D.mexpr) -> D.path:
return [D.path(A.Const(True, T.bool, null_srcinfo()), me)]
return [D.path(aexpr_true, aexpr_false, me)]


def make_unk() -> D.path:
return [D.path(A.Const(True, T.bool, null_srcinfo()), D.Unk())]
return [D.path(aexpr_true, aexpr_false, D.Unk())]


def make_unk_array(buf_name: Sym, dims: list) -> D.path:
return [D.path(aexpr_true, aexpr_false, D.Array(buf_name, dims))]


class ConstantPropagation(AbstractInterpretation):
Expand All @@ -540,7 +566,7 @@ def abs_init_val(self, name, typ):
dims.append(
A.Var(Sym(name.name() + "_" + str(i)), T.index, null_srcinfo())
)
return D.env(dims, make_unk())
return D.env(dims, make_unk_array(name, dims))
else:
return D.env([], make_unk())

Expand All @@ -551,12 +577,14 @@ def abs_alloc_val(self, name, typ):
dims.append(
A.Var(Sym(name.name() + "_" + str(i)), T.index, null_srcinfo())
)
return D.env(dims, make_unk())
return D.env(dims, make_unk_array(name, dims))
else:
return D.env([], make_unk())

def abs_iter_val(self, lo, hi):
return D.env([], make_unk())
def abs_iter_val(self, name, lo, hi):
lo_cons = A.BinOp("<=", lo, name, T.index, null_srcinfo())
hi_cons = A.BinOp("<", name, hi, T.index, null_srcinfo())
return AAnd(lo_cons, hi_cons)

def abs_stride_expr(self, name, dim):
assert False, "unimplemented"
Expand Down
18 changes: 18 additions & 0 deletions tests/test_dataflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,3 +182,21 @@ def foo(n: size, a: f32, c: f32):
print()
print(foo.dataflow())
print()


def test_absval_init():
@proc
def foo(n: size, dst: f32[n]):
for i in seq(0, n):
dst[i] = 0.0

print()
print(foo.dataflow())

@proc
def foo(n: size, dst: f32[n], src: f32[n]):
for i in seq(0, n):
dst[i] = src[i]

print()
print(foo.dataflow())

0 comments on commit 959d90c

Please sign in to comment.