Skip to content

Commit

Permalink
Add support of tensor types with unit dimension prefix in EvalSlice (#…
Browse files Browse the repository at this point in the history
…2090)

Edit: Sure, sorry for missing description.

Currently our small research team experimenting with BERT model
represented in number of *-hlo dialects and wants to simplify it in
terms of variety of operators. This PR fixes an issue we stumbled upon:

```
%42 = stablehlo.constant dense<"..."> : tensor<1x512xi64>
...
%66 = stablehlo.slice %42 [0:1, 0:128] : (tensor<1x512xi64>) -> tensor<1x128xi64> 
```
Previous implementation of EvalSlice can't handle such case -- tensor
type prefixed with unit dimension(s) i.e. 1x128.

This PR adds support of the above case and can slice from any position,
e.g.

```
  %0 = stablehlo.constant dense<[[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]]> : tensor<1x3x2x2xi64>
  %1 = "stablehlo.slice"(%0) {
    start_indices = array<i64: 0, 1, 1, 0>,
    limit_indices = array<i64: 1, 2, 2, 2>,
    strides = array<i64: 1, 1, 1, 1>
  } : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64>
```
is folded to 
```
%1 = stablehlo.constant dense<[[[[7, 8]]]]> : tensor<1x1x1x2xi64>
```
  • Loading branch information
mvpant authored Mar 19, 2024
1 parent 3442dbe commit 969b9ad
Show file tree
Hide file tree
Showing 4 changed files with 100 additions and 11 deletions.
1 change: 1 addition & 0 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -963,6 +963,7 @@ cc_library(
"@llvm-project//mlir:AsmParser",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:FunctionInterfaces",
"@llvm-project//mlir:IR",
Expand Down
64 changes: 64 additions & 0 deletions stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -424,6 +424,70 @@ func.func @eval_slice() -> tensor<2xi64> {

// -----

// CHECK-LABEL: func @eval_slice_wild_stride
func.func @eval_slice_wild_stride() -> tensor<1x1x1xi64> {
// CHECK-NOT: stablehlo.slice
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<2> : tensor<1x1x1xi64>
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64>
%1 = "stablehlo.slice"(%0) {
start_indices = array<i64: 0, 0, 1>,
limit_indices = array<i64: 1, 1, 2>,
strides = array<i64: 99, 42, 1>
} : (tensor<1x2x2xi64>) -> tensor<1x1x1xi64>
func.return %1 : tensor<1x1x1xi64>
}

// -----

// CHECK-LABEL: func @eval_slice_unit_prefix
func.func @eval_slice_unit_prefix() -> (tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>) {
// CHECK-NOT: stablehlo.slice
// CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<{{\[\[\[}}[1, 2]]]]> : tensor<1x1x1x2xi64>
// CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<{{\[\[\[}}[7, 8]]]]> : tensor<1x1x1x2xi64>
// CHECK: [[RESULT3:%.*]] = stablehlo.constant dense<{{\[\[\[}}[11, 12]]]]> : tensor<1x1x1x2xi64>
// CHECK: return [[RESULT1]], [[RESULT2]], [[RESULT3]]
%0 = stablehlo.constant dense<[[[[1, 2], [3, 4]], [[5, 6], [7, 8]], [[9, 10], [11, 12]]]]> : tensor<1x3x2x2xi64>

%1 = "stablehlo.slice"(%0) {
start_indices = array<i64: 0, 0, 0, 0>,
limit_indices = array<i64: 1, 1, 1, 2>,
strides = array<i64: 1, 1, 1, 1>
} : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64>

%2 = "stablehlo.slice"(%0) {
start_indices = array<i64: 0, 1, 1, 0>,
limit_indices = array<i64: 1, 2, 2, 2>,
strides = array<i64: 1, 1, 1, 1>
} : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64>

%3 = "stablehlo.slice"(%0) {
start_indices = array<i64: 0, 2, 1, 0>,
limit_indices = array<i64: 1, 3, 2, 2>,
strides = array<i64: 1, 1, 1, 1>
} : (tensor<1x3x2x2xi64>) -> tensor<1x1x1x2xi64>

func.return %1, %2, %3 : tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>, tensor<1x1x1x2xi64>
}

// -----

// CHECK-LABEL: func @eval_slice_non_unit_prefix
func.func @eval_slice_non_unit_prefix() -> tensor<1x2x1xi64> {
// CHECK: stablehlo.constant {{.*}} : tensor<1x2x2xi64>
// CHECK: [[RESULT:%.*]] = stablehlo.slice{{.*}}
// CHECK: return [[RESULT]]
%0 = stablehlo.constant dense<[[[1, 2], [3, 4]]]> : tensor<1x2x2xi64>
%1 = "stablehlo.slice"(%0) {
start_indices = array<i64: 0, 0, 1>,
limit_indices = array<i64: 1, 2, 2>,
strides = array<i64: 1, 1, 1>
} : (tensor<1x2x2xi64>) -> tensor<1x2x1xi64>
func.return %1 : tensor<1x2x1xi64>
}

// -----

// CHECK-LABEL: func @eval_subtract
func.func @eval_subtract() -> tensor<i64> {
// CHECK-NOT: stablehlo.subtract
Expand Down
1 change: 1 addition & 0 deletions stablehlo/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ add_mlir_dialect_library(StablehloPasses
MLIRArithDialect
MLIRAsmParser
MLIRComplexDialect
MLIRDialectUtils
MLIRFuncDialect
MLIRFunctionInterfaces
MLIRIR
Expand Down
45 changes: 34 additions & 11 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ limitations under the License.
#include "llvm/Support/ErrorHandling.h"
#include "llvm/Support/FormatVariadic.h"
#include "mlir/Dialect/Func/IR/FuncOps.h"
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/BuiltinAttributes.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/BuiltinTypes.h"
Expand Down Expand Up @@ -577,20 +578,42 @@ struct EvalSliceOpPattern : public OpRewritePattern<SliceOp> {
LogicalResult matchAndRewrite(SliceOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();
if (!resultType.hasRank() || resultType.getRank() != 1)
return rewriter.notifyMatchFailure(op, "expected 1-dimensional type");

SmallVector<APSInt> operand;
if (failed(hlo::matchInts(op.getOperand(), operand)))
if (resultType.getRank() < 1)
return rewriter.notifyMatchFailure(
op, "expected non-0 ranked tensor result type");

auto operand = op.getOperand().cast<TypedValue<RankedTensorType>>();
RankedTensorType operandType = operand.getType();
if (!operandType.hasStaticShape())
return rewriter.notifyMatchFailure(
op, "expected operand with static ranked tensor type");

// A ranked tensor type with unit dimension prefix of R-1 size is physically
// compatible with 1-dimensional type.
if (!llvm::all_of(resultType.getShape().drop_back(),
[](int64_t s) { return s == 1; }))
return rewriter.notifyMatchFailure(
op, "expected 1-dimensional compatible result type");

SmallVector<APSInt> operandData;
if (failed(hlo::matchInts(operand, operandData)))
return rewriter.notifyMatchFailure(op, "expected constant operand");

int64_t start = op.getStartIndices()[0];
int64_t limit = op.getLimitIndices()[0];
int64_t stride = op.getStrides()[0];
const auto dimOffsets = computeSuffixProduct(operandType.getShape());
auto startIndices = op.getStartIndices();
auto limitIndices = op.getLimitIndices();
auto strides = op.getStrides();

int64_t start = 0;
for (size_t i = 0; i < startIndices.size(); ++i)
start += startIndices[i] * dimOffsets[i];

auto slicedDim = operandType.getRank() - 1;
int64_t limit = start + limitIndices[slicedDim] - startIndices[slicedDim];
int64_t stride = strides[slicedDim];
SmallVector<APSInt> result;
for (auto i = start; i < limit; i += stride) {
result.push_back(operand[i]);
}
for (auto i = start; i < limit; i += stride)
result.push_back(operandData[i]);

rewriter.replaceOpWithNewOp<ConstantOp>(op,
getTensorAttr(resultType, result));
Expand Down

0 comments on commit 969b9ad

Please sign in to comment.