diff --git a/lib/Transform/XTenMinimizeLiveTensors.cpp b/lib/Transform/XTenMinimizeLiveTensors.cpp index 6b5f5cc9..855b683f 100644 --- a/lib/Transform/XTenMinimizeLiveTensors.cpp +++ b/lib/Transform/XTenMinimizeLiveTensors.cpp @@ -199,7 +199,7 @@ FailureOr> getFmOperands(Operation *op) { return {getSubgraphIFMs(op)}; if (isTemplatedGraph(op)) - return {op->getOperands()}; + return {getSubgraphIFMs(op)}; // Otherwise, this is a PseudoOp and IFM is the first operand. if (!(isAnyPseudoOp(op) || isInterfaceOp(op))) { @@ -225,7 +225,8 @@ size_t getSize(Value val) { if (auto complexType = elementType.dyn_cast()) { elementType = complexType.getElementType(); - return (elementType.getIntOrFloatBitWidth() * type.getNumElements() * 2) / 8; + return (elementType.getIntOrFloatBitWidth() * type.getNumElements() * 2) / + 8; } llvm_unreachable("Does not know how to compute size"); } @@ -299,7 +300,8 @@ class XTenMinimizeLiveTensorsPass } else { fmResults = SmallVector(currFn.getBody().front().getArguments()); } - std::optional const sharesResultMemory = sharesMemoryWithResult(defOp); + std::optional const sharesResultMemory = + sharesMemoryWithResult(defOp); OpInfo info = {.op = defOp, .operands = *fmOperands, .results = fmResults, diff --git a/test/Transform/XTenMinimizeLiveTensors/other_subgraphs.mlir b/test/Transform/XTenMinimizeLiveTensors/other_subgraphs.mlir index f2df3a4f..c75e5976 100644 --- a/test/Transform/XTenMinimizeLiveTensors/other_subgraphs.mlir +++ b/test/Transform/XTenMinimizeLiveTensors/other_subgraphs.mlir @@ -361,4 +361,19 @@ func.func @support_for_inteface_op(%arg0: tensor<1x3x224x224xf32>) -> tensor<1x6 xten_nn.output %6 : tensor<1x64x56x56xf32> } -> tensor<1x64x56x56xf32> return %3 : tensor<1x64x56x56xf32> +} + +// ----- + +// CHECK-LABEL: func.func @tg_with_constant_ops +// CHECK: LayerName = "TGConst"{{.*}} Reason = "TemplatedGraph" +func.func @tg_with_constant_ops(%arg0: tensor<1x1x64x8xbf16>) -> tensor<1x1x64x8xbf16> { + %0 = xten_nn.load_external_const {file = "constants.h5", key = "Test/Constant_2_0"} -> tensor<8xbf16> + %1 = xten_nn.load_external_const {file = "constants.h5", key = "Test/Constant_1_0"} -> tensor<8xbf16> + %2 = xten_nn.subgraph (%arg1 = %arg0: tensor<1x1x64x8xbf16>, %arg2 = %1: tensor<8xbf16>, %arg3 = %0: tensor<8xbf16>) attributes {IfmOperands = [0 : index], LayerName = "TGConst", Reason = "TemplatedGraph"} + { + %6 = tensor.empty() : tensor<1x1x64x8xbf16> + xten_nn.output %6 : tensor<1x1x64x8xbf16> + } -> tensor<1x1x64x8xbf16> + return %2 : tensor<1x1x64x8xbf16> } \ No newline at end of file