Skip to content

Commit

Permalink
Skip conversion of shape.shapeof with 0-ranked tensor operand (#2107)
Browse files Browse the repository at this point in the history
Currently `--shape-legalize-to-stablehlo` fails on the following code:

```
func.func @test1() -> tensor<0xindex> {
  %0 = arith.constant dense<0> : tensor<i32>
  %1 = shape.shape_of %0 : tensor<i32> -> tensor<0xindex>
  func.return %1 : tensor<0xindex>
}
```
at `stablehlo/dialect/TypeInference.cpp:1700`:
```
// concatenate_c5
auto elementType = inputTypes[0].cast<ShapedType>().getElementType();
```
as `inputTypes.size()` is zero.

I have checked how it works on non-0 ranked tensor type:
```
func.func @test2() -> tensor<2xindex> {
  %1 = arith.constant dense<0> : tensor<2x128xi32>
  %3 = shape.shape_of %1 : tensor<2x128xi32> -> tensor<2xindex>
  func.return %3 : tensor<2xindex>
}
```
produces:
```
func.func @test2() -> tensor<2xindex> {
  %cst = arith.constant dense<0> : tensor<2x128xi32>
  %0 = stablehlo.get_dimension_size %cst, dim = 0 : (tensor<2x128xi32>) -> tensor<i32>
  %1 = stablehlo.reshape %0 : (tensor<i32>) -> tensor<1xi32>
  %2 = stablehlo.get_dimension_size %cst, dim = 1 : (tensor<2x128xi32>) -> tensor<i32>
  %3 = stablehlo.reshape %2 : (tensor<i32>) -> tensor<1xi32>
  %4 = stablehlo.concatenate %1, %3, dim = 0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<2xi32>
  %5 = builtin.unrealized_conversion_cast %4 : tensor<2xi32> to tensor<2xindex>
    return %5 : tensor<2xindex>
}
```

I suggest considering an alternative approach instead of simply bailing
out; one option could be generating a constant tensor with zero
dimension:
```
func.func @test1() -> tensor<0xindex> {
  %cst = arith.constant dense<0> : tensor<i32>
  %0 = stablehlo.constant dense<> : tensor<0xi32>
  %1 = builtin.unrealized_conversion_cast %0 : tensor<0xi32> to tensor<0xindex>
  return %1 : tensor<0xindex>
}
```
but i am not entirely certain in semantic equivalence.
  • Loading branch information
mvpant authored Mar 18, 2024
1 parent 9c1bccf commit ec4ec78
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 10 deletions.
12 changes: 12 additions & 0 deletions stablehlo/tests/shape_legalize_to_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -387,3 +387,15 @@ func.func @tensor_extract_dynamic(%arg0: tensor<?x3xindex>) -> index {
%0 = tensor.extract %arg0[%c1, %c2] : tensor<?x3xindex>
return %0 : index
}

// -----

// CHECK-LABEL: func @shape_of_zero_ranked_tensor
func.func @shape_of_zero_ranked_tensor(%arg0: tensor<?x3xindex>) -> tensor<0xindex> {
// CHECK: %[[CONST:.*]] = stablehlo.constant dense<> : tensor<0xi32>
// CHECK-NEXT: %[[RES_DIM0_INDEX:.*]] = builtin.unrealized_conversion_cast %[[CONST]] : tensor<0xi32> to tensor<0xindex>
// CHECK-NEXT: return %[[RES_DIM0_INDEX]] : tensor<0xindex>
%0 = arith.constant dense<0> : tensor<i32>
%1 = shape.shape_of %0 : tensor<i32> -> tensor<0xindex>
func.return %1 : tensor<0xindex>
}
28 changes: 18 additions & 10 deletions stablehlo/transforms/ShapeLegalizeToStablehlo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -256,17 +256,25 @@ struct ConvertShapeOfOpPattern : public OpRewritePattern<shape::ShapeOfOp> {
// Produce a StableHLO equivalent of this shape::ShapeOfOp.
// This is a very laborious representation because StableHLO is currently
// lacking convenient tools to express this.
SmallVector<Value> sizesI32x1;
for (auto i = 0; i < operandType.getRank(); ++i) {
auto sizeI32 =
rewriter.create<GetDimensionSizeOp>(op.getLoc(), op.getArg(), i);
auto sizeI32x1 = rewriter.create<ReshapeOp>(
op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()),
sizeI32);
sizesI32x1.push_back(sizeI32x1);
Value shapeI32;
if (operandType.getRank() > 0) {
SmallVector<Value> sizesI32x1;
for (auto i = 0; i < operandType.getRank(); ++i) {
auto sizeI32 =
rewriter.create<GetDimensionSizeOp>(op.getLoc(), op.getArg(), i);
auto sizeI32x1 = rewriter.create<ReshapeOp>(
op.getLoc(), RankedTensorType::get({1}, rewriter.getI32Type()),
sizeI32);
sizesI32x1.push_back(sizeI32x1);
}
shapeI32 = rewriter.create<ConcatenateOp>(op.getLoc(), sizesI32x1,
/*dimension=*/0);
} else {
shapeI32 = rewriter.create<ConstantOp>(
op.getLoc(), DenseElementsAttr::get(
RankedTensorType::get({0}, rewriter.getI32Type()),
ArrayRef<Attribute>()));
}
auto shapeI32 = rewriter.create<ConcatenateOp>(op.getLoc(), sizesI32x1,
/*dimension=*/0);

// Cast result from tensor<Nxi32> to tensor<Nxindex>.
// This will error out if the result is !shape.shape.
Expand Down

0 comments on commit ec4ec78

Please sign in to comment.