diff --git a/BUILD.bazel b/BUILD.bazel index 79f20cf9bc..274e0cfb2e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -963,6 +963,7 @@ cc_library( "@llvm-project//mlir:AsmParser", "@llvm-project//mlir:CommonFolders", "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:DialectUtils", "@llvm-project//mlir:FuncDialect", "@llvm-project//mlir:FunctionInterfaces", "@llvm-project//mlir:IR", diff --git a/stablehlo/tests/stablehlo_refine_shapes.mlir b/stablehlo/tests/stablehlo_refine_shapes.mlir index 44c215b101..11bdf93289 100644 --- a/stablehlo/tests/stablehlo_refine_shapes.mlir +++ b/stablehlo/tests/stablehlo_refine_shapes.mlir @@ -424,6 +424,70 @@ func.func @eval_slice() -> tensor<2xi64> { // ----- +// CHECK-LABEL: func @eval_slice_wild_stride +func.func @eval_slice_wild_stride() -> tensor<1x1x1xi64> { + // CHECK-NOT: stablehlo.slice + // CHECK: [[RESULT:%.*]] = stablehlo.constant dense<2> : tensor<1x1x1xi64> + // CHECK: return [[RESULT]] + %0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64> + %1 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x2x2xi64>) -> tensor<1x1x1xi64> + func.return %1 : tensor<1x1x1xi64> +} + +// ----- + +// CHECK-LABEL: func @eval_slice_unit_prefix +func.func @eval_slice_unit_prefix() -> (tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>) { + // CHECK-NOT: stablehlo.slice + // CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<{{\[\[\[}}[1, 2]]]]> : tensor<1x1x1x2xi64> + // CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<{{\[\[\[}}[7, 8]]]]> : tensor<1x1x1x2xi64> + // CHECK: [[RESULT3:%.*]] = stablehlo.constant dense<{{\[\[\[}}[11, 12]]]]> : tensor<1x1x1x2xi64> + // CHECK: return [[RESULT1]], [[RESULT2]], [[RESULT3]] + %0 = stablehlo.constant dense<[[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]]> : tensor<1x3x2x2xi64> + + %1 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64> + + %2 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64> + + %3 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64> + + func.return %1, %2, %3 : tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64> +} + +// ----- + +// CHECK-LABEL: func @eval_slice_non_unit_prefix +func.func @eval_slice_non_unit_prefix() -> tensor<1x2x1xi64> { + // CHECK: stablehlo.constant {{.*}} : tensor<1x2x2xi64> + // CHECK: [[RESULT:%.*]] = stablehlo.slice{{.*}} + // CHECK: return [[RESULT]] + %0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64> + %1 = "stablehlo.slice"(%0) { + start_indices = array, + limit_indices = array, + strides = array + } : (tensor<1x2x2xi64>) -> tensor<1x2x1xi64> + func.return %1 : tensor<1x2x1xi64> +} + +// ----- + // CHECK-LABEL: func @eval_subtract func.func @eval_subtract() -> tensor { // CHECK-NOT: stablehlo.subtract diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index bc196efd3b..753e8983e1 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -45,6 +45,7 @@ add_mlir_dialect_library(StablehloPasses MLIRArithDialect MLIRAsmParser MLIRComplexDialect + MLIRDialectUtils MLIRFuncDialect MLIRFunctionInterfaces MLIRIR diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index a1b8809455..dc4b3d2aad 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -29,6 +29,7 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "llvm/Support/FormatVariadic.h" #include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/BuiltinTypes.h" @@ -577,20 +578,42 @@ struct EvalSliceOpPattern : public OpRewritePattern { LogicalResult matchAndRewrite(SliceOp op, PatternRewriter& rewriter) const override { auto resultType = op.getType(); - if (!resultType.hasRank() || resultType.getRank() != 1) - return rewriter.notifyMatchFailure(op, "expected 1-dimensional type"); - - SmallVector operand; - if (failed(hlo::matchInts(op.getOperand(), operand))) + if (resultType.getRank() < 1) + return rewriter.notifyMatchFailure( + op, "expected non-0 ranked tensor result type"); + + auto operand = op.getOperand().cast>(); + RankedTensorType operandType = operand.getType(); + if (!operandType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "expected operand with static ranked tensor type"); + + // A ranked tensor type with unit dimension prefix of R-1 size is physically + // compatible with 1-dimensional type. + if (!llvm::all_of(resultType.getShape().drop_back(), + [](int64_t s) { return s == 1; })) + return rewriter.notifyMatchFailure( + op, "expected 1-dimensional compatible result type"); + + SmallVector operandData; + if (failed(hlo::matchInts(operand, operandData))) return rewriter.notifyMatchFailure(op, "expected constant operand"); - int64_t start = op.getStartIndices()[0]; - int64_t limit = op.getLimitIndices()[0]; - int64_t stride = op.getStrides()[0]; + const auto dimOffsets = computeSuffixProduct(operandType.getShape()); + auto startIndices = op.getStartIndices(); + auto limitIndices = op.getLimitIndices(); + auto strides = op.getStrides(); + + int64_t start = 0; + for (size_t i = 0; i < startIndices.size(); ++i) + start += startIndices[i] * dimOffsets[i]; + + auto slicedDim = operandType.getRank() - 1; + int64_t limit = start + limitIndices[slicedDim] - startIndices[slicedDim]; + int64_t stride = strides[slicedDim]; SmallVector result; - for (auto i = start; i < limit; i += stride) { - result.push_back(operand[i]); - } + for (auto i = start; i < limit; i += stride) + result.push_back(operandData[i]); rewriter.replaceOpWithNewOp(op, getTensorAttr(resultType, result));