From 7d95a16568f44cd0f2f997bfe18345dcaa645ad6 Mon Sep 17 00:00:00 2001 From: mlevesquedion Date: Mon, 29 Apr 2024 09:37:27 -0700 Subject: [PATCH] Fix ConditionalSpeculatability implementation for DynamicReshape (#2261) Both the input and the shape must be fully known statically. Whether the output is dynamic or not does not matter. Indeed, consider e.g.: ``` func.func @foo(%arg0: tensor) { %constant_shape = stablehlo.constant dense<[2, 3]> : tensor<2xi32> %0 = stablehlo.dynamic_reshape %arg0, %constant_shape : (tensor, tensor<2xi32>) -> tensor return } ``` The input is dynamic, so it could turn out to have e.g. 10 elements instead of the expected 6. Similarly, if the shape is unknown: ``` func.func @foo(%arg0: tensor<2x3xf64>, %unknown_shape: tensor<2xi32>) { %0 = stablehlo.dynamic_reshape %arg0, %unknown_shape : (tensor<2x3xf64>, tensor<2xi32>) -> tensor return } ``` Again, the shape could turn out to be e.g. `[2, 5]` at runtime and so the reshape's behavior would be undefined. --- stablehlo/dialect/StablehloOps.cpp | 12 ++---------- stablehlo/tests/ops_speculatability.mlir | 4 ++-- 2 files changed, 4 insertions(+), 12 deletions(-) diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 7406fb9edf..f5d35b0a72 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1510,18 +1510,10 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( } mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() { - // If the output type's shape is fully dynamic, there is no expectation - // for the shape so the op is speculatable. - if (llvm::all_of(llvm::seq(getType().getRank()), - [this](int64_t i) { return getType().isDynamicDim(i); })) - return mlir::Speculation::Speculatable; - // If the input is static and the shape operand is constant, the output // shape can be inferred and any mismatch will be caught statically. - // If any dimension in the input is dynamic, the number of elements may - // disagree with either the output. - // If the shape operand is not constant, it could disagree with the output, - // which has at least 1 static dimension at this point in the function. + // If any dimension in the input is dynamic, or if the shape is not known, + // the number of elements may disagree at runtime. if (getOperand().getType().hasStaticShape() && matchPattern(getOutputShape(), m_Constant())) return mlir::Speculation::Speculatable; diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index fae177f5db..1dd3fb90c7 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -1750,13 +1750,13 @@ func.func @dynamic_reshape( %2 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor, tensor<2xi32>) -> tensor<5x4xf64> "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<5x4xf64>) -> () %3 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor, tensor<2xi32>) -> tensor - "hlo_test_speculatability.is_speculatable"(%3) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () // Unknown shape %4 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64> "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<5x4xf64>) -> () %5 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor - "hlo_test_speculatability.is_speculatable"(%5) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () return }