From 5c4a2de37165c4434b7940a36f508a683792a204 Mon Sep 17 00:00:00 2001 From: Michael Levesque-Dion Date: Thu, 18 Apr 2024 15:00:29 -0700 Subject: [PATCH] Implement ConditionallySpeculatable for remaining dynamic ops Included ops: - DynamicPad - RealDynamicSlice - DynamicConv - DynamicGather I refactored the logic to check speculatability for shaped ops to enable reuse and allow ops with more than one shape-related operand to be checked. This should be the last change adding new speculatability implementations. All the other ops are either done, pure, or deprecated. I will confirm this shortly by reviewing the entire opset and making sure everything is covered (I have been tracking progress in a personal document). --- stablehlo/dialect/Base.cpp | 17 ++ stablehlo/dialect/Base.h | 18 ++ stablehlo/dialect/Base.td | 9 + stablehlo/dialect/StablehloOps.cpp | 28 +-- stablehlo/dialect/StablehloOps.td | 31 ++- stablehlo/tests/ops_speculatability.mlir | 254 ++++++++++++++++++++++- 6 files changed, 324 insertions(+), 33 deletions(-) diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index e5cf93f97d..f938f2c233 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -611,5 +611,22 @@ bool isSplatArray(ArrayRef arr, int64_t val) { [val](int64_t x) { return x == val; }); } +mlir::Speculation::Speculatability getShapedSpeculatability( + Operation* op, int64_t shapeCount) { + // If all inputs are static and the shape-related operands are constant + // then any relationship between the input, the shapes and the output can be + // verified statically. + bool allInputsStatic = llvm::all_of(op->getOperandTypes(), [](Type t) { + return cast(t).hasStaticShape(); + }); + bool allShapesConstant = llvm::all_of(llvm::seq(shapeCount), [&](int64_t i) { + return matchPattern(op->getOperand(op->getNumOperands() - 1 - i), + m_Constant()); + }); + return allInputsStatic && allShapesConstant + ? mlir::Speculation::Speculatable + : mlir::Speculation::NotSpeculatable; +} + } // namespace hlo } // namespace mlir diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index 36ce89c126..b2b7decbdf 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -251,6 +251,13 @@ void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) { } } // namespace bytecode +// Determines the speculatability for a shaped operation `op` with `shapeCount` +// shape operands. The last `count` operands are assumed to be shape operands. +// To be speculatable, such an op must either have a fully dynamic result type +// or have only static inputs and constant shape operands. +mlir::Speculation::Speculatability getShapedSpeculatability(Operation *op, + int64_t shapeCount); + namespace OpTrait { template @@ -466,6 +473,17 @@ struct RecursivelySpeculatableIfAllInputsStaticImplTrait } }; +template +struct SpeculatableIfAllInputsStaticAndShapeConstantImplTrait + : public mlir::OpTrait::TraitBase< + ConcreteType, + SpeculatableIfAllInputsStaticAndShapeConstantImplTrait> { + mlir::Speculation::Speculatability getSpeculatability() { + auto op = this->getOperation(); + return getShapedSpeculatability(op, 1); + } +}; + } // namespace OpTrait } // namespace hlo } // namespace mlir diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 5c668d83e0..8b324e080d 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -368,4 +368,13 @@ def HLO_RecursivelySpeculatableIfAllInputsStaticImplTrait def HLO_RecursivelySpeculatableIfAllInputsStatic : TraitList<[ ConditionallySpeculatable, HLO_RecursivelySpeculatableIfAllInputsStaticImplTrait]>; +def HLO_SpeculatableIfAllInputsStaticAndShapeConstantImplTrait + : HLO_NativeOpTrait<"SpeculatableIfAllInputsStaticAndShapeConstantImplTrait">; + +// This trait is the same as HLO_SpeculatableIfAllInputsStatic, but for ops that +// take a shape as their last operand. Such ops are speculatable if either the +// output is dynamic or all inputs are static and the shape is constant. +def HLO_SpeculatableIfAllInputsStaticAndShapeConstant : TraitList<[ + ConditionallySpeculatable, HLO_SpeculatableIfAllInputsStaticAndShapeConstantImplTrait]>; + #endif // STABLEHLO_DIALECT_BASE diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 7406fb9edf..c79359c9ad 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1509,26 +1509,6 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( return success(); } -mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() { - // If the output type's shape is fully dynamic, there is no expectation - // for the shape so the op is speculatable. - if (llvm::all_of(llvm::seq(getType().getRank()), - [this](int64_t i) { return getType().isDynamicDim(i); })) - return mlir::Speculation::Speculatable; - - // If the input is static and the shape operand is constant, the output - // shape can be inferred and any mismatch will be caught statically. - // If any dimension in the input is dynamic, the number of elements may - // disagree with either the output. - // If the shape operand is not constant, it could disagree with the output, - // which has at least 1 static dimension at this point in the function. - if (getOperand().getType().hasStaticShape() && - matchPattern(getOutputShape(), m_Constant())) - return mlir::Speculation::Speculatable; - - return mlir::Speculation::NotSpeculatable; -} - //===----------------------------------------------------------------------===// // DynamicSliceOp //===----------------------------------------------------------------------===// @@ -1598,6 +1578,10 @@ LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes( return success(); } +mlir::Speculation::Speculatability RealDynamicSliceOp::getSpeculatability() { + return hlo::getShapedSpeculatability(getOperation(), /*count=*/3); +} + //===----------------------------------------------------------------------===// // InfeedOp //===----------------------------------------------------------------------===// @@ -2145,6 +2129,10 @@ LogicalResult DynamicPadOp::reifyReturnTypeShapes( return success(); } +mlir::Speculation::Speculatability DynamicPadOp::getSpeculatability() { + return hlo::getShapedSpeculatability(getOperation(), /*count=*/3); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index a4d2ec4e4e..7ec559de56 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -145,7 +145,7 @@ def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [Condit $output_shape `,` `dim` `=` $iota_dimension attr-dict `:` functional-type(operands, results) }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = commonClassDeclaration # [{ /// Interface method for ConditionallySpeculatable. mlir::Speculation::Speculatability getSpeculatability(); }]; @@ -2646,7 +2646,7 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape", } def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", - [ConditionallySpeculatable, NoMemoryEffect]> { + [HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> { let summary = "DynamicReshape operation"; let description = [{ This operation is a work in progress, so it is not yet included in @@ -2668,11 +2668,6 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; - - let extraClassDeclaration = commonClassDeclaration # [{ - /// Interface method for ConditionallySpeculatable. - mlir::Speculation::Speculatability getSpeculatability(); - }]; } def StableHLO_ScatterOp: StableHLO_Op<"scatter", @@ -3368,7 +3363,8 @@ def StableHLO_ReducePrecisionOp : StableHLO_Op<"reduce_precision", def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp< "real_dynamic_slice", - [Pure, AllElementTypesMatch<["operand", "result"]>, + [ConditionallySpeculatable, NoMemoryEffect, + AllElementTypesMatch<["operand", "result"]>, AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { let summary = "RealDynamicSlice operation"; let description = [{ @@ -3396,10 +3392,16 @@ def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp< let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = commonClassDeclaration # [{ + /// Interface method for ConditionallySpeculatable. + mlir::Speculation::Speculatability getSpeculatability(); + }]; } def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad", - [Pure, AllElementTypesMatch<["operand", "padding_value", "result"]>, + [ConditionallySpeculatable, NoMemoryEffect, + AllElementTypesMatch<["operand", "padding_value", "result"]>, AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> { let summary = "DynamicPad operation"; let description = [{ @@ -3433,10 +3435,16 @@ def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad", let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = commonClassDeclaration # [{ + /// Interface method for ConditionallySpeculatable. + mlir::Speculation::Speculatability getSpeculatability(); + }]; } def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather", - [InferTensorTypeWithReify, Pure]> { + [HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect, + InferTensorTypeWithReify]> { let summary = "DynamicGather operation"; let description = [{ This operation is a work in progress, so it is not yet included in @@ -3469,7 +3477,8 @@ def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather", let results = (outs HLO_Tensor); } -def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [Pure]> { +def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", + [HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> { let summary = "DynamicConv operation"; let description = [{ This operation is a work in progress, so it is not yet included in diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index fae177f5db..76432d1199 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -1676,6 +1676,148 @@ func.func @scatter( // ----- +// CHECK-LABEL: func @dynamic_conv +// CHECK-NEXT: return +func.func @dynamic_conv( + %static_input: tensor<100x26x26x32xf64>, %static_kernel: tensor<3x3x1x32xf64>, + %dynamic_input: tensor, %dynamic_kernel: tensor, + %unknown_shape: tensor<2x2xi32> +) { + %constant_shape = stablehlo.constant dense<2> : tensor<2x2xi32> + + // Static inputs, constant shape + %0 = "stablehlo.dynamic_conv"(%static_input, %static_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor<100x28x28x1xf64> + "hlo_test_speculatability.is_speculatable"(%0) : (tensor<100x28x28x1xf64>) -> () + %1 = "stablehlo.dynamic_conv"(%static_input, %static_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input, static kernel, constant shape + %2 = "stablehlo.dynamic_conv"(%dynamic_input, %static_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor<100x28x28x1xf64> + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<100x28x28x1xf64>) -> () + %3 = "stablehlo.dynamic_conv"(%dynamic_input, %static_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () + + // Static input, dynamic kernel, constant shape + %4 = "stablehlo.dynamic_conv"(%static_input, %dynamic_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor, tensor<2x2xi32>) -> tensor<100x28x28x1xf64> + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<100x28x28x1xf64>) -> () + %5 = "stablehlo.dynamic_conv"(%static_input, %dynamic_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor, tensor<2x2xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () + + // Static input, static kernel, unknown shape + %6 = "stablehlo.dynamic_conv"(%static_input, %static_kernel, %unknown_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor<100x28x28x1xf64> + "hlo_test_speculatability.is_not_speculatable"(%6) : (tensor<100x28x28x1xf64>) -> () + %7 = "stablehlo.dynamic_conv"(%static_input, %static_kernel, %unknown_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%7) : (tensor) -> () + + return +} + +// ----- + +// CHECK-LABEL: func @dynamic_gather +// CHECK-NEXT: return +func.func @dynamic_gather( + %static_input: tensor<3x4x2xi32>, %static_indices: tensor<2x3x2xi64>, + %dynamic_input: tensor, %dynamic_indices: tensor, + %unknown_slice_sizes: tensor<3xi32> +) { + %constant_slice_sizes = stablehlo.constant dense<[1, 2, 2]> : tensor<3xi32> + + // Static inputs, constant shape + %0 = "stablehlo.dynamic_gather"(%static_input, %static_indices, %constant_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%0) : (tensor) -> () + %1 = "stablehlo.dynamic_gather"(%static_input, %static_indices, %constant_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input, static start_indices, constant slice_sizes + %2 = "stablehlo.dynamic_gather"(%dynamic_input, %static_indices, %constant_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor, tensor<2x3x2xi64>, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor) -> () + + // Static input, dynamic start_indices, constant slice_sizes + %3 = "stablehlo.dynamic_gather"(%static_input, %dynamic_indices, %constant_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xi32>, tensor, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () + + // Static input, static start_indices, unknown slice_sizes + %4 = "stablehlo.dynamic_gather"(%static_input, %static_indices, %unknown_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor) -> () + + return +} + +// ----- + // CHECK-LABEL: func @dynamic_iota // CHECK-NEXT: return func.func @dynamic_iota(%unknown_shape: tensor<2xi32>) { @@ -1732,6 +1874,65 @@ func.func @set_dimension_size( // ----- +// CHECK-LABEL: func @dynamic_pad +// CHECK-NEXT: return +func.func @dynamic_pad( + %static_arg: tensor<4xf64>, %dynamic_arg: tensor, + %padding_value: tensor, %unknown_padding: tensor<1xi32> +) { + %constant_padding = stablehlo.constant dense<0> : tensor<1xi32> + + // Static input, constant padding + %0 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %constant_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_speculatable"(%0) : (tensor<4xf64>) -> () + %1 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %constant_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input, constant padding + %2 = stablehlo.dynamic_pad %dynamic_arg, %padding_value, + %unknown_padding, %unknown_padding, %unknown_padding + : (tensor, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<4xf64>) -> () + %3 = stablehlo.dynamic_pad %dynamic_arg, %padding_value, + %unknown_padding, %unknown_padding, %unknown_padding + : (tensor, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () + + // Static input, unknown paddings + %4 = stablehlo.dynamic_pad %static_arg, %padding_value, + %unknown_padding, %constant_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<4xf64>) -> () + %5 = stablehlo.dynamic_pad %static_arg, %padding_value, + %unknown_padding, %constant_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () + %6 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %unknown_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_not_speculatable"(%6) : (tensor<4xf64>) -> () + %7 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %unknown_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%7) : (tensor) -> () + %8 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %constant_padding, %unknown_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_not_speculatable"(%8) : (tensor<4xf64>) -> () + %9 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %constant_padding, %unknown_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%9) : (tensor) -> () + + return +} + +// ----- + // CHECK-LABEL: func @dynamic_reshape // CHECK-NEXT: return func.func @dynamic_reshape( @@ -1750,13 +1951,62 @@ func.func @dynamic_reshape( %2 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor, tensor<2xi32>) -> tensor<5x4xf64> "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<5x4xf64>) -> () %3 = stablehlo.dynamic_reshape %dynamic_arg, %constant_shape : (tensor, tensor<2xi32>) -> tensor - "hlo_test_speculatability.is_speculatable"(%3) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () // Unknown shape %4 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor<5x4xf64> "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<5x4xf64>) -> () %5 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor - "hlo_test_speculatability.is_speculatable"(%5) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () + + return +} + +// ----- + +// CHECK-LABEL: func @real_dynamic_slice +// CHECK-NEXT: return +func.func @real_dynamic_slice( + %static_arg: tensor<4xf64>, %dynamic_arg: tensor, + %unknown_value: tensor<1xi32> +) { + %constant_value = stablehlo.constant dense<1> : tensor<1xi32> + + // Static input, constant values + %0 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %constant_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_speculatable"(%0) : (tensor<0xf64>) -> () + %1 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %constant_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input, constant values + %2 = stablehlo.real_dynamic_slice %dynamic_arg, %constant_value, %constant_value, %constant_value + : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<0xf64>) -> () + %3 = stablehlo.real_dynamic_slice %dynamic_arg, %constant_value, %constant_value, %constant_value + : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () + + // Static input, unknown paddings + %4 = stablehlo.real_dynamic_slice %static_arg, %unknown_value, %constant_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<0xf64>) -> () + %5 = stablehlo.real_dynamic_slice %static_arg, %unknown_value, %constant_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () + %6 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %unknown_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_not_speculatable"(%6) : (tensor<0xf64>) -> () + %7 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %unknown_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%7) : (tensor) -> () + %8 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %constant_value, %unknown_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_not_speculatable"(%8) : (tensor<0xf64>) -> () + %9 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %constant_value, %unknown_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%9) : (tensor) -> () return }