diff --git a/include/polygeist/Passes/Passes.td b/include/polygeist/Passes/Passes.td index 66ec6f894241..6569b6b4fdc6 100644 --- a/include/polygeist/Passes/Passes.td +++ b/include/polygeist/Passes/Passes.td @@ -10,6 +10,7 @@ def AffineOpt : Pass<"affine-opt"> { let dependentDialects = [ "scf::SCFDialect", "arith::ArithDialect", + "memref::MemRefDialect", ]; } diff --git a/lib/polygeist/Passes/AffineOpt.cpp b/lib/polygeist/Passes/AffineOpt.cpp index b52e09f7965c..a5b0e0f62852 100644 --- a/lib/polygeist/Passes/AffineOpt.cpp +++ b/lib/polygeist/Passes/AffineOpt.cpp @@ -133,8 +133,9 @@ static void inlineAll(func::CallOp callOp, ModuleOp m = nullptr) { m = callOp->getParentOfType(); auto name = callOp.getCallee(); if (auto funcOp = dyn_cast(m.lookupSymbol(name))) { - funcOp->walk( - [&](func::CallOp nestedCallOp) { inlineAll(nestedCallOp, m); }); + if (funcOp->hasAttr(SCOP_STMT_ATTR_NAME)) + funcOp->walk( + [&](func::CallOp nestedCallOp) { inlineAll(nestedCallOp, m); }); alwaysInlineCall(callOp); } else { llvm::errs() << "Unexpected call to non-FuncOp\n"; @@ -166,6 +167,7 @@ void AffineOptPass::runOnOperation() { mlir::func::CallOp call = pair.second; // Reg2Mem polymer::separateAffineIfBlocks(f, b); + polymer::demoteLoopReduction(f, b); polymer::demoteRegisterToMemory(f, b); // Extract scop stmt polymer::replaceUsesByStored(f, b); @@ -179,16 +181,17 @@ void AffineOptPass::runOnOperation() { signalPassFailure(); return; } - if (mlir::func::FuncOp g = polymer::plutoTransform(f, b, "")) { + mlir::func::FuncOp g = nullptr; + if ((g = polymer::plutoTransform(f, b, ""))) { g.setPublic(); g->setAttrs(f->getAttrs()); g.setName(f.getName()); f.erase(); - inlineAll(call); - if (true) { - polymer::plutoParallelize(g, b); - } + } + inlineAll(call); + if (g && /*options.parallelize=*/true) { + polymer::plutoParallelize(g, b); } } } diff --git a/tools/polymer/lib/Support/ScopStmt.cc b/tools/polymer/lib/Support/ScopStmt.cc index 6bf200e568a5..c935178b41bb 100644 --- a/tools/polymer/lib/Support/ScopStmt.cc +++ b/tools/polymer/lib/Support/ScopStmt.cc @@ -216,13 +216,35 @@ void ScopStmt::getAccessMapAndMemRef(mlir::Operation *op, IRMapping argMap; impl->getArgsValueMapping(argMap); - // TODO: assert op is in the callee. - affine::MemRefAccess access(op); + SmallVector indices; + if (auto loadOp = dyn_cast(op)) { + *memref = loadOp.getMemRef(); + llvm::append_range(indices, loadOp.getMapOperands()); + } else { + assert(isa(op) && + "Affine read/write op expected"); + auto storeOp = cast(op); + *memref = storeOp.getMemRef(); + llvm::append_range(indices, storeOp.getMapOperands()); + } + + // Get affine map from AffineLoad/Store. + AffineMap map; + if (auto loadOp = dyn_cast(op)) + map = loadOp.getAffineMap(); + else + map = cast(op).getAffineMap(); - // Collect the access affine::AffineValueMap that binds to operands in the - // callee. affine::AffineValueMap aMap; - access.getAccessMap(&aMap); + + // SmallVector operands2(indices.begin(), indices.end()); + // affine::fullyComposeAffineMapAndOperands(&map, &operands2); + // map = simplifyAffineMap(map); + // affine::canonicalizeMapAndOperands(&map, &operands2); + aMap.reset(map, indices); + + // TODO: assert op is in the callee. + affine::MemRefAccess access(op); // Replace its operands by what the caller uses. SmallVector operands; diff --git a/tools/polymer/lib/Transforms/Reg2Mem.cc b/tools/polymer/lib/Transforms/Reg2Mem.cc index bd4ea0981824..eaae00e55291 100644 --- a/tools/polymer/lib/Transforms/Reg2Mem.cc +++ b/tools/polymer/lib/Transforms/Reg2Mem.cc @@ -520,8 +520,9 @@ cloneAffineForWithoutIterArgs(mlir::affine::AffineForOp forOp, OpBuilder &b) { return newForOp; } -static void demoteLoopReduction(mlir::func::FuncOp f, - mlir::affine::AffineForOp forOp, OpBuilder &b) { +namespace polymer { +void demoteLoopReduction(mlir::func::FuncOp f, mlir::affine::AffineForOp forOp, + OpBuilder &b) { SmallVector initVals{forOp.getInits()}; mlir::Block *body = forOp.getBody(); mlir::affine::AffineYieldOp yieldOp = findYieldOp(forOp); @@ -546,13 +547,14 @@ static void demoteLoopReduction(mlir::func::FuncOp f, forOp.erase(); } -static void demoteLoopReduction(mlir::func::FuncOp f, OpBuilder &b) { +void demoteLoopReduction(mlir::func::FuncOp f, OpBuilder &b) { SmallVector forOps; findReductionLoops(f, forOps); for (mlir::affine::AffineForOp forOp : forOps) demoteLoopReduction(f, forOp, b); } +} // namespace polymer class DemoteLoopReductionPass : public mlir::PassWrapper