Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ConditionallySpeculatable for remaining dynamic ops #2242

Merged
merged 3 commits into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -611,5 +611,22 @@ bool isSplatArray(ArrayRef<int64_t> arr, int64_t val) {
[val](int64_t x) { return x == val; });
}

mlir::Speculation::Speculatability getShapedSpeculatability(
Operation* op, int64_t shapeCount) {
// If all inputs are static and the shape-related operands are constant
// then any relationship between the input, the shapes and the output can be
// verified statically.
bool allInputsStatic = llvm::all_of(op->getOperandTypes(), [](Type t) {
return cast<ShapedType>(t).hasStaticShape();
});
bool allShapesConstant = llvm::all_of(llvm::seq(shapeCount), [&](int64_t i) {
return matchPattern(op->getOperand(op->getNumOperands() - 1 - i),
m_Constant());
});
return allInputsStatic && allShapesConstant
? mlir::Speculation::Speculatable
: mlir::Speculation::NotSpeculatable;
}

} // namespace hlo
} // namespace mlir
17 changes: 17 additions & 0 deletions stablehlo/dialect/Base.h
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,13 @@ void writeEnumAttribute(EnumTypeAttr val, DialectBytecodeWriter &writer) {
}
} // namespace bytecode

// Determines the speculatability for a shaped operation `op` with `shapeCount`
// shape operands. The last `count` operands are assumed to be shape operands.
// To be speculatable, such an op must have only static inputs and constant
// shape operands.
mlir::Speculation::Speculatability getShapedSpeculatability(Operation *op,
int64_t shapeCount);

namespace OpTrait {

template <typename ConcreteType>
Expand Down Expand Up @@ -470,6 +477,16 @@ struct RecursivelySpeculatableIfAllInputsStaticImplTrait
}
};

template <typename ConcreteType>
struct SpeculatableIfAllInputsStaticAndShapeConstantImplTrait
: public mlir::OpTrait::TraitBase<
ConcreteType,
SpeculatableIfAllInputsStaticAndShapeConstantImplTrait> {
mlir::Speculation::Speculatability getSpeculatability() {
return getShapedSpeculatability(this->getOperation(), 1);
}
};

} // namespace OpTrait
} // namespace hlo
} // namespace mlir
Expand Down
9 changes: 9 additions & 0 deletions stablehlo/dialect/Base.td
Original file line number Diff line number Diff line change
Expand Up @@ -368,4 +368,13 @@ def HLO_RecursivelySpeculatableIfAllInputsStaticImplTrait
def HLO_RecursivelySpeculatableIfAllInputsStatic : TraitList<[
ConditionallySpeculatable, HLO_RecursivelySpeculatableIfAllInputsStaticImplTrait]>;

def HLO_SpeculatableIfAllInputsStaticAndShapeConstantImplTrait
: HLO_NativeOpTrait<"SpeculatableIfAllInputsStaticAndShapeConstantImplTrait">;

// This trait is the same as HLO_SpeculatableIfAllInputsStatic, but for ops that
// take a shape as their last operand. Such ops are speculatable if all inputs
// are static and the shape is constant.
def HLO_SpeculatableIfAllInputsStaticAndShapeConstant : TraitList<[
ConditionallySpeculatable, HLO_SpeculatableIfAllInputsStaticAndShapeConstantImplTrait]>;

#endif // STABLEHLO_DIALECT_BASE
20 changes: 8 additions & 12 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1532,18 +1532,6 @@ LogicalResult DynamicReshapeOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability DynamicReshapeOp::getSpeculatability() {
// If the input is static and the shape operand is constant, the output
// shape can be inferred and any mismatch will be caught statically.
// If any dimension in the input is dynamic, or if the shape is not known,
// the number of elements may disagree at runtime.
if (getOperand().getType().hasStaticShape() &&
matchPattern(getOutputShape(), m_Constant()))
return mlir::Speculation::Speculatable;

return mlir::Speculation::NotSpeculatable;
}

//===----------------------------------------------------------------------===//
// DynamicSliceOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -1613,6 +1601,10 @@ LogicalResult RealDynamicSliceOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability RealDynamicSliceOp::getSpeculatability() {
return hlo::getShapedSpeculatability(getOperation(), /*shapeCount=*/3);
}

//===----------------------------------------------------------------------===//
// InfeedOp
//===----------------------------------------------------------------------===//
Expand Down Expand Up @@ -2160,6 +2152,10 @@ LogicalResult DynamicPadOp::reifyReturnTypeShapes(
return success();
}

mlir::Speculation::Speculatability DynamicPadOp::getSpeculatability() {
return hlo::getShapedSpeculatability(getOperation(), /*shapeCount=*/3);
}

//===----------------------------------------------------------------------===//
// ReshapeOp
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 20 additions & 11 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def StableHLO_DynamicIotaOp: StableHLO_ShapedInterfaceOp<"dynamic_iota", [Condit
$output_shape `,` `dim` `=` $iota_dimension attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = [{
let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
Expand Down Expand Up @@ -2653,7 +2653,7 @@ def StableHLO_ReshapeOp: StableHLO_Op<"reshape",
}

def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape",
[ConditionallySpeculatable, NoMemoryEffect]> {
[HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> {
let summary = "DynamicReshape operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand All @@ -2675,11 +2675,6 @@ def StableHLO_DynamicReshapeOp: StableHLO_ShapedInterfaceOp<"dynamic_reshape",
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_ScatterOp: StableHLO_Op<"scatter",
Expand Down Expand Up @@ -3375,7 +3370,8 @@ def StableHLO_ReducePrecisionOp : StableHLO_Op<"reduce_precision",

def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp<
"real_dynamic_slice",
[Pure, AllElementTypesMatch<["operand", "result"]>,
[ConditionallySpeculatable, NoMemoryEffect,
AllElementTypesMatch<["operand", "result"]>,
AllTypesMatch<["start_indices", "limit_indices", "strides"]>]> {
let summary = "RealDynamicSlice operation";
let description = [{
Expand Down Expand Up @@ -3403,10 +3399,16 @@ def StableHLO_RealDynamicSliceOp: StableHLO_ShapedInterfaceOp<
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad",
[Pure, AllElementTypesMatch<["operand", "padding_value", "result"]>,
[ConditionallySpeculatable, NoMemoryEffect,
AllElementTypesMatch<["operand", "padding_value", "result"]>,
AllTypesMatch<["edge_padding_low", "edge_padding_high", "interior_padding"]>]> {
let summary = "DynamicPad operation";
let description = [{
Expand Down Expand Up @@ -3440,10 +3442,16 @@ def StableHLO_DynamicPadOp: StableHLO_ShapedInterfaceOp<"dynamic_pad",
let hasVerifier = 1;

let assemblyFormat = "operands attr-dict `:` functional-type(operands, results)";

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather",
[InferTensorTypeWithReify, Pure]> {
[HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect,
InferTensorTypeWithReify]> {
let summary = "DynamicGather operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand Down Expand Up @@ -3476,7 +3484,8 @@ def StableHLO_DynamicGatherOp: StableHLO_Op<"dynamic_gather",
let results = (outs HLO_Tensor);
}

def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv", [Pure]> {
def StableHLO_DynamicConvOp : StableHLO_Op<"dynamic_conv",
[HLO_SpeculatableIfAllInputsStaticAndShapeConstant, NoMemoryEffect]> {
let summary = "DynamicConv operation";
let description = [{
This operation is a work in progress, so it is not yet included in
Expand Down
Loading
Loading