Skip to content

Commit

Permalink
Add specialization of ComputeReshapeShapeOp (#2102)
Browse files Browse the repository at this point in the history
This enables specialization of stablehlo.compute_reshape_shape operator
that operates on constant data:

```
func.func @test() -> (tensor<4xi32>, tensor<1xi32>) {
  %0 = arith.constant dense<[2, -1, 2, 64]> : tensor<4xi32>
  %1 = arith.constant dense<[-1]> : tensor<1xi32>
  %2 = arith.constant 32768 : index
  %3 = stablehlo.compute_reshape_shape %2, %0 : (index, tensor<4xi32>) -> tensor<4xi32>
  %4 = stablehlo.compute_reshape_shape %2, %1 : (index, tensor<1xi32>) -> tensor<1xi32>
  func.return %3, %4 : tensor<4xi32>, tensor<1xi32>
}
```
is transformed into:

```
func.func @test() -> (tensor<4xi32>, tensor<1xi32>) {
  %0 = stablehlo.constant dense<[2, 128, 2, 64]> : tensor<4xi32>
  %1 = stablehlo.constant dense<32768> : tensor<1xi32>
  return %0, %1 : tensor<4xi32>, tensor<1xi32>
}
```
  • Loading branch information
mvpant authored Mar 22, 2024
1 parent 7c8815a commit 1860c4c
Show file tree
Hide file tree
Showing 5 changed files with 147 additions and 2 deletions.
7 changes: 7 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,13 @@ bool isCompatibleForHloTypeInference(Value shape1, Type tp2) {
return isCompatibleForHloTypeInference(tp1, tp2);
}

LogicalResult matchInt(Value value, int64_t& result) {
APInt constValue;
if (!matchPattern(value, m_ConstantInt(&constValue))) return failure();
result = constValue.getSExtValue();
return success();
}

LogicalResult matchInts(Value value, SmallVector<int64_t>& result) {
DenseIntElementsAttr attr;
if (!matchPattern(value, m_Constant(&attr))) return failure();
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,9 @@ LogicalResult inferMostSpecificTypeComponents(
std::optional<Location> location, TypeRange inputTypes,
SmallVectorImpl<ShapedTypeComponents> &inferredReturnShapes);

// Matches a constant with integer value into int64_t.
LogicalResult matchInt(Value value, int64_t &result);

// Matches a constant tensor with integer values into a 1-dimensional vector.
// Doesn't preserve the bitness or the signedness of the underlying values,
// extracting them into int64_t.
Expand Down
5 changes: 3 additions & 2 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3332,8 +3332,9 @@ def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [Pure]> {
let results = (outs HLO_Tensor);
}

def StableHLO_ComputeReshapeShapeOp :
StableHLO_Op<"compute_reshape_shape", [Pure]> {
def StableHLO_ComputeReshapeShapeOp : StableHLO_Op<
"compute_reshape_shape",
[Pure, AllShapesMatch<["dynamic_shape", "result"]>]> {
let summary = "ComputeReshapeShape operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand Down
79 changes: 79 additions & 0 deletions stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,85 @@ func.func @eval_compare_lt() -> tensor<i1> {

// -----

// CHECK-LABEL: func @eval_compute_reshape_shape
func.func @eval_compute_reshape_shape() -> tensor<4xi32> {
// CHECK-NOT: stablehlo.compute_reshape_shape
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<[2, 128, 2, 64]> : tensor<4xi32>
// CHECK: return [[RESULT]]
%0 = arith.constant dense<[2, 128, 2, 64]> : tensor<4xi32>
%1 = arith.constant 32768 : index
%2 = stablehlo.compute_reshape_shape %1, %0 : (index, tensor<4xi32>) -> tensor<4xi32>
func.return %2 : tensor<4xi32>
}

// -----

// CHECK-LABEL: func @eval_compute_reshape_shape_zero_dynamic_shape
func.func @eval_compute_reshape_shape_zero_dynamic_shape() -> tensor<0xi32> {
// CHECK-NOT: stablehlo.compute_reshape_shape
// CHECK: [[RESULT:%.*]] = stablehlo.constant dense<> : tensor<0xi32>
// CHECK: return [[RESULT]]
%0 = arith.constant dense<[]> : tensor<0xi32>
%1 = arith.constant 32768 : index
%2 = stablehlo.compute_reshape_shape %1, %0 : (index, tensor<0xi32>) -> tensor<0xi32>
func.return %2 : tensor<0xi32>
}

// -----

// CHECK-LABEL: func @eval_compute_reshape_shape_unknown_dimension
func.func @eval_compute_reshape_shape_unknown_dimension() -> (tensor<4xi32>, tensor<1xi32>) {
// CHECK-NOT: stablehlo.compute_reshape_shape
// CHECK: [[RESULT1:%.*]] = stablehlo.constant dense<[2, 128, 2, 64]> : tensor<4xi32>
// CHECK: [[RESULT2:%.*]] = stablehlo.constant dense<32768> : tensor<1xi32>
// CHECK: return [[RESULT1]], [[RESULT2]]
%0 = arith.constant dense<[2, -1, 2, 64]> : tensor<4xi32>
%1 = arith.constant dense<[-1]> : tensor<1xi32>
%2 = arith.constant 32768 : index
%3 = stablehlo.compute_reshape_shape %2, %0 : (index, tensor<4xi32>) -> tensor<4xi32>
%4 = stablehlo.compute_reshape_shape %2, %1 : (index, tensor<1xi32>) -> tensor<1xi32>
func.return %3, %4 : tensor<4xi32>, tensor<1xi32>
}

// -----

// CHECK-LABEL: func @eval_compute_reshape_shape_two_unknown_dims
func.func @eval_compute_reshape_shape_two_unknown_dims() -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = stablehlo.compute_reshape_shape
// CHECK: return [[RESULT]]
%0 = arith.constant dense<[2, -1, -1, 64]> : tensor<4xi32>
%1 = arith.constant 32768 : index
%2 = stablehlo.compute_reshape_shape %1, %0 : (index, tensor<4xi32>) -> tensor<4xi32>
func.return %2 : tensor<4xi32>
}

// -----

// CHECK-LABEL: func @eval_compute_reshape_shape_non_divisible_shape
func.func @eval_compute_reshape_shape_non_divisible_shape() -> (tensor<4xi32>, tensor<4xi32>) {
// CHECK: [[RESULT1:%.*]] = stablehlo.compute_reshape_shape
// CHECK: [[RESULT2:%.*]] = stablehlo.compute_reshape_shape
// CHECK: return [[RESULT1]], [[RESULT2]]
%0 = arith.constant dense<[2, 128, 3, -1]> : tensor<4xi32>
%1 = arith.constant dense<[2, 128, 2, 63]> : tensor<4xi32>
%2 = arith.constant 32768 : index
%3 = stablehlo.compute_reshape_shape %2, %0 : (index, tensor<4xi32>) -> tensor<4xi32>
%4 = stablehlo.compute_reshape_shape %2, %1 : (index, tensor<4xi32>) -> tensor<4xi32>
func.return %3, %4 : tensor<4xi32>, tensor<4xi32>
}

// -----

// CHECK-LABEL: func @eval_compute_reshape_shape_non_specializable
func.func @eval_compute_reshape_shape_non_specializable(%arg0 : tensor<4xi32>, %arg1 : index) -> tensor<4xi32> {
// CHECK: [[RESULT:%.*]] = stablehlo.compute_reshape_shape
// CHECK: return [[RESULT]]
%0 = stablehlo.compute_reshape_shape %arg1, %arg0 : (index, tensor<4xi32>) -> tensor<4xi32>
func.return %0 : tensor<4xi32>
}

// -----

// CHECK-LABEL: func @eval_concatenate_1d
func.func @eval_concatenate_1d() -> tensor<4xi64> {
// CHECK-NOT: stablehlo.concatenate
Expand Down
55 changes: 55 additions & 0 deletions stablehlo/transforms/StablehloRefineShapes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -404,6 +404,60 @@ struct EvalCompareOpPattern : public OpRewritePattern<CompareOp> {
}
};

struct EvalComputeReshapeShapeOpPattern
: public OpRewritePattern<ComputeReshapeShapeOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ComputeReshapeShapeOp op,
PatternRewriter& rewriter) const override {
auto resultType = op.getType();

int64_t numElems;
if (failed(hlo::matchInt(op.getNumElements(), numElems)))
return rewriter.notifyMatchFailure(
op, "expected constant number of elements");

SmallVector<int64_t> dynShapeValues;
if (failed(hlo::matchInts(op.getDynamicShape(), dynShapeValues)))
return rewriter.notifyMatchFailure(op, "expected constant dynamic shape");

std::optional<size_t> unspecifiedDimIdx;
int64_t dimProduct = 1;
constexpr int64_t kUnspecifiedDimSize = -1;
for (size_t i = 0; i < dynShapeValues.size(); ++i) {
if (dynShapeValues[i] == kUnspecifiedDimSize) {
if (unspecifiedDimIdx.has_value())
return rewriter.notifyMatchFailure(
op, "multiple -1 values in dimensions is an undefined behavior");

unspecifiedDimIdx = i;
continue;
}

dimProduct *= dynShapeValues[i];
}

if (numElems % dimProduct != 0)
return rewriter.notifyMatchFailure(
op,
"dimensions that can't evenly divide num elements is an undefined "
"behavior");

if (unspecifiedDimIdx.has_value())
dynShapeValues[unspecifiedDimIdx.value()] = numElems / dimProduct;

const auto resultBitWidth = resultType.getElementTypeBitWidth();
auto result = llvm::to_vector(
llvm::map_range(dynShapeValues, [&](int64_t value) -> APSInt {
return APSInt(APInt(resultBitWidth, value), false);
}));

rewriter.replaceOpWithNewOp<ConstantOp>(op,
getTensorAttr(resultType, result));

return success();
}
};

struct EvalConcatenateOpPattern : public OpRewritePattern<ConcatenateOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ConcatenateOp op,
Expand Down Expand Up @@ -1195,6 +1249,7 @@ void populateStablehloRefineShapesPatterns(RewritePatternSet* patterns,
patterns->add<EvalBroadcastInDimOpPattern>(context);
patterns->add<EvalClampOpPattern>(context);
patterns->add<EvalCompareOpPattern>(context);
patterns->add<EvalComputeReshapeShapeOpPattern>(context);
patterns->add<EvalConcatenateOpPattern>(context);
patterns->add<EvalConvertOpPattern>(context);
patterns->add<EvalDivOpPattern>(context);
Expand Down

0 comments on commit 1860c4c

Please sign in to comment.