From 969b9ad925113d740195d72bb01cc41fe11fcfdd Mon Sep 17 00:00:00 2001 From: Michael Date: Wed, 20 Mar 2024 00:03:38 +0300 Subject: [PATCH] Add support of tensor types with unit dimension prefix in EvalSlice (#2090) Edit: Sure, sorry for missing description. Currently our small research team experimenting with BERT model represented in number of *-hlo dialects and wants to simplify it in terms of variety of operators. This PR fixes an issue we stumbled upon: ``` %42 = stablehlo.constant dense<"..."> : tensor<1x512xi64> ... %66 = stablehlo.slice %42 [0:1, 0:128] : (tensor<1x512xi64>) -> tensor<1x128xi64> ``` Previous implementation of EvalSlice can't handle such case -- tensor type prefixed with unit dimension(s) i.e. 1x128. This PR adds support of the above case and can slice from any position, e.g. ``` %0 = stablehlo.constant dense<[[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]]> : tensor<1x3x2x2xi64> %1 = "stablehlo.slice"(%0) { start_indices = array, limit_indices = array, strides = array } : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64> ``` is folded to ``` %1 = stablehlo.constant dense<[[[[7, 8]]]]> : tensor<1x1x1x2xi64> ``` --- BUILD.bazel | 1 + stablehlo/tests/stablehlo_refine_shapes.mlir | 64 +++++++++++++++++++ stablehlo/transforms/CMakeLists.txt | 1 + .../transforms/StablehloRefineShapes.cpp | 45 +++++++++---- 4 files changed, 100 insertions(+), 11 deletions(-) diff --git a/BUILD.bazel b/BUILD.bazel index 79f20cf9bc..274e0cfb2e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -963,6 +963,7 @@ cc_library( "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:CommonFolders", "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", diff --git a/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/tests/stablehlo_refine_shapes.mlir index 44c215b101..11bdf93289 100644 --- a/stablehlo/tests/stablehlo_refine_shapes.mlir +++ b/stablehlo/tests/stablehlo_refine_shapes.mlir @@ -424,6 +424,70 @@ func.func @eval_slice() -> tensor<2xi64> { // ----- +// CHECK-LABEL: func @eval_slice_wild_stride +func.func @eval_slice_wild_stride() -> tensor<1x1x1xi64> { + // CHECK-NOT: stablehlo.slice + // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<2> : tensor<1x1x1xi64> + // CHECK: return [[RESULT]] + %0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64> + %1 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x2x2xi64>) -> tensor<1x1x1xi64> + func.return %1 : tensor<1x1x1xi64> +} + +// ----- + +// CHECK-LABEL: func @eval_slice_unit_prefix +func.func @eval_slice_unit_prefix() -> (tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>) { + // CHECK-NOT: stablehlo.slice + // CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<{{\[\[\[}}[1, 2]]]]> : tensor<1x1x1x2xi64> + // CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<{{\[\[\[}}[7, 8]]]]> : tensor<1x1x1x2xi64> + // CHECK: [[RESULT3:%.*]] = stablehlo.constant dense<{{\[\[\[}}[11, 12]]]]> : tensor<1x1x1x2xi64> + // CHECK: return [[RESULT1]], [[RESULT2]], [[RESULT3]] + %0 = stablehlo.constant dense<[[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]]> : tensor<1x3x2x2xi64> + + %1 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64> + + %2 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64> + + %3 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64> + + func.return %1, %2, %3 : tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64> +} + +// ----- + +// CHECK-LABEL: func @eval_slice_non_unit_prefix +func.func @eval_slice_non_unit_prefix() -> tensor<1x2x1xi64> { + // CHECK: stablehlo.constant {{.*}} : tensor<1x2x2xi64> + // CHECK: [[RESULT:%.*]] = stablehlo.slice{{.*}} + // CHECK: return [[RESULT]] + %0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64> + %1 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x2x2xi64>) -> tensor<1x2x1xi64> + func.return %1 : tensor<1x2x1xi64> +} + +// ----- + // CHECK-LABEL: func @eval_subtract func.func @eval_subtract() -> tensor { // CHECK-NOT: stablehlo.subtract diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index bc196efd3b..753e8983e1 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -45,6 +45,7 @@ add_mlir_dialect_library(StablehloPasses MLIRArithDialect MLIRAsmParser MLIRComplexDialect + MLIRDialectUtils MLIRFuncDialect MLIRFunctionInterfaces MLIRIR diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index a1b8809455..dc4b3d2aad 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -577,20 +578,42 @@ struct EvalSliceOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(SliceOp op, PatternRewriter& rewriter) const override { auto resultType = op.getType(); - if (!resultType.hasRank() || resultType.getRank() != 1) - return rewriter.notifyMatchFailure(op, "expected 1-dimensional type"); - - SmallVector operand; - if (failed(hlo::matchInts(op.getOperand(), operand))) + if (resultType.getRank() < 1) + return rewriter.notifyMatchFailure( + op, "expected non-0 ranked tensor result type"); + + auto operand = op.getOperand().cast>(); + RankedTensorType operandType = operand.getType(); + if (!operandType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "expected operand with static ranked tensor type"); + + // A ranked tensor type with unit dimension prefix of R-1 size is physically + // compatible with 1-dimensional type. + if (!llvm::all_of(resultType.getShape().drop_back(), + [](int64_t s) { return s == 1; })) + return rewriter.notifyMatchFailure( + op, "expected 1-dimensional compatible result type"); + + SmallVector operandData; + if (failed(hlo::matchInts(operand, operandData))) return rewriter.notifyMatchFailure(op, "expected constant operand"); - int64_t start = op.getStartIndices()[0]; - int64_t limit = op.getLimitIndices()[0]; - int64_t stride = op.getStrides()[0]; + const auto dimOffsets = computeSuffixProduct(operandType.getShape()); + auto startIndices = op.getStartIndices(); + auto limitIndices = op.getLimitIndices(); + auto strides = op.getStrides(); + + int64_t start = 0; + for (size_t i = 0; i < startIndices.size(); ++i) + start += startIndices[i] * dimOffsets[i]; + + auto slicedDim = operandType.getRank() - 1; + int64_t limit = start + limitIndices[slicedDim] - startIndices[slicedDim]; + int64_t stride = strides[slicedDim]; SmallVector result; - for (auto i = start; i < limit; i += stride) { - result.push_back(operand[i]); - } + for (auto i = start; i < limit; i += stride) + result.push_back(operandData[i]); rewriter.replaceOpWithNewOp(op, getTensorAttr(resultType, result));