Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement ConditionallySpeculatable for ConvolutionOp #2228

Merged
merged 4 commits into from
Apr 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,58 @@ LogicalResult ConvolutionOp::verify() {
getResult().getType());
}

mlir::Speculation::Speculatability ConvolutionOp::getSpeculatability() {
auto inputType = getLhs().getType();
auto kernelType = getRhs().getType();
auto resultType = getType();

auto dimNumbers = getDimensionNumbers();
auto inputBatchDim = dimNumbers.getInputBatchDimension();
auto inputFeatureDim = dimNumbers.getInputFeatureDimension();
auto inputSpatialDims = dimNumbers.getInputSpatialDimensions();
auto kernelInputFeatureDim = dimNumbers.getKernelInputFeatureDimension();
auto kernelOutputFeatureDim = dimNumbers.getKernelOutputFeatureDimension();
auto kernelSpatialDims = dimNumbers.getKernelSpatialDimensions();
auto outputBatchDim = dimNumbers.getOutputBatchDimension();
auto outputFeatureDim = dimNumbers.getOutputFeatureDimension();
auto outputSpatialDims = dimNumbers.getOutputSpatialDimensions();

auto batchGroupCount = getBatchGroupCount();
auto featureGroupCount = getFeatureGroupCount();

// input_feature_dimension and kernel_input_feature_dimension must be static
// (C14).
if (inputType.isDynamicDim(inputFeatureDim) ||
kernelType.isDynamicDim(kernelInputFeatureDim))
return mlir::Speculation::NotSpeculatable;

// input_batch_dimension must be static if batch_group_count > 1 (C10) or if
// output_batch_dimension is static (C25).
if (inputType.isDynamicDim(inputBatchDim) &&
(batchGroupCount > 1 || !resultType.isDynamicDim(outputBatchDim)))
return mlir::Speculation::NotSpeculatable;

// kernel_output_feature_dimension must be static if batch_group_count > 1
// (C15) or feature_group_count > 1 (C16) or if output_feature_dimension is
// static (C25).
if (kernelType.isDynamicDim(kernelOutputFeatureDim) &&
(batchGroupCount > 1 || featureGroupCount > 1 ||
!resultType.isDynamicDim(outputFeatureDim)))
return mlir::Speculation::NotSpeculatable;

// If a spatial dimension is static in the output, it must be static in the
// inputs (C25).
for (auto [inputDim, kernelDim, resultDim] :
llvm::zip(inputSpatialDims, kernelSpatialDims, outputSpatialDims)) {
if (!resultType.isDynamicDim(resultDim) &&
(inputType.isDynamicDim(inputDim) ||
kernelType.isDynamicDim(kernelDim)))
return mlir::Speculation::NotSpeculatable;
}

return mlir::Speculation::Speculatable;
}

//===----------------------------------------------------------------------===//
// ConvertOp
//===----------------------------------------------------------------------===//
Expand Down
8 changes: 7 additions & 1 deletion stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -2204,7 +2204,8 @@ def StableHLO_CompositeOp : StableHLO_Op<"composite", [DeclareOpInterfaceMethods
let assemblyFormat = "$name $inputs attr-dict `:` functional-type(operands, results)";
}

def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [Pure]> {
def StableHLO_ConvolutionOp : StableHLO_Op<"convolution",
[ConditionallySpeculatable, NoMemoryEffect]> {
let summary = "Convolution operation";
let description = [{
Computes dot products between windows of `lhs` and slices of `rhs` and
Expand Down Expand Up @@ -2252,6 +2253,11 @@ def StableHLO_ConvolutionOp : StableHLO_Op<"convolution", [Pure]> {
$window_reversal) `}`
attr-dict `:` functional-type(operands, results)
}];

let extraClassDeclaration = commonClassDeclaration # [{
/// Interface method for ConditionallySpeculatable.
mlir::Speculation::Speculatability getSpeculatability();
}];
}

def StableHLO_CrossReplicaSumOp : StableHLO_Op<"cross-replica-sum",
Expand Down
105 changes: 105 additions & 0 deletions stablehlo/tests/ops_speculatability.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1366,6 +1366,111 @@ func.func @concatenate(%static_arg: tensor<2x2xi64>, %first_dim_dynamic: tensor<

// -----

// CHECK-LABEL: func @convolution
// CHECK-NEXT: return
func.func @convolution(
%input_static: tensor<100x26x26x32xf64>, %kernel_static: tensor<3x3x2x32xf64>,
%input_feature_dim_dynamic: tensor<100x26x26x?xf64>, %input_batch_dim_dynamic: tensor<?x26x26x32xf64>,
%kernel_feature_dim_dynamic: tensor<3x3x2x?xf64>, %kernel_output_feature_dim_dynamic: tensor<3x3x?x32xf64>, %kernel_output_feature_dim_dynamic_2_feature_groups: tensor<3x3x?x16xf64>,
%input_spatial_dims_dynamic: tensor<100x?x?x32xf64>, %kernel_spatial_dims_dynamic: tensor<?x?x2x32xf64>
) {
// Inputs fully static
%0 = stablehlo.convolution(%input_static, %kernel_static)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x32xf64>, tensor<3x3x2x32xf64>) -> tensor<100x24x24x2xf64>
"hlo_test_speculatability.is_speculatable"(%0) : (tensor<100x24x24x2xf64>) -> ()
%1 = stablehlo.convolution(%input_static, %kernel_static)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x32xf64>, tensor<3x3x2x32xf64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_speculatable"(%1) : (tensor<?x?x?x?xf64>) -> ()

// input_feature_dimension is dynamic
%2 = stablehlo.convolution(%input_feature_dim_dynamic, %kernel_static)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x?xf64>, tensor<3x3x2x32xf64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_not_speculatable"(%2) : (tensor<?x?x?x?xf64>) -> ()

// kernel_input_feature_dimension is dynamic
%3 = stablehlo.convolution(%input_static, %kernel_feature_dim_dynamic)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x32xf64>, tensor<3x3x2x?xf64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_not_speculatable"(%3) : (tensor<?x?x?x?xf64>) -> ()

// input_batch_dimension is dynamic
%4 = stablehlo.convolution(%input_batch_dim_dynamic, %kernel_static)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<?x26x26x32xf64>, tensor<3x3x2x32xf64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_speculatable"(%4) : (tensor<?x?x?x?xf64>) -> ()
// batch_group_count > 1
%5 = stablehlo.convolution(%input_batch_dim_dynamic, %kernel_static)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 2 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<?x26x26x32xf64>, tensor<3x3x2x32xf64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_not_speculatable"(%5) : (tensor<?x?x?x?xf64>) -> ()
// output_batch_dimension is static
%6 = stablehlo.convolution(%input_batch_dim_dynamic, %kernel_static)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<?x26x26x32xf64>, tensor<3x3x2x32xf64>) -> tensor<100x?x?x?xf64>
"hlo_test_speculatability.is_not_speculatable"(%6) : (tensor<100x?x?x?xf64>) -> ()

// kernel_output_feature_dimension is dynamic
%7 = stablehlo.convolution(%input_static, %kernel_output_feature_dim_dynamic)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x32xf64>, tensor<3x3x?x32xf64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_speculatable"(%7) : (tensor<?x?x?x?xf64>) -> ()
// batch_group_count > 1
%8 = stablehlo.convolution(%input_static, %kernel_output_feature_dim_dynamic)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 2 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x32xf64>, tensor<3x3x?x32xf64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_not_speculatable"(%8) : (tensor<?x?x?x?xf64>) -> ()
// feature_group_count > 1
%9 = stablehlo.convolution(%input_static, %kernel_output_feature_dim_dynamic_2_feature_groups)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 2 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x32xf64>, tensor<3x3x?x16xf64>) -> tensor<?x?x?x?xf64>
"hlo_test_speculatability.is_not_speculatable"(%9) : (tensor<?x?x?x?xf64>) -> ()
// output_feature_dimension is static
%10 = stablehlo.convolution(%input_static, %kernel_output_feature_dim_dynamic)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x32xf64>, tensor<3x3x?x32xf64>) -> tensor<?x?x?x2xf64>
"hlo_test_speculatability.is_not_speculatable"(%10) : (tensor<?x?x?x2xf64>) -> ()

// Spatial dimensions dynamic
%11 = stablehlo.convolution(%input_spatial_dims_dynamic, %kernel_static)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x?x?x32xf64>, tensor<3x3x2x32xf64>) -> tensor<100x24x24x2xf64>
"hlo_test_speculatability.is_not_speculatable"(%11) : (tensor<100x24x24x2xf64>) -> ()
%12 = stablehlo.convolution(%input_spatial_dims_dynamic, %kernel_static)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x?x?x32xf64>, tensor<3x3x2x32xf64>) -> tensor<100x?x?x2xf64>
"hlo_test_speculatability.is_speculatable"(%12) : (tensor<100x?x?x2xf64>) -> ()
%13 = stablehlo.convolution(%input_static, %kernel_spatial_dims_dynamic)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x32xf64>, tensor<?x?x2x32xf64>) -> tensor<100x24x24x2xf64>
"hlo_test_speculatability.is_not_speculatable"(%13) : (tensor<100x24x24x2xf64>) -> ()
%14 = stablehlo.convolution(%input_static, %kernel_spatial_dims_dynamic)
dim_numbers = [b, 0, 1, f] x [0, 1, o, i] -> [b, 0, 1, f],
window = {stride = [], pad = [], lhs_dilate = [], rhs_dilate = [], reverse = []}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} : (tensor<100x26x26x32xf64>, tensor<?x?x2x32xf64>) -> tensor<100x?x?x2xf64>
"hlo_test_speculatability.is_speculatable"(%14) : (tensor<100x?x?x2xf64>) -> ()

return
}

// -----

// CHECK-LABEL: func @dot_general
// CHECK-NEXT: return
func.func @dot_general(
Expand Down
Loading