From a760d415759fb12d892b226325571661e1143db2 Mon Sep 17 00:00:00 2001 From: "Rickert, Jonas" Date: Fri, 4 Oct 2024 13:35:16 +0100 Subject: [PATCH] Skip external consts when walking operands --- lib/Transform/XTenMinimizeLiveTensors.cpp | 19 ++++++++++++++++--- 1 file changed, 16 insertions(+), 3 deletions(-) diff --git a/lib/Transform/XTenMinimizeLiveTensors.cpp b/lib/Transform/XTenMinimizeLiveTensors.cpp index 6b5f5cc9..326b5a9d 100644 --- a/lib/Transform/XTenMinimizeLiveTensors.cpp +++ b/lib/Transform/XTenMinimizeLiveTensors.cpp @@ -192,6 +192,17 @@ FailureOr> getFmOperands(Operation *op) { if (isa(op)) return {{}}; + const auto filterOutExternalConst = [](SmallVector operands) { + SmallVector filteredOperands; + for (const auto operand : operands) { + if (!isa_and_nonnull( + operand.getDefiningOp())) { + filteredOperands.push_back(operand); + } + } + return filteredOperands; + }; + if (isInCoreChain(op)) return {getSubgraphIFMs(op)}; @@ -199,7 +210,7 @@ FailureOr> getFmOperands(Operation *op) { return {getSubgraphIFMs(op)}; if (isTemplatedGraph(op)) - return {op->getOperands()}; + return {filterOutExternalConst(op->getOperands())}; // Otherwise, this is a PseudoOp and IFM is the first operand. if (!(isAnyPseudoOp(op) || isInterfaceOp(op))) { @@ -225,7 +236,8 @@ size_t getSize(Value val) { if (auto complexType = elementType.dyn_cast()) { elementType = complexType.getElementType(); - return (elementType.getIntOrFloatBitWidth() * type.getNumElements() * 2) / 8; + return (elementType.getIntOrFloatBitWidth() * type.getNumElements() * 2) / + 8; } llvm_unreachable("Does not know how to compute size"); } @@ -299,7 +311,8 @@ class XTenMinimizeLiveTensorsPass } else { fmResults = SmallVector(currFn.getBody().front().getArguments()); } - std::optional const sharesResultMemory = sharesMemoryWithResult(defOp); + std::optional const sharesResultMemory = + sharesMemoryWithResult(defOp); OpInfo info = {.op = defOp, .operands = *fmOperands, .results = fmResults,