diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index da5b84da27..e2739a4b86 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -1910,6 +1910,28 @@ LogicalResult SetDimensionSizeOp::inferReturnTypeComponents( adaptor.getSize(), adaptor.getDimension(), inferredReturnShapes); } +mlir::Speculation::Speculatability SetDimensionSizeOp::getSpeculatability() { + // If the dimension being set is constant, the verifier will have checked that + // it matches the corresponding dimension in the output. + if (matchPattern(getSize(), m_Constant())) + return mlir::Speculation::Speculatable; + + // If the dimension being set is not constant, it is only speculatable if it + // is dynamic in the output. + auto resultType = getType(); + if (!resultType.isDynamicDim(getDimension())) return mlir::Speculation::NotSpeculatable; + + // For all other dimensions, if the dimension is static in the output, it must + // be static in the input. + auto inputType = getOperand().getType(); + for (size_t i : llvm::seq(resultType.getRank())) { + if (i == getDimension()) continue; + if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(i)) + return mlir::Speculation::NotSpeculatable; + } + return mlir::Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // TransposeOp //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index dce23c766f..d29b4b9728 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2786,7 +2786,8 @@ def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter", let hasVerifier = 1; } -def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", [Pure, +def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", + [ConditionallySpeculatable, NoMemoryEffect, InferTensorType]> { let summary = "SetDimensionSize operation"; let description = [{ @@ -2812,6 +2813,11 @@ def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", [Pure, $operand `,` $size `,` `dim` `=` $dimension attr-dict `:` functional-type(operands, results) }]; + + let extraClassDeclaration = commonClassDeclaration # [{ + /// Interface method for ConditionallySpeculatable. + mlir::Speculation::Speculatability getSpeculatability(); + }]; } def StableHLO_SortOp : StableHLO_Op<"sort", diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index 3dcffb1eda..bb4d569013 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -1594,6 +1594,35 @@ func.func @dynamic_iota(%unknown_shape: tensor<2xi32>) { return } +// CHECK-LABEL: func @set_dimension_size +// CHECK-NEXT: return +func.func @set_dimension_size( + %static_arg: tensor<2x3xf64>, %dynamic_arg: tensor<2x?xf64>, + %unknown_size: tensor +) { + %constant_size = stablehlo.constant dense<4> : tensor + + // Unknown size + %0 = stablehlo.set_dimension_size %static_arg, %unknown_size, dim = 0 : (tensor<2x3xf64>, tensor) -> tensor<2x3xf64> + "hlo_test_speculatability.is_not_speculatable"(%0) : (tensor<2x3xf64>) -> () + %1 = stablehlo.set_dimension_size %static_arg, %unknown_size, dim = 0 : (tensor<2x3xf64>, tensor) -> tensor + "hlo_test_speculatability.is_speculatable"(%1) : (tensor) -> () + + // Constant size + %2 = stablehlo.set_dimension_size %static_arg, %constant_size, dim = 0 : (tensor<2x3xf64>, tensor) -> tensor<4x3xf64> + "hlo_test_speculatability.is_speculatable"(%2) : (tensor<4x3xf64>) -> () + %3 = stablehlo.set_dimension_size %static_arg, %constant_size, dim = 0 : (tensor<2x3xf64>, tensor) -> tensor + "hlo_test_speculatability.is_speculatable"(%3) : (tensor) -> () + + // Dimension not being set is dynamic + %4 = stablehlo.set_dimension_size %dynamic_arg, %unknown_size, dim = 0 : (tensor<2x?xf64>, tensor) -> tensor + "hlo_test_speculatability.is_not_speculatable"(%4) : (tensor) -> () + %5 = stablehlo.set_dimension_size %dynamic_arg, %unknown_size, dim = 0 : (tensor<2x?xf64>, tensor) -> tensor + "hlo_test_speculatability.is_speculatable"(%5) : (tensor) -> () + + return +} + // ----- // Recursively speculatable ops