Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for SelectAndScatter
Browse files Browse the repository at this point in the history
This op isn't specced yet but the speculation logic is
relatively straightforward and follows from the tablegen/type
inference/XLA documentation: https://openxla.org/xla/operation_semantics#setdimensionsize.
  • Loading branch information
mlevesquedion committed Apr 17, 2024
1 parent 5b75941 commit cdfb404
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 1 deletion.
22 changes: 22 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1910,6 +1910,28 @@ LogicalResult SetDimensionSizeOp::inferReturnTypeComponents(
adaptor.getSize(), adaptor.getDimension(), inferredReturnShapes);
}

mlir::Speculation::Speculatability SetDimensionSizeOp::getSpeculatability() {
// If the dimension being set is constant, the verifier will have checked that
// it matches the corresponding dimension in the output.
if (matchPattern(getSize(), m_Constant()))
return mlir::Speculation::Speculatable;

// If the dimension being set is not constant, it is only speculatable if it
// is dynamic in the output.
auto resultType = getType();
if (!resultType.isDynamicDim(getDimension())) return mlir::Speculation::NotSpeculatable;

// For all other dimensions, if the dimension is static in the output, it must
// be static in the input.
auto inputType = getOperand().getType();
for (size_t i : llvm::seq(resultType.getRank())) {
if (i == getDimension()) continue;
if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(i))
return mlir::Speculation::NotSpeculatable;
}
return mlir::Speculation::Speculatable;
}

//===----------------------------------------------------------------------===//
// TransposeOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 7 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2786,7 +2786,8 @@ def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter",
let hasVerifier = 1;
}

def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", [Pure,
def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size",
[ConditionallySpeculatable, NoMemoryEffect,
InferTensorType]> {
let summary = "SetDimensionSize operation";
let description = [{
Expand All @@ -2812,6 +2813,11 @@ def StableHLO_SetDimensionSizeOp: StableHLO_Op<"set_dimension_size", [Pure,
$operand `,` $size `,` `dim` `=` $dimension attr-dict
`:` functional-type(operands, results)
}];

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_SortOp : StableHLO_Op<"sort",
Expand Down
29 changes: 29 additions & 0 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1594,6 +1594,35 @@ func.func @dynamic_iota(%unknown_shape: tensor<2xi32>) {
return
}

// CHECK-LABEL: func @set_dimension_size
// CHECK-NEXT: return
func.func @set_dimension_size(
%static_arg: tensor<2x3xf64>, %dynamic_arg: tensor<2x?xf64>,
%unknown_size: tensor<i32>
) {
%constant_size = stablehlo.constant dense<4> : tensor<i32>

// Unknown size
%0 = stablehlo.set_dimension_size %static_arg, %unknown_size, dim = 0 : (tensor<2x3xf64>, tensor<i32>) -> tensor<2x3xf64>
"hlo_test_speculatability.is_not_speculatable"(%0) : (tensor<2x3xf64>) -> ()
%1 = stablehlo.set_dimension_size %static_arg, %unknown_size, dim = 0 : (tensor<2x3xf64>, tensor<i32>) -> tensor<?x3xf64>
"hlo_test_speculatability.is_speculatable"(%1) : (tensor<?x3xf64>) -> ()

// Constant size
%2 = stablehlo.set_dimension_size %static_arg, %constant_size, dim = 0 : (tensor<2x3xf64>, tensor<i32>) -> tensor<4x3xf64>
"hlo_test_speculatability.is_speculatable"(%2) : (tensor<4x3xf64>) -> ()
%3 = stablehlo.set_dimension_size %static_arg, %constant_size, dim = 0 : (tensor<2x3xf64>, tensor<i32>) -> tensor<?x3xf64>
"hlo_test_speculatability.is_speculatable"(%3) : (tensor<?x3xf64>) -> ()

// Dimension not being set is dynamic
%4 = stablehlo.set_dimension_size %dynamic_arg, %unknown_size, dim = 0 : (tensor<2x?xf64>, tensor<i32>) -> tensor<?x3xf64>
"hlo_test_speculatability.is_not_speculatable"(%4) : (tensor<?x3xf64>) -> ()
%5 = stablehlo.set_dimension_size %dynamic_arg, %unknown_size, dim = 0 : (tensor<2x?xf64>, tensor<i32>) -> tensor<?x?xf64>
"hlo_test_speculatability.is_speculatable"(%5) : (tensor<?x?xf64>) -> ()

return
}

// -----

// Recursively speculatable ops
Expand Down

0 comments on commit cdfb404

Please sign in to comment.