Skip to content

Commit

Permalink
Implement ConditionallySpeculatable for SelectAndScatter
Browse files Browse the repository at this point in the history
The spec says (C11): shape(operand) = shape(result).
  • Loading branch information
Michael Levesque-Dion committed Apr 17, 2024
1 parent 5b75941 commit 88e02e7
Show file tree
Hide file tree
Showing 4 changed files with 92 additions and 1 deletion.
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -421,6 +421,23 @@ struct SpeculatableIfStaticDimInOutputIsStaticInInputImplTrait
}
};

template <typename ConcreteType>
struct RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait
: public mlir::OpTrait::TraitBase<
ConcreteType,
RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait> {
mlir::Speculation::Speculatability getSpeculatability() {
auto op = this->getOperation();
auto inputType = cast<RankedTensorType>(op->getOperand(0).getType());
auto resultType = cast<RankedTensorType>(op->getResult(0).getType());
for (size_t i : llvm::seq(resultType.getRank())) {
if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(i))
return mlir::Speculation::NotSpeculatable;
}
return mlir::Speculation::RecursivelySpeculatable;
}
};

template <typename ConcreteType>
struct SpeculatableIfAllInputsStaticImplTrait
: public mlir::OpTrait::TraitBase<ConcreteType,
Expand Down
10 changes: 10 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -336,6 +336,16 @@ def HLO_SpeculatableIfStaticDimInOutputIsStaticInInputImplTrait
def HLO_SpeculatableIfStaticDimInOutputIsStaticInInput : TraitList<[
ConditionallySpeculatable, HLO_SpeculatableIfStaticDimInOutputIsStaticInInputImplTrait]>;

def HLO_RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait
: HLO_NativeOpTrait<"RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait">;

// This trait is the same as HLO_SpeculatableIfStaticDimInOutputIsStaticInInput,
// but for ops that have regions. If all static dimensions in the output are
// static in the input, such an op is RecursivelySpeculatable (the ops in its
// regions have to be checked for speculatability).
def HLO_RecursivelySpeculatableIfStaticDimInOutputIsStaticInInput : TraitList<[
ConditionallySpeculatable, HLO_RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait]>;

def HLO_SpeculatableIfAllInputsStaticImplTrait
: HLO_NativeOpTrait<"SpeculatableIfAllInputsStaticImplTrait">;

Expand Down
3 changes: 2 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2739,7 +2739,8 @@ def StableHLO_SelectOp: StableHLO_Op<"select",
}

def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter",
[DeclareOpInterfaceMethods<InferTypeOpInterface> /*select_and_scatter_c11,
[HLO_RecursivelySpeculatableIfStaticDimInOutputIsStaticInInput,
DeclareOpInterfaceMethods<InferTypeOpInterface> /*select_and_scatter_c11,
select_and_scatter_c12*/, RecursiveMemoryEffects]> {
let summary = "SelectAndScatter operation";
let description = [{
Expand Down
63 changes: 63 additions & 0 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1708,6 +1708,69 @@ func.func @reduce_window(%static_arg: tensor<2x4xf64>, %dynamic_arg: tensor<?x?x

// -----

// CHECK-LABEL: func @select_and_scatter
// CHECK-NEXT: return
func.func @select_and_scatter(
%static_arg: tensor<10x24x24x64xf64>, %dynamic_arg: tensor<?x?x?x?xf64>,
%source: tensor<10x12x12x64xf64>, %init: tensor<f64>
) {
%0 = "stablehlo.select_and_scatter"(%static_arg, %source, %init) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
%c0 = stablehlo.constant dense<false> : tensor<i1>
stablehlo.return %c0 : tensor<i1>
}, {
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
window_dimensions = array<i64: 1, 2, 2, 1>,
window_strides = array<i64: 1, 2, 2, 1>
} : (tensor<10x24x24x64xf64>, tensor<10x12x12x64xf64>, tensor<f64>) -> tensor<10x24x24x64xf64>
"hlo_test_speculatability.is_recursively_speculatable"(%0) : (tensor<10x24x24x64xf64>) -> ()

%1 = "stablehlo.select_and_scatter"(%static_arg, %source, %init) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
%c0 = stablehlo.constant dense<false> : tensor<i1>
stablehlo.return %c0 : tensor<i1>
}, {
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
window_dimensions = array<i64: 1, 2, 2, 1>,
window_strides = array<i64: 1, 2, 2, 1>
} : (tensor<10x24x24x64xf64>, tensor<10x12x12x64xf64>, tensor<f64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_recursively_speculatable"(%1) : (tensor<?x?x?x?xf64>) -> ()

%2 = "stablehlo.select_and_scatter"(%dynamic_arg, %source, %init) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
%c0 = stablehlo.constant dense<false> : tensor<i1>
stablehlo.return %c0 : tensor<i1>
}, {
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
window_dimensions = array<i64: 1, 2, 2, 1>,
window_strides = array<i64: 1, 2, 2, 1>
} : (tensor<?x?x?x?xf64>, tensor<10x12x12x64xf64>, tensor<f64>) -> tensor<10x24x24x64xf64>
"hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<10x24x24x64xf64>) -> ()

%3 = "stablehlo.select_and_scatter"(%dynamic_arg, %source, %init) ({
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
%c0 = stablehlo.constant dense<false> : tensor<i1>
stablehlo.return %c0 : tensor<i1>
}, {
^bb0(%arg0: tensor<f64>, %arg1: tensor<f64>):
stablehlo.return %arg0 : tensor<f64>
}) {
window_dimensions = array<i64: 1, 2, 2, 1>,
window_strides = array<i64: 1, 2, 2, 1>
} : (tensor<?x?x?x?xf64>, tensor<10x12x12x64xf64>, tensor<f64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_recursively_speculatable"(%3) : (tensor<?x?x?x?xf64>) -> ()

return
}

// -----

// CHECK-LABEL: func @sort
// CHECK-NEXT: return
func.func @sort(%static_arg: tensor<2x4xf64>, %dynamic_arg: tensor<?x?xf64>) {
Expand Down

0 comments on commit 88e02e7

Please sign in to comment.