diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 0dfde89765..965d457f52 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1508,6 +1508,26 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( return success(); } +mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() { + // If the output type's shape is fully dynamic, there is no expectation + // for the shape so the op is speculatable. + if (llvm::all_of(llvm::seq(getType().getRank()), + [this](int64_t i) { return getType().isDynamicDim(i); })) + return mlir::Speculation::Speculatable; + + // If the input is static and the shape operand is constant, the output + // shape can be inferred and any mismatch will be caught statically. + // If any dimension in the input is dynamic, the number of elements may + // disagree with either the output. + // If the shape operand is not constant, it could disagree with the output, + // which has at least 1 static dimension at this point in the function. + if (getOperand().getType().hasStaticShape() && + matchPattern(getOutputShape(), m_Constant())) + return mlir::Speculation::Speculatable; + + return mlir::Speculation::NotSpeculatable; +} + //===----------------------------------------------------------------------===// // DynamicSliceOp //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 701a35c0a4..2f5384b60d 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2637,7 +2637,8 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape", }]; } -def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [Pure]> { +def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", + [ConditionallySpeculatable, NoMemoryEffect]> { let summary = "DynamicReshape operation"; let description = [{ This operation is a work in progress, so it is not yet included in @@ -2659,6 +2660,11 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", [ let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = commonClassDeclaration # [{ + /// Interface method for ConditionallySpeculatable. + mlir::Speculation::Speculatability getSpeculatability(); + }]; } def StableHLO_ScatterOp: StableHLO_Op<"scatter", diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index f142c77922..7457903fb0 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -1701,6 +1701,37 @@ func.func @dynamic_iota(%unknown_shape: tensor<2xi32>) { // ----- +// CHECK-LABEL: func @dynamic_reshape +// CHECK-NEXT: return +func.func @dynamic_reshape( + %static_arg: tensor<4x5xf64>, %dynamic_arg: tensor, + %unknown_shape: tensor<2xi32> +) { + %constant_shape = stablehlo.constant dense<[5, 4]> : tensor<2xi32> + + // Static input, constant shape + %0 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64> + "hlo_test_speculatability.is_speculatable"(%0) : (tensor<5x4xf64>) -> () + %1 = stablehlo.dynamic_reshape %static_arg, %constant_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input + %2 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor, tensor<2xi32>) -> tensor<5x4xf64> + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<5x4xf64>) -> () + %3 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor, tensor<2xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%3) : (tensor) -> () + + // Unknown shape + %4 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64> + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<5x4xf64>) -> () + %5 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%5) : (tensor) -> () + + return +} + +// ----- + // Recursively speculatable ops // -----