diff --git a/stablehlo/dialect/Base.cpp b/stablehlo/dialect/Base.cpp index e5cf93f97d..f938f2c233 100644 --- a/stablehlo/dialect/Base.cpp +++ b/stablehlo/dialect/Base.cpp @@ -611,5 +611,22 @@ bool isSplatArray(ArrayRef arr, int64_t val) { [val](int64_t x) { return x == val; }); } +mlir::Speculation::Speculatability getShapedSpeculatability( + Operation* op, int64_t shapeCount) { + // If all inputs are static and the shape-related operands are constant + // then any relationship between the input, the shapes and the output can be + // verified statically. + bool allInputsStatic = llvm::all_of(op->getOperandTypes(), [](Type t) { + return cast(t).hasStaticShape(); + }); + bool allShapesConstant = llvm::all_of(llvm::seq(shapeCount), [&](int64_t i) { + return matchPattern(op->getOperand(op->getNumOperands() - 1 - i), + m_Constant()); + }); + return allInputsStatic && allShapesConstant + ? mlir::Speculation::Speculatable + : mlir::Speculation::NotSpeculatable; +} + } // namespace hlo } // namespace mlir diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index 1956fbf062..7b045db706 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -255,6 +255,13 @@ void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) { } } // namespace bytecode +// Determines the speculatability for a shaped operation `op` with `shapeCount` +// shape operands. The last `count` operands are assumed to be shape operands. +// To be speculatable, such an op must have only static inputs and constant +// shape operands. +mlir::Speculation::Speculatability getShapedSpeculatability(Operation *op, + int64_t shapeCount); + namespace OpTrait { template @@ -470,6 +477,16 @@ struct RecursivelySpeculatableIfAllInputsStaticImplTrait } }; +template +struct SpeculatableIfAllInputsStaticAndShapeConstantImplTrait + : public mlir::OpTrait::TraitBase< + ConcreteType, + SpeculatableIfAllInputsStaticAndShapeConstantImplTrait> { + mlir::Speculation::Speculatability getSpeculatability() { + return getShapedSpeculatability(this->getOperation(), 1); + } +}; + } // namespace OpTrait } // namespace hlo } // namespace mlir diff --git a/stablehlo/dialect/Base.td b/stablehlo/dialect/Base.td index 5c668d83e0..29bbf0c1ef 100644 --- a/stablehlo/dialect/Base.td +++ b/stablehlo/dialect/Base.td @@ -368,4 +368,13 @@ def HLO_RecursivelySpeculatableIfAllInputsStaticImplTrait def HLO_RecursivelySpeculatableIfAllInputsStatic : TraitList<[ ConditionallySpeculatable, HLO_RecursivelySpeculatableIfAllInputsStaticImplTrait]>; +def HLO_SpeculatableIfAllInputsStaticAndShapeConstantImplTrait + : HLO_NativeOpTrait<"SpeculatableIfAllInputsStaticAndShapeConstantImplTrait">; + +// This trait is the same as HLO_SpeculatableIfAllInputsStatic, but for ops that +// take a shape as their last operand. Such ops are speculatable if all inputs +// are static and the shape is constant. +def HLO_SpeculatableIfAllInputsStaticAndShapeConstant : TraitList<[ + ConditionallySpeculatable, HLO_SpeculatableIfAllInputsStaticAndShapeConstantImplTrait]>; + #endif // STABLEHLO_DIALECT_BASE diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index 86a8da90ff..adc829d801 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1532,18 +1532,6 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes( return success(); } -mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() { - // If the input is static and the shape operand is constant, the output - // shape can be inferred and any mismatch will be caught statically. - // If any dimension in the input is dynamic, or if the shape is not known, - // the number of elements may disagree at runtime. - if (getOperand().getType().hasStaticShape() && - matchPattern(getOutputShape(), m_Constant())) - return mlir::Speculation::Speculatable; - - return mlir::Speculation::NotSpeculatable; -} - //===----------------------------------------------------------------------===// // DynamicSliceOp //===----------------------------------------------------------------------===// @@ -1613,6 +1601,10 @@ LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes( return success(); } +mlir::Speculation::Speculatability RealDynamicSliceOp::getSpeculatability() { + return hlo::getShapedSpeculatability(getOperation(), /*shapeCount=*/3); +} + //===----------------------------------------------------------------------===// // InfeedOp //===----------------------------------------------------------------------===// @@ -2160,6 +2152,10 @@ LogicalResult DynamicPadOp::reifyReturnTypeShapes( return success(); } +mlir::Speculation::Speculatability DynamicPadOp::getSpeculatability() { + return hlo::getShapedSpeculatability(getOperation(), /*shapeCount=*/3); +} + //===----------------------------------------------------------------------===// // ReshapeOp //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 3d73557b4e..dabf8b3ef7 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -145,7 +145,7 @@ def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [Condit $output_shape `,` `dim` `=` $iota_dimension attr-dict `:` functional-type(operands, results) }]; - let extraClassDeclaration = [{ + let extraClassDeclaration = commonClassDeclaration # [{ /// Interface method for ConditionallySpeculatable. mlir::Speculation::Speculatability getSpeculatability(); }]; @@ -2653,7 +2653,7 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape", } def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", - [ConditionallySpeculatable, NoMemoryEffect]> { + [HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> { let summary = "DynamicReshape operation"; let description = [{ This operation is a work in progress, so it is not yet included in @@ -2675,11 +2675,6 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape", let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; - - let extraClassDeclaration = commonClassDeclaration # [{ - /// Interface method for ConditionallySpeculatable. - mlir::Speculation::Speculatability getSpeculatability(); - }]; } def StableHLO_ScatterOp: StableHLO_Op<"scatter", @@ -3375,7 +3370,8 @@ def StableHLO_ReducePrecisionOp : StableHLO_Op<"reduce_precision", def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp< "real_dynamic_slice", - [Pure, AllElementTypesMatch<["operand", "result"]>, + [ConditionallySpeculatable, NoMemoryEffect, + AllElementTypesMatch<["operand", "result"]>, AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> { let summary = "RealDynamicSlice operation"; let description = [{ @@ -3403,10 +3399,16 @@ def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp< let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = commonClassDeclaration # [{ + /// Interface method for ConditionallySpeculatable. + mlir::Speculation::Speculatability getSpeculatability(); + }]; } def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad", - [Pure, AllElementTypesMatch<["operand", "padding_value", "result"]>, + [ConditionallySpeculatable, NoMemoryEffect, + AllElementTypesMatch<["operand", "padding_value", "result"]>, AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> { let summary = "DynamicPad operation"; let description = [{ @@ -3440,10 +3442,16 @@ def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad", let hasVerifier = 1; let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)"; + + let extraClassDeclaration = commonClassDeclaration # [{ + /// Interface method for ConditionallySpeculatable. + mlir::Speculation::Speculatability getSpeculatability(); + }]; } def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather", - [InferTensorTypeWithReify, Pure]> { + [HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect, + InferTensorTypeWithReify]> { let summary = "DynamicGather operation"; let description = [{ This operation is a work in progress, so it is not yet included in @@ -3476,7 +3484,8 @@ def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather", let results = (outs HLO_Tensor); } -def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [Pure]> { +def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", + [HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> { let summary = "DynamicConv operation"; let description = [{ This operation is a work in progress, so it is not yet included in diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index 6da14350e2..98314835d4 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -1514,7 +1514,6 @@ func.func @convolution( window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []} {batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo, #stablehlo]} : (tensor<100x26x26x32xf64>, tensor) -> tensor<100x?x?x2xf64> "hlo_test_speculatability.is_speculatable"(%14) : (tensor<100x?x?x2xf64>) -> () - return } @@ -1565,7 +1564,6 @@ func.func @dot_general( "hlo_test_speculatability.is_not_speculatable"(%10) : (tensor<2x4x1x6x7x8xf64>) -> () %11 = stablehlo.dot_general %large_static_lhs, %large_dynamic_rhs, batching_dims = [1, 3] x [0, 4], contracting_dims = [2, 4] x [2, 1], precision = [DEFAULT, DEFAULT] : (tensor<1x2x3x4x5x6xf64>, tensor<2x5x3x?x4x?xf64>) -> tensor<2x4x1x6x7x8xf64> "hlo_test_speculatability.is_not_speculatable"(%11) : (tensor<2x4x1x6x7x8xf64>) -> () - return } @@ -1777,6 +1775,146 @@ func.func @dynamic_broadcast_in_dim( // ----- +// CHECK-LABEL: func @dynamic_conv +// CHECK-NEXT: return +func.func @dynamic_conv( + %static_input: tensor<100x26x26x32xf64>, %static_kernel: tensor<3x3x1x32xf64>, + %dynamic_input: tensor, %dynamic_kernel: tensor, + %unknown_shape: tensor<2x2xi32> +) { + %constant_shape = stablehlo.constant dense<2> : tensor<2x2xi32> + + // Static inputs, constant shape + %0 = "stablehlo.dynamic_conv"(%static_input, %static_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor<100x28x28x1xf64> + "hlo_test_speculatability.is_speculatable"(%0) : (tensor<100x28x28x1xf64>) -> () + %1 = "stablehlo.dynamic_conv"(%static_input, %static_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input, static kernel, constant shape + %2 = "stablehlo.dynamic_conv"(%dynamic_input, %static_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor<100x28x28x1xf64> + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<100x28x28x1xf64>) -> () + %3 = "stablehlo.dynamic_conv"(%dynamic_input, %static_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () + + // Static input, dynamic kernel, constant shape + %4 = "stablehlo.dynamic_conv"(%static_input, %dynamic_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor, tensor<2x2xi32>) -> tensor<100x28x28x1xf64> + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<100x28x28x1xf64>) -> () + %5 = "stablehlo.dynamic_conv"(%static_input, %dynamic_kernel, %constant_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor, tensor<2x2xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () + + // Static input, static kernel, unknown shape + %6 = "stablehlo.dynamic_conv"(%static_input, %static_kernel, %unknown_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor<100x28x28x1xf64> + "hlo_test_speculatability.is_not_speculatable"(%6) : (tensor<100x28x28x1xf64>) -> () + %7 = "stablehlo.dynamic_conv"(%static_input, %static_kernel, %unknown_shape) { + dimension_numbers = #stablehlo.conv<[b, 0, 1, f]x[0, 1, o, i]->[b, 0, 1, f]>, + window_strides = array, lhs_dilation = array, rhs_dilation = array, + feature_group_count = 1 : i64, batch_group_count = 1 : i64 + } : (tensor<100x26x26x32xf64>, tensor<3x3x1x32xf64>, tensor<2x2xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%7) : (tensor) -> () + return +} + +// ----- + +// CHECK-LABEL: func @dynamic_gather +// CHECK-NEXT: return +func.func @dynamic_gather( + %static_input: tensor<3x4x2xi32>, %static_indices: tensor<2x3x2xi64>, + %dynamic_input: tensor, %dynamic_indices: tensor, + %unknown_slice_sizes: tensor<3xi32> +) { + %constant_slice_sizes = stablehlo.constant dense<[1, 2, 2]> : tensor<3xi32> + + // Static inputs, constant shape + %0 = "stablehlo.dynamic_gather"(%static_input, %static_indices, %constant_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%0) : (tensor) -> () + %1 = "stablehlo.dynamic_gather"(%static_input, %static_indices, %constant_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = true + } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input, static start_indices, constant slice_sizes + %2 = "stablehlo.dynamic_gather"(%dynamic_input, %static_indices, %constant_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor, tensor<2x3x2xi64>, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor) -> () + + // Static input, dynamic start_indices, constant slice_sizes + %3 = "stablehlo.dynamic_gather"(%static_input, %dynamic_indices, %constant_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xi32>, tensor, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () + + // Static input, static start_indices, unknown slice_sizes + %4 = "stablehlo.dynamic_gather"(%static_input, %static_indices, %unknown_slice_sizes) { + dimension_numbers = #stablehlo.gather< + offset_dims = [2, 3], + collapsed_slice_dims = [0], + start_index_map = [1, 0], + index_vector_dim = 2>, + slice_sizes = array, + indices_are_sorted = false + } : (tensor<3x4x2xi32>, tensor<2x3x2xi64>, tensor<3xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor) -> () + return +} + +// ----- + // CHECK-LABEL: func @dynamic_iota // CHECK-NEXT: return func.func @dynamic_iota(%unknown_shape: tensor<2xi32>) { @@ -1829,7 +1967,64 @@ func.func @set_dimension_size( "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () %6 = stablehlo.set_dimension_size %dynamic_arg, %unknown_size, dim = 0 : (tensor<4x?xf64>, tensor) -> tensor "hlo_test_speculatability.is_speculatable"(%6) : (tensor) -> () + return +} +// ----- + +// CHECK-LABEL: func @dynamic_pad +// CHECK-NEXT: return +func.func @dynamic_pad( + %static_arg: tensor<4xf64>, %dynamic_arg: tensor, + %padding_value: tensor, %unknown_padding: tensor<1xi32> +) { + %constant_padding = stablehlo.constant dense<0> : tensor<1xi32> + + // Static input, constant padding + %0 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %constant_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_speculatable"(%0) : (tensor<4xf64>) -> () + %1 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %constant_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input, constant padding + %2 = stablehlo.dynamic_pad %dynamic_arg, %padding_value, + %unknown_padding, %unknown_padding, %unknown_padding + : (tensor, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<4xf64>) -> () + %3 = stablehlo.dynamic_pad %dynamic_arg, %padding_value, + %unknown_padding, %unknown_padding, %unknown_padding + : (tensor, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () + + // Static input, unknown paddings + %4 = stablehlo.dynamic_pad %static_arg, %padding_value, + %unknown_padding, %constant_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<4xf64>) -> () + %5 = stablehlo.dynamic_pad %static_arg, %padding_value, + %unknown_padding, %constant_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () + %6 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %unknown_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_not_speculatable"(%6) : (tensor<4xf64>) -> () + %7 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %unknown_padding, %constant_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%7) : (tensor) -> () + %8 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %constant_padding, %unknown_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<4xf64> + "hlo_test_speculatability.is_not_speculatable"(%8) : (tensor<4xf64>) -> () + %9 = stablehlo.dynamic_pad %static_arg, %padding_value, + %constant_padding, %constant_padding, %unknown_padding + : (tensor<4xf64>, tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%9) : (tensor) -> () return } @@ -1860,7 +2055,54 @@ func.func @dynamic_reshape( "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<5x4xf64>) -> () %5 = stablehlo.dynamic_reshape %static_arg, %unknown_shape : (tensor<4x5xf64>, tensor<2xi32>) -> tensor "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () + return +} + +// ----- +// CHECK-LABEL: func @real_dynamic_slice +// CHECK-NEXT: return +func.func @real_dynamic_slice( + %static_arg: tensor<4xf64>, %dynamic_arg: tensor, + %unknown_value: tensor<1xi32> +) { + %constant_value = stablehlo.constant dense<1> : tensor<1xi32> + + // Static input, constant values + %0 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %constant_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_speculatable"(%0) : (tensor<0xf64>) -> () + %1 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %constant_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Dynamic input, constant values + %2 = stablehlo.real_dynamic_slice %dynamic_arg, %constant_value, %constant_value, %constant_value + : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<0xf64>) -> () + %3 = stablehlo.real_dynamic_slice %dynamic_arg, %constant_value, %constant_value, %constant_value + : (tensor, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%3) : (tensor) -> () + + // Static input, unknown paddings + %4 = stablehlo.real_dynamic_slice %static_arg, %unknown_value, %constant_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<0xf64>) -> () + %5 = stablehlo.real_dynamic_slice %static_arg, %unknown_value, %constant_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%5) : (tensor) -> () + %6 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %unknown_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_not_speculatable"(%6) : (tensor<0xf64>) -> () + %7 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %unknown_value, %constant_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%7) : (tensor) -> () + %8 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %constant_value, %unknown_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<0xf64> + "hlo_test_speculatability.is_not_speculatable"(%8) : (tensor<0xf64>) -> () + %9 = stablehlo.real_dynamic_slice %static_arg, %constant_value, %constant_value, %unknown_value + : (tensor<4xf64>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%9) : (tensor) -> () return } @@ -2048,7 +2290,6 @@ func.func @select_and_scatter( window_strides = array } : (tensor, tensor<10x12x12x64xf64>, tensor) -> tensor "hlo_test_speculatability.is_recursively_speculatable"(%3) : (tensor) -> () - return }