Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for Transpose (#2199)
Browse files Browse the repository at this point in the history
  • Loading branch information
mlevesquedion authored Apr 12, 2024
1 parent e4d3c1e commit 2d7c943
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 1 deletion.
13 changes: 13 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2146,6 +2146,19 @@ LogicalResult TransposeOp::inferReturnTypes(
adaptor.getPermutation(), inferredReturnTypes);
}

mlir::Speculation::Speculatability TransposeOp::getSpeculatability() {
// This is the same logic as SpeculatableIfStaticDimInOutputIsStaticInInput,
// except it accounts for the permutation.
auto inputType = getOperand().getType();
auto resultType = getType();
auto perm = getPermutation();
for (size_t i : llvm::seq(resultType.getRank())) {
if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(perm[i]))
return mlir::Speculation::NotSpeculatable;
}
return mlir::Speculation::Speculatable;
}

//===----------------------------------------------------------------------===//
// TriangularSolveOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 7 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2909,7 +2909,8 @@ def StableHLO_TraceOp: StableHLO_Op<"trace"> {
}

def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose",
[Pure, HLO_CompatibleOperandsAndResultElementType, /*transpose_c1*/
[Pure,
HLO_CompatibleOperandsAndResultElementType, /*transpose_c1*/
DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
let summary = "Transpose operation";
let description = [{
Expand All @@ -2936,6 +2937,11 @@ def StableHLO_TransposeOp: StableHLO_ShapedInterfaceOp<"transpose",
$operand `,` `dims` `=` $permutation
attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_TriangularSolveOp: StableHLO_Op<"triangular_solve",
Expand Down
37 changes: 37 additions & 0 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -709,6 +709,43 @@ func.func @all_to_all(%static_arg: tensor<2x4x8xf64>, %dynamic_arg: tensor<?x?x?

// -----

// CHECK-LABEL: func @transpose
// CHECK-NEXT: return
func.func @transpose(
%static: tensor<2x4xf64>, %first_dim_dynamic: tensor<?x4xf64>,
%second_dim_dynamic: tensor<2x?xf64>, %dynamic: tensor<?x?xf64>,
%three_d: tensor<1x2x3xf64>, %three_d_dynamic: tensor<1x2x?xf64>
) {
%speculatable_0 = stablehlo.transpose %static, dims = [1, 0] : (tensor<2x4xf64>) -> tensor<4x2xf64>
%not_speculatable_0 = stablehlo.transpose %second_dim_dynamic, dims = [1, 0] : (tensor<2x?xf64>) -> tensor<4x2xf64>
%speculatable_1 = stablehlo.transpose %second_dim_dynamic, dims = [1, 0] : (tensor<2x?xf64>) -> tensor<?x2xf64>
%not_speculatable_1 = stablehlo.transpose %first_dim_dynamic, dims = [1, 0] : (tensor<?x4xf64>) -> tensor<4x2xf64>
%speculatable_2 = stablehlo.transpose %first_dim_dynamic, dims = [1, 0] : (tensor<?x4xf64>) -> tensor<4x?xf64>
%not_speculatable_2 = stablehlo.transpose %dynamic, dims = [1, 0] : (tensor<?x?xf64>) -> tensor<4x2xf64>
%not_speculatable_3 = stablehlo.transpose %dynamic, dims = [1, 0] : (tensor<?x?xf64>) -> tensor<?x2xf64>
%not_speculatable_4 = stablehlo.transpose %dynamic, dims = [1, 0] : (tensor<?x?xf64>) -> tensor<4x?xf64>
%speculatable_3 = stablehlo.transpose %dynamic, dims = [1, 0] : (tensor<?x?xf64>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%speculatable_0) : (tensor<4x2xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_0) : (tensor<4x2xf64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_1) : (tensor<?x2xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_1) : (tensor<4x2xf64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_2) : (tensor<4x?xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_2) : (tensor<4x2xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_3) : (tensor<?x2xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_4) : (tensor<4x?xf64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_3) : (tensor<?x?xf64>) -> ()

%speculatable_4 = stablehlo.transpose %three_d, dims = [1, 0, 2] : (tensor<1x2x3xf64>) -> tensor<2x1x3xf64>
%not_speculatable_5 = stablehlo.transpose %three_d_dynamic, dims = [1, 0, 2] : (tensor<1x2x?xf64>) -> tensor<2x1x3xf64>
%speculatable_5 = stablehlo.transpose %three_d_dynamic, dims = [1, 0, 2] : (tensor<1x2x?xf64>) -> tensor<2x1x?xf64>
"hlo_test_speculatability.is_speculatable"(%speculatable_4) : (tensor<2x1x3xf64>) -> ()
"hlo_test_speculatability.is_not_speculatable"(%not_speculatable_5) : (tensor<2x1x3xf64>) -> ()
"hlo_test_speculatability.is_speculatable"(%speculatable_5) : (tensor<2x1x?xf64>) -> ()
return
}

// -----

// BinaryElementwise and BinaryBitwiseOrLogicalElementwise ops

// -----
Expand Down

0 comments on commit 2d7c943

Please sign in to comment.