diff --git a/stablehlo/dialect/StablehloOps.cpp b/stablehlo/dialect/StablehloOps.cpp index a68f6b60bb..b5af46b677 100644 --- a/stablehlo/dialect/StablehloOps.cpp +++ b/stablehlo/dialect/StablehloOps.cpp @@ -2146,6 +2146,19 @@ LogicalResult TransposeOp::inferReturnTypes( adaptor.getPermutation(), inferredReturnTypes); } +mlir::Speculation::Speculatability TransposeOp::getSpeculatability() { + // This is the same logic as SpeculatableIfStaticDimInOutputIsStaticInInput, + // except it accounts for the permutation. + auto inputType = getOperand().getType(); + auto resultType = getType(); + auto perm = getPermutation(); + for (size_t i : llvm::seq(resultType.getRank())) { + if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(perm[i])) + return mlir::Speculation::NotSpeculatable; + } + return mlir::Speculation::Speculatable; +} + //===----------------------------------------------------------------------===// // TriangularSolveOp //===----------------------------------------------------------------------===// diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index 6c3f5f26a7..dade4b80d7 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2909,7 +2909,8 @@ def StableHLO_TraceOp: StableHLO_Op<"trace"> { } def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose", - [Pure, HLO_CompatibleOperandsAndResultElementType, /*transpose_c1*/ + [Pure, + HLO_CompatibleOperandsAndResultElementType, /*transpose_c1*/ DeclareOpInterfaceMethods]> { let summary = "Transpose operation"; let description = [{ @@ -2936,6 +2937,11 @@ def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose", $operand `,` `dims` `=` $permutation attr-dict `:` functional-type(operands, results) }]; + + let extraClassDeclaration = commonClassDeclaration # [{ + /// Interface method for ConditionallySpeculatable. + mlir::Speculation::Speculatability getSpeculatability(); + }]; } def StableHLO_TriangularSolveOp: StableHLO_Op<"triangular_solve", diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index f3fe2d80fa..eb79b7bfc2 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -709,6 +709,43 @@ func.func @all_to_all(%static_arg: tensor<2x4x8xf64>, %dynamic_arg: tensor, %first_dim_dynamic: tensor, + %second_dim_dynamic: tensor<2x?xf64>, %dynamic: tensor, + %three_d: tensor<1x2x3xf64>, %three_d_dynamic: tensor<1x2x?xf64> +) { + %speculatable_0 = stablehlo.transpose %static, dims = [1, 0] : (tensor<2x4xf64>) -> tensor<4x2xf64> + %not_speculatable_0 = stablehlo.transpose %second_dim_dynamic, dims = [1, 0] : (tensor<2x?xf64>) -> tensor<4x2xf64> + %speculatable_1 = stablehlo.transpose %second_dim_dynamic, dims = [1, 0] : (tensor<2x?xf64>) -> tensor + %not_speculatable_1 = stablehlo.transpose %first_dim_dynamic, dims = [1, 0] : (tensor) -> tensor<4x2xf64> + %speculatable_2 = stablehlo.transpose %first_dim_dynamic, dims = [1, 0] : (tensor) -> tensor<4x?xf64> + %not_speculatable_2 = stablehlo.transpose %dynamic, dims = [1, 0] : (tensor) -> tensor<4x2xf64> + %not_speculatable_3 = stablehlo.transpose %dynamic, dims = [1, 0] : (tensor) -> tensor + %not_speculatable_4 = stablehlo.transpose %dynamic, dims = [1, 0] : (tensor) -> tensor<4x?xf64> + %speculatable_3 = stablehlo.transpose %dynamic, dims = [1, 0] : (tensor) -> tensor + "hlo_test_speculatability.is_speculatable"(%speculatable_0) : (tensor<4x2xf64>) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_0) : (tensor<4x2xf64>) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_1) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_1) : (tensor<4x2xf64>) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_2) : (tensor<4x?xf64>) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_2) : (tensor<4x2xf64>) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_3) : (tensor) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_4) : (tensor<4x?xf64>) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_3) : (tensor) -> () + + %speculatable_4 = stablehlo.transpose %three_d, dims = [1, 0, 2] : (tensor<1x2x3xf64>) -> tensor<2x1x3xf64> + %not_speculatable_5 = stablehlo.transpose %three_d_dynamic, dims = [1, 0, 2] : (tensor<1x2x?xf64>) -> tensor<2x1x3xf64> + %speculatable_5 = stablehlo.transpose %three_d_dynamic, dims = [1, 0, 2] : (tensor<1x2x?xf64>) -> tensor<2x1x?xf64> + "hlo_test_speculatability.is_speculatable"(%speculatable_4) : (tensor<2x1x3xf64>) -> () + "hlo_test_speculatability.is_not_speculatable"(%not_speculatable_5) : (tensor<2x1x3xf64>) -> () + "hlo_test_speculatability.is_speculatable"(%speculatable_5) : (tensor<2x1x?xf64>) -> () + return +} + +// ----- + // BinaryElementwise and BinaryBitwiseOrLogicalElementwise ops // -----