From 88e02e74f0cdffa0bb5bc75e5a3c194a7d2e01e7 Mon Sep 17 00:00:00 2001 From: Michael Levesque-Dion Date: Wed, 17 Apr 2024 11:23:39 -0700 Subject: [PATCH] Implement ConditionallySpeculatable for SelectAndScatter The spec says (C11): shape(operand) = shape(result). --- stablehlo/dialect/Base.h | 17 +++++++ stablehlo/dialect/Base.td | 10 ++++ stablehlo/dialect/StablehloOps.td | 3 +- stablehlo/tests/ops_speculatability.mlir | 63 ++++++++++++++++++++++++ 4 files changed, 92 insertions(+), 1 deletion(-) diff --git a/stablehlo/dialect/Base.h b/stablehlo/dialect/Base.h index f7cbec3c50..36ce89c126 100644 --- a/stablehlo/dialect/Base.h +++ b/stablehlo/dialect/Base.h @@ -421,6 +421,23 @@ struct SpeculatableIfStaticDimInOutputIsStaticInInputImplTrait } }; +template +struct RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait + : public mlir::OpTrait::TraitBase< + ConcreteType, + RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait> { + mlir::Speculation::Speculatability getSpeculatability() { + auto op = this->getOperation(); + auto inputType = cast(op->getOperand(0).getType()); + auto resultType = cast(op->getResult(0).getType()); + for (size_t i : llvm::seq(resultType.getRank())) { + if (!resultType.isDynamicDim(i) && inputType.isDynamicDim(i)) + return mlir::Speculation::NotSpeculatable; + } + return mlir::Speculation::RecursivelySpeculatable; + } +}; + template struct SpeculatableIfAllInputsStaticImplTrait : public mlir::OpTrait::TraitBase; +def HLO_RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait + : HLO_NativeOpTrait<"RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait">; + +// This trait is the same as HLO_SpeculatableIfStaticDimInOutputIsStaticInInput, +// but for ops that have regions. If all static dimensions in the output are +// static in the input, such an op is RecursivelySpeculatable (the ops in its +// regions have to be checked for speculatability). +def HLO_RecursivelySpeculatableIfStaticDimInOutputIsStaticInInput : TraitList<[ + ConditionallySpeculatable, HLO_RecursivelySpeculatableIfStaticDimInOutputIsStaticInInputImplTrait]>; + def HLO_SpeculatableIfAllInputsStaticImplTrait : HLO_NativeOpTrait<"SpeculatableIfAllInputsStaticImplTrait">; diff --git a/stablehlo/dialect/StablehloOps.td b/stablehlo/dialect/StablehloOps.td index dce23c766f..2bb8723742 100644 --- a/stablehlo/dialect/StablehloOps.td +++ b/stablehlo/dialect/StablehloOps.td @@ -2739,7 +2739,8 @@ def StableHLO_SelectOp: StableHLO_Op<"select", } def StableHLO_SelectAndScatterOp: StableHLO_Op<"select_and_scatter", - [DeclareOpInterfaceMethods /*select_and_scatter_c11, + [HLO_RecursivelySpeculatableIfStaticDimInOutputIsStaticInInput, + DeclareOpInterfaceMethods /*select_and_scatter_c11, select_and_scatter_c12*/, RecursiveMemoryEffects]> { let summary = "SelectAndScatter operation"; let description = [{ diff --git a/stablehlo/tests/ops_speculatability.mlir b/stablehlo/tests/ops_speculatability.mlir index 3dcffb1eda..31f59ee3a3 100644 --- a/stablehlo/tests/ops_speculatability.mlir +++ b/stablehlo/tests/ops_speculatability.mlir @@ -1708,6 +1708,69 @@ func.func @reduce_window(%static_arg: tensor<2x4xf64>, %dynamic_arg: tensor, %dynamic_arg: tensor, + %source: tensor<10x12x12x64xf64>, %init: tensor +) { + %0 = "stablehlo.select_and_scatter"(%static_arg, %source, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %c0 = stablehlo.constant dense : tensor + stablehlo.return %c0 : tensor + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + window_dimensions = array, + window_strides = array + } : (tensor<10x24x24x64xf64>, tensor<10x12x12x64xf64>, tensor) -> tensor<10x24x24x64xf64> + "hlo_test_speculatability.is_recursively_speculatable"(%0) : (tensor<10x24x24x64xf64>) -> () + + %1 = "stablehlo.select_and_scatter"(%static_arg, %source, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %c0 = stablehlo.constant dense : tensor + stablehlo.return %c0 : tensor + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + window_dimensions = array, + window_strides = array + } : (tensor<10x24x24x64xf64>, tensor<10x12x12x64xf64>, tensor) -> tensor + "hlo_test_speculatability.is_recursively_speculatable"(%1) : (tensor) -> () + + %2 = "stablehlo.select_and_scatter"(%dynamic_arg, %source, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %c0 = stablehlo.constant dense : tensor + stablehlo.return %c0 : tensor + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + window_dimensions = array, + window_strides = array + } : (tensor, tensor<10x12x12x64xf64>, tensor) -> tensor<10x24x24x64xf64> + "hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<10x24x24x64xf64>) -> () + + %3 = "stablehlo.select_and_scatter"(%dynamic_arg, %source, %init) ({ + ^bb0(%arg0: tensor, %arg1: tensor): + %c0 = stablehlo.constant dense : tensor + stablehlo.return %c0 : tensor + }, { + ^bb0(%arg0: tensor, %arg1: tensor): + stablehlo.return %arg0 : tensor + }) { + window_dimensions = array, + window_strides = array + } : (tensor, tensor<10x12x12x64xf64>, tensor) -> tensor + "hlo_test_speculatability.is_recursively_speculatable"(%3) : (tensor) -> () + + return +} + +// ----- + // CHECK-LABEL: func @sort // CHECK-NEXT: return func.func @sort(%static_arg: tensor<2x4xf64>, %dynamic_arg: tensor) {