Skip to content

Commit

Permalink
Fix ConditionalSpeculatability implementation for DynamicReshape (#2261)
Browse files Browse the repository at this point in the history
Both the input and the shape must be fully known statically. Whether the
output is dynamic or not does not matter.

Indeed, consider e.g.:

```
func.func @foo(%arg0: tensor<?x?xf64>) {
    %constant_shape = stablehlo.constant dense<[2, 3]> : tensor<2xi32>
    %0 = stablehlo.dynamic_reshape %arg0, %constant_shape : (tensor<?x?xf64>, tensor<2xi32>) -> tensor<?x?xf64>
    return
}
```

The input is dynamic, so it could turn out to have e.g. 10 elements
instead of the expected 6. Similarly, if the shape is unknown:

```
func.func @foo(%arg0: tensor<2x3xf64>, %unknown_shape: tensor<2xi32>) {
    %0 = stablehlo.dynamic_reshape %arg0, %unknown_shape : (tensor<2x3xf64>, tensor<2xi32>) -> tensor<?x?xf64>
    return
}
```

Again, the shape could turn out to be e.g. `[2, 5]` at runtime and so
the reshape's behavior would be undefined.
  • Loading branch information
mlevesquedion authored Apr 29, 2024
1 parent b6406a4 commit 7d95a16
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 12 deletions.
12 changes: 2 additions & 10 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1510,18 +1510,10 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
}

mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() {
// If the output type's shape is fully dynamic, there is no expectation
// for the shape so the op is speculatable.
if (llvm::all_of(llvm::seq(getType().getRank()),
[this](int64_t i) { return getType().isDynamicDim(i); }))
return mlir::Speculation::Speculatable;

// If the input is static and the shape operand is constant, the output
// shape can be inferred and any mismatch will be caught statically.
// If any dimension in the input is dynamic, the number of elements may
// disagree with either the output.
// If the shape operand is not constant, it could disagree with the output,
// which has at least 1 static dimension at this point in the function.
// If any dimension in the input is dynamic, or if the shape is not known,
// the number of elements may disagree at runtime.
if (getOperand().getType().hasStaticShape() &&
matchPattern(getOutputShape(), m_Constant()))
return mlir::Speculation::Speculatable;
Expand Down
4 changes: 2 additions & 2 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1750,13 +1750,13 @@ func.func @dynamic_reshape(
%2 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor<?x?xf64>, tensor<2xi32>) -> tensor<5x4xf64>
"hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<5x4xf64>) -> ()
%3 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor<?x?xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%3) : (tensor<?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%3) : (tensor<?x?xf64>) -> ()

// Unknown shape
%4 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64>
"hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<5x4xf64>) -> ()
%5 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%5) : (tensor<?x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%5) : (tensor<?x?xf64>) -> ()

return
}
Expand Down

0 comments on commit 7d95a16

Please sign in to comment.