Skip to content

Commit

Permalink
Update spec and verifier for convolution op according to the RFC for …
Browse files Browse the repository at this point in the history
…hybrid quantized op (#2171)

This PR implements the specification and constraints for convolution as
discussed in the [RFC](#1792).
Please let me know what you think.
Thanks!

cc @sdasgup3
  • Loading branch information
doyeonkim0 authored Apr 12, 2024
1 parent c304904 commit 1f249ac
Show file tree
Hide file tree
Showing 3 changed files with 201 additions and 9 deletions.
40 changes: 31 additions & 9 deletions docs/spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -2191,6 +2191,14 @@ For quantized types, performs `dequantize_op_quantize(
feature_group_count, batch_group_count, precision_config), lhs, rhs,
type(result))`.

For hybrid quantized types, performs `hybrid_dequantize_then_op(
lambda lhs, rhs: convolution(lhs, rhs, window_strides, padding, lhs_dilation, rhs_dilation,
window_reversal, input_batch_dimension, input_feature_dimension,
input_spatial_dimensions, kernel_input_feature_dimension,
kernel_output_feature_dimension, kernel_spatial_dimensions,
output_batch_dimension, output_feature_dimension, output_spatial_dimensions,
feature_group_count, batch_group_count, precision_config), lhs, rhs)`.

#### Inputs

| Label | Name | Type | Constraints |
Expand Down Expand Up @@ -2273,16 +2281,18 @@ For quantized types, performs `dequantize_op_quantize(
* If the operation uses non-quantized tensors:
* (C27) `element_type(lhs) = element_type(rhs) = element_type(result)`.
* If the operation uses quantized tensors:
* (C28) `is_quantized_tensor(lhs) and is_quantized_tensor(rhs) and
is_quantized_tensor(result)`.
* (C29) `storage_type(lhs) = storage_type(rhs)`.
* (C30) `expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)`.
* (C31) If `is_per_tensor_quantized(rhs)`, then
`is_per_tensor_quantized(result)`.
* (C32) If `is_per_axis_quantized(rhs)`, then
`quantization_dimension(rhs) = kernel_output_feature_dimension`.
* (C33) If `is_per_axis_quantized(result)`, then
* (C28) `is_quantized(lhs) = is_quantized(result) and is_quantized(rhs)`.
* (C29) If `is_per_axis_quantized(rhs)`,
then `quantization_dimension(rhs) = kernel_output_feature_dimension`.
* (C30) If `is_per_axis_quantized(result)`, then
`quantization_dimension(result) = output_feature_dimension`.
* If `is_quantized(lhs)`:
* (C31) `storage_type(lhs) = storage_type(rhs)`.
* (C32) `expressed_type(lhs) = expressed_type(rhs) = expressed_type(result)`.
* (C33) If `is_per_tensor_quantized(rhs)`, then
`is_per_tensor_quantized(result)`.
* If `!is_quantized(lhs)`:
* (C34) `element_type(lhs) = expressed_type(rhs) = element_type(result)`.
<!-- markdownlint-enable line-length -->

#### Examples
Expand Down Expand Up @@ -6725,6 +6735,18 @@ def dequantize_select_quantize(pred, on_true, on_false, output_type):
return quantize(float_result, output_type)
```

* `hybrid_dequantize_then_op` is used to specify weight-only quantization for
hybrid op which accepts lhs in floating-point and rhs in quantized types. It
dequantizes quantized inputs into their expressed types and performs computation
in float. Element type of float lhs tensor and expressed type of quantized rhs
tensor should be identical.

```python
def hybrid_dequantize_then_op(op, lhs, rhs):
assert(is_float(lhs) and is_quantized(rhs) and element_type(lhs) == expressed_type(rhs))
return op(lhs, dequantize(rhs))
```

#### Grid computations

* `cross_partition(replica_groups: Value) -> Value`. See the "cross_replica"
Expand Down
86 changes: 86 additions & 0 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,53 @@ LogicalResult verifyBinaryOpQuantizationConstraints(
return success();
}

LogicalResult verifyConvolutionDotGeneralCommonQuantizationConstraints(
std::optional<Location> location, Type lhsElementType, Type rhsElementType,
Type resultElementType) {
// convolution_c28
if (!isa<quant::QuantizedType>(rhsElementType) ||
(isa<quant::QuantizedType>(lhsElementType) !=
isa<quant::QuantizedType>(resultElementType))) {
return emitOptionalError(
location,
"rhs should be quantized for quantized operations and "
"is_quantized(lhs)=is_quantized(result) should hold");
}

auto rhsQuantType = cast<quant::QuantizedType>(rhsElementType);
if (auto lhsQuantType = dyn_cast<quant::QuantizedType>(lhsElementType)) {
auto resultQuantType = cast<quant::QuantizedType>(resultElementType);
// convolution_c31
if (lhsQuantType.getStorageType() != rhsQuantType.getStorageType()) {
return emitOptionalError(
location, "mismatched lhs and rhs quantization storage types");
}
// convolution_c32
if (lhsQuantType.getExpressedType() != rhsQuantType.getExpressedType() ||
lhsQuantType.getExpressedType() != resultQuantType.getExpressedType()) {
return emitOptionalError(
location,
"mismatched lhs, rhs and result quantization expressed types");
}
// convolution_c33
if (isa<quant::UniformQuantizedType>(rhsQuantType) &&
!isa<quant::UniformQuantizedType>(resultQuantType)) {
return emitOptionalError(
location, "mismatched rhs and result quantization granularity");
}
} else {
Type rhsExpressedType = rhsQuantType.getExpressedType();
// convolution_c34
if (lhsElementType != rhsExpressedType ||
lhsElementType != resultElementType) {
return emitOptionalError(location,
"mismatched rhs quantization expressed type and "
"lhs and result element type");
}
}
return success();
}

bool isSameQuantPerAxisScaleZeroPoint(Type ty1, Type ty2) {
auto qty1 =
dyn_cast<quant::UniformQuantizedPerAxisType>(getElementTypeOrSelf(ty1));
Expand Down Expand Up @@ -3543,6 +3590,40 @@ LogicalResult verifyCompositeOp(std::optional<Location> loc, Operation* op,
return success();
}

LogicalResult verifyConvolutionOpQuantizationConstraints(
std::optional<Location> location, Type lhsType, Type rhsType,
Type resultType, int64_t kernelOutputFeatureDimension,
int64_t outputFeatureDimension) {
Type lhsElementType = getElementTypeOrSelf(lhsType);
Type rhsElementType = getElementTypeOrSelf(rhsType);
Type resultElementType = getElementTypeOrSelf(resultType);

// convolution_c29
if (auto rhsPerAxisType =
dyn_cast<quant::UniformQuantizedPerAxisType>(rhsElementType)) {
if (rhsPerAxisType.getQuantizedDimension() !=
kernelOutputFeatureDimension) {
return emitOptionalError(location,
"quantization dimension of rhs should be same "
"with kernel_output_feature_dimension");
}
}

// convolution_c30
if (auto resultPerAxisType =
dyn_cast<quant::UniformQuantizedPerAxisType>(resultElementType)) {
if (resultPerAxisType.getQuantizedDimension() != outputFeatureDimension) {
return emitOptionalError(location,
"quantization dimension of result should be "
"same with output_feature_dimension");
}
}

// convolution_c31 - convolution_c34
return verifyConvolutionDotGeneralCommonQuantizationConstraints(
location, lhsElementType, rhsElementType, resultElementType);
}

LogicalResult verifyConvolutionOp(
std::optional<Location> location, Type lhsType, Type rhsType,
std::optional<ArrayRef<int64_t>> windowStrides,
Expand Down Expand Up @@ -3576,6 +3657,11 @@ LogicalResult verifyConvolutionOp(
"is incompatible with return type of operation ",
shapedResultType, "");

if (anyQuantized<quant::QuantizedType>({lhsType, rhsType, resultType})) {
return verifyConvolutionOpQuantizationConstraints(
location, lhsType, rhsType, resultType, kernelOutputFeatureDimension,
outputFeatureDimension);
}
return success();
}

Expand Down
84 changes: 84 additions & 0 deletions stablehlo/tests/ops_stablehlo_quantized.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -967,3 +967,87 @@ func.func @reshape_c3_mismatch_product_before(%arg0: tensor<1x2x3x4x5x!quant.uni
%reshape = "stablehlo.reshape" (%arg0) : (tensor<1x2x3x4x5x!quant.uniform<i8:f32:0, {1.0:17}>>) -> tensor<2x1x3x20x!quant.uniform<i8:f32:1, {1.0:17}>>
func.return
}

// -----

func.func @convolution_c28(%arg0: tensor<1x8x8x207xf32>, %arg1: tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>> {
// expected-error@+1 {{rhs should be quantized for quantized operations and is_quantized(lhs)=is_quantized(result) should hold}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x8x8x207xf32>, tensor<3x3x207x16x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
func.return %0 : tensor<1x8x8x16x!quant.uniform<i8:f32, 10.0:50>>
}

// -----

func.func @convolution_c29(%arg0: tensor<1x4x4x1xf32>, %arg1: tensor<3x3x1x1x!quant.uniform<i8:f32:1, {5.0:20, 5.0:20, 5.0:20}>>) -> tensor<1x4x4x1xf32> {
// expected-error@+1 {{quantization dimension of rhs should be same with kernel_output_feature_dimension}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x4x4x1xf32>, tensor<3x3x1x1x!quant.uniform<i8:f32:1, {5.0:20, 5.0:20, 5.0:20}>>) -> tensor<1x4x4x1xf32>
func.return %0 : tensor<1x4x4x1xf32>
}

// -----

func.func @convolution_c30(%arg0: tensor<1x4x4x1x!quant.uniform<i8:f32, 4.0:10>>, %arg1: tensor<3x3x1x1x!quant.uniform<i8:f32:3, {5.0:20}>>) -> tensor<1x4x4x1x!quant.uniform<i8:f32:0, {3.0:6}>> {
// expected-error@+1 {{quantization dimension of result should be same with output_feature_dimension}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x4x4x1x!quant.uniform<i8:f32, 4.0:10>>, tensor<3x3x1x1x!quant.uniform<i8:f32:3, {5.0:20}>>) -> tensor<1x4x4x1x!quant.uniform<i8:f32:0, {3.0:6}>>
func.return %0 : tensor<1x4x4x1x!quant.uniform<i8:f32:0, {3.0:6}>>
}

// -----

func.func @convolution_c31(%arg0: tensor<1x4x4x1x!quant.uniform<i8:f32, 4.0:10>>, %arg1: tensor<3x3x1x1x!quant.uniform<i16:f32:3, {5.0:20}>>) -> tensor<1x4x4x1x!quant.uniform<i8:f32:3, {3.0:6}>> {
// expected-error@+1 {{mismatched lhs and rhs quantization storage types}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x4x4x1x!quant.uniform<i8:f32, 4.0:10>>, tensor<3x3x1x1x!quant.uniform<i16:f32:3, {5.0:20}>>) -> tensor<1x4x4x1x!quant.uniform<i8:f32:3, {3.0:6}>>
func.return %0 : tensor<1x4x4x1x!quant.uniform<i8:f32:3, {3.0:6}>>
}

// -----

func.func @convolution_c32(%arg0: tensor<1x4x4x1x!quant.uniform<i8:f16, 4.0:10>>, %arg1: tensor<3x3x1x1x!quant.uniform<i8:f32:3, {5.0:20}>>) -> tensor<1x4x4x1x!quant.uniform<i8:f32:3, {3.0:6}>> {
// expected-error@+1 {{mismatched lhs, rhs and result quantization expressed types}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x4x4x1x!quant.uniform<i8:f16, 4.0:10>>, tensor<3x3x1x1x!quant.uniform<i8:f32:3, {5.0:20}>>) -> tensor<1x4x4x1x!quant.uniform<i8:f32:3, {3.0:6}>>
func.return %0 : tensor<1x4x4x1x!quant.uniform<i8:f32:3, {3.0:6}>>
}

// -----

func.func @convolution_c33(%arg0: tensor<1x4x4x1x!quant.uniform<i8:f32, 4.0:10>>, %arg1: tensor<3x3x1x1x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x4x4x1x!quant.uniform<i8:f32:3, {3.0:6}>> {
// expected-error@+1 {{mismatched rhs and result quantization granularity}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x4x4x1x!quant.uniform<i8:f32, 4.0:10>>, tensor<3x3x1x1x!quant.uniform<i8:f32, 5.0:20>>) -> tensor<1x4x4x1x!quant.uniform<i8:f32:3, {3.0:6}>>
func.return %0 : tensor<1x4x4x1x!quant.uniform<i8:f32:3, {3.0:6}>>
}

// -----

func.func @convolution_c34(%arg0: tensor<1x4x4x1xf32>, %arg1: tensor<3x3x1x1x!quant.uniform<i8:f16:3, {5.0:20}>>) -> tensor<1x4x4x1xf32> {
// expected-error@+1 {{mismatched rhs quantization expressed type and lhs and result element type}}
%0 = stablehlo.convolution(%arg0, %arg1)
dim_numbers = [b, 0, 1, f]x[0, 1, i, o]->[b, 0, 1, f],
window = {stride = [1, 1], pad = [[1, 1], [1, 1]], lhs_dilate = [1, 1], rhs_dilate = [1, 1]}
{batch_group_count = 1 : i64, feature_group_count = 1 : i64, precision_config = [#stablehlo<precision DEFAULT>, #stablehlo<precision DEFAULT>]} :
(tensor<1x4x4x1xf32>, tensor<3x3x1x1x!quant.uniform<i8:f16:3, {5.0:20}>>) -> tensor<1x4x4x1xf32>
func.return %0 : tensor<1x4x4x1xf32>
}

0 comments on commit 1f249ac

Please sign in to comment.