Skip to content

Commit

Permalink
AddOp: Remove HLO_CompatibleOperandsAndResultType Trait, add custom v…
Browse files Browse the repository at this point in the history
…erifier (#2127)

AddOP allows mix of `per-tensor` and `per-axis` inputs. Existing
StableHLO AddOp def uses `HLO_CompatibleOperandsAndResultType` Trait
which does not allow mix of `per-tensor` and `per-axis` inputs.
This PR:
removed HLO_CompatibleOperandsAndResultType Trait for AddOP
Added custom Verifier to implement OP constraints



Thank you @sdasgup3 for the help with debugging builder issues after
removal of `HLO_CompatibleOperandsAndResultType` suggestion for
`inferReturnTypes` placement.
  • Loading branch information
abhigunj authored Mar 29, 2024
1 parent 416b87d commit 3b953c6
Show file tree
Hide file tree
Showing 8 changed files with 204 additions and 22 deletions.
27 changes: 26 additions & 1 deletion stablehlo/dialect/StablehloOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ LogicalResult ReduceScatterOp::verify() {
inferredReturnShapes); \
}

INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AddOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(AndOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(Atan2Op)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(CbrtOp)
Expand Down Expand Up @@ -186,6 +185,32 @@ INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(SubtractOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(TanhOp)
INFER_RETURN_TYPE_COMPONENTS_FROM_OPERANDS(XorOp)

//===----------------------------------------------------------------------===//
// AddOp
//===----------------------------------------------------------------------===//

LogicalResult AddOp::inferReturnTypeComponents(
MLIRContext* context, std::optional<Location> location,
ValueShapeRange operands, DictionaryAttr attributes,
OpaqueProperties properties, RegionRange regions,
SmallVectorImpl<ShapedTypeComponents>& inferredReturnShapes) {
SmallVector<Type> inferredReturnTypes;
if (failed(inferReturnTypes(context, location, operands.getValues(),
attributes, properties, regions,
inferredReturnTypes)))
return failure();
if (inferredReturnTypes.size() != 1) return failure();
auto inferredReturnType = inferredReturnTypes[0].dyn_cast<ShapedType>();
if (!inferredReturnType) return failure();
inferredReturnShapes.push_back(inferredReturnType);
return success();
}

LogicalResult AddOp::verify() {
return hlo::verifyAddOp(getLoc(), getOperation(), getLhs().getType(),
getRhs().getType(), getResult().getType());
}

//===----------------------------------------------------------------------===//
// AfterAllOp
//===----------------------------------------------------------------------===//
Expand Down
31 changes: 28 additions & 3 deletions stablehlo/dialect/StablehloOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -660,7 +660,7 @@ class StableHLO_BinaryElementwiseOp<string mnemonic, list<Trait> traits,
OperandType:$rhs
);

let extraClassDeclaration = commonClassDeclaration # [{
string binaryElementwiseOpCommonClassDeclaration = commonClassDeclaration # [{
LogicalResult reifyReturnTypeShapes(
OpBuilder& builder, ValueRange operands,
SmallVectorImpl<Value>& reifiedReturnShapes) {
Expand All @@ -670,6 +670,8 @@ class StableHLO_BinaryElementwiseOp<string mnemonic, list<Trait> traits,
}
}];

let extraClassDeclaration = binaryElementwiseOpCommonClassDeclaration;

let results = (outs ResultType:$result);

let assemblyFormat = [{
Expand All @@ -678,8 +680,9 @@ class StableHLO_BinaryElementwiseOp<string mnemonic, list<Trait> traits,
}];
}

def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add",
[HLO_Commutative, Pure, HLO_CompatibleOperandsAndResultType],
def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add", [HLO_Commutative, Pure,
InferTypeOpInterface,
DeclareOpInterfaceMethods<InferShapedTypeOpInterface, ["inferReturnTypeComponents"]>],
HLO_TensorOrPerAxisQuantizedTensor> {
let summary = "Add operation";
let description = [{
Expand All @@ -694,6 +697,28 @@ def StableHLO_AddOp : StableHLO_BinaryElementwiseOp<"add",
%result = stablehlo.add %lhs, %rhs : tensor<2x2xi32>
```
}];

let extraClassDeclaration = binaryElementwiseOpCommonClassDeclaration # [{
static LogicalResult inferReturnTypes(
MLIRContext * /*context*/, std::optional<Location> location,
ValueRange operands, DictionaryAttr /*attributes*/,
OpaqueProperties /*properties*/, RegionRange /*regions*/,
SmallVectorImpl<Type> &inferredReturnTypes) {
if (operands.empty())
return emitOptionalError(
location,
"Expected non-empty operands for AddOp::inferReturnTypes");

auto inferredTypeOrErr =
mlir::hlo::inferMostSpecificType(location, operands.getTypes());
if (failed(inferredTypeOrErr)) return failure();
inferredReturnTypes.emplace_back(*inferredTypeOrErr);
return success();
}
}];

let hasVerifier = 1;

}

def StableHLO_Atan2Op : StableHLO_BinaryElementwiseOp<"atan2",
Expand Down
105 changes: 105 additions & 0 deletions stablehlo/dialect/TypeInference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,95 @@ limitations under the License.

namespace mlir {
namespace hlo {
namespace {
//===----------------------------------------------------------------------===//
// Utils for quantization specific verifications
//===----------------------------------------------------------------------===//
template <typename T>
bool allQuantized(ArrayRef<Type> typeRange) {
return llvm::all_of(
typeRange, [&](Type val) { return getElementTypeOrSelf(val).isa<T>(); });
}

template <typename T>
bool noneQuantized(ArrayRef<Type> typeRange) {
return llvm::all_of(
typeRange, [&](Type val) { return !getElementTypeOrSelf(val).isa<T>(); });
}

template <typename T>
bool anyQuantized(ArrayRef<Type> typeRange) {
return llvm::any_of(
typeRange, [&](Type val) { return getElementTypeOrSelf(val).isa<T>(); });
}

LogicalResult verifyBinaryOpQuantizationConstraints(
std::optional<Location> location, Type lhsType, Type rhsType,
Type resultType) {
lhsType = getElementTypeOrSelf(lhsType);
rhsType = getElementTypeOrSelf(rhsType);
resultType = getElementTypeOrSelf(resultType);
llvm::SmallVector<Type, 3> typeEntries{lhsType, rhsType, resultType};

// add_c2
if (!allQuantized<quant::QuantizedType>(typeEntries)) {
return emitOptionalError(location,
"expects all operands and results to be either "
"quantized or non-quantized");
}
auto lhsQType = lhsType.dyn_cast<quant::QuantizedType>();
auto rhsQType = rhsType.dyn_cast<quant::QuantizedType>();
auto resultQType = resultType.dyn_cast<quant::QuantizedType>();
// add_c3
auto storageType = lhsQType.getStorageType();
if (storageType != rhsQType.getStorageType() ||
storageType != resultQType.getStorageType())
return emitOptionalError(
location, "mismatched operands and result quantization storage types");
// add_c4
auto expressedType = lhsQType.getExpressedType();
if (expressedType != rhsQType.getExpressedType() ||
expressedType != resultQType.getExpressedType())
return emitOptionalError(
location,
"mismatched operands and result quantization expressed types");

auto lhsQPAType = lhsType.dyn_cast<quant::UniformQuantizedPerAxisType>();
auto rhsQPAType = rhsType.dyn_cast<quant::UniformQuantizedPerAxisType>();
auto resultQPAType =
resultType.dyn_cast<quant::UniformQuantizedPerAxisType>();
if (lhsQPAType || rhsQPAType) {
// add_c5
if (!resultQPAType)
return emitOptionalError(
location, "result is not per_axis quantized but lhs or rhs are");
// add_c6
if (lhsQPAType) {
if (resultQPAType.getQuantizedDimension() !=
lhsQPAType.getQuantizedDimension())
return emitOptionalError(
location, "quantization_dimension of lhs and result are not same ",
lhsType, " vs ", resultType);
}
// add_c7
if (rhsQPAType) {
if (resultQPAType.getQuantizedDimension() !=
rhsQPAType.getQuantizedDimension())
return emitOptionalError(
location, "quantization_dimension of rhs and result are not same ",
rhsType, " vs ", resultType);
}
return success();
}

if (resultQPAType)
return emitOptionalError(location,
"result per_axis quantized but none from rhs "
"and lhs are per_axis quantized");
return success();
}

} // namespace

//===----------------------------------------------------------------------===//
// Utils for shape functions.
Expand Down Expand Up @@ -264,6 +353,22 @@ LogicalResult verifyPairwiseCompatibleShapes(TypeRange values) {
return success();
}

LogicalResult verifyAddOp(std::optional<Location> location, Operation* op,
Type lhsType, Type rhsType, Type resultType) {
llvm::SmallVector<Type, 3> typeEntries{lhsType, rhsType, resultType};
if (anyQuantized<quant::QuantizedType>(typeEntries))
return verifyBinaryOpQuantizationConstraints(location, lhsType, rhsType,
resultType);

if (getElementTypeOrSelf(lhsType) != getElementTypeOrSelf(rhsType) ||
getElementTypeOrSelf(lhsType) != getElementTypeOrSelf(resultType))
return emitOptionalError(
location,
"op requires the same element type for all operands and results");

return success();
}

LogicalResult verifyBatchNorm(std::optional<Location> location,
ValueRange multiDimOperands,
ValueRange singleDimOperands,
Expand Down
3 changes: 3 additions & 0 deletions stablehlo/dialect/TypeInference.h
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,9 @@ LogicalResult inferWhileOp(std::optional<Location> location, ValueRange operand,
// Verifiers for ops.
//===----------------------------------------------------------------------===//

LogicalResult verifyAddOp(std::optional<Location> location, Operation* op,
Type lhsType, Type rhsType, Type resultType);

LogicalResult verifyAllGatherOp(std::optional<Location> location, Value operand,
int64_t allGatherDim,
DenseIntElementsAttr replicaGroups,
Expand Down
3 changes: 2 additions & 1 deletion stablehlo/tests/infer_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -1204,7 +1204,8 @@ func.func @add_bounds(
func.func @add_bounds_mismatch(
%arg0: tensor<3xf32, #stablehlo.bounds<?>>,
%arg1: tensor<?xf32, #stablehlo.bounds<2>>) -> tensor<?xindex> {
// expected-error@+1 {{requires compatible types for all operands and results}}
// expected-error@+2 {{op failed to infer returned types}}
// expected-error@+1 {{Mismatched dimension size 3 and bound 2 in dimension 0}}
%result = "stablehlo.add"(%arg0, %arg1) : (
tensor<3xf32, #stablehlo.bounds<?>>,
tensor<?xf32, #stablehlo.bounds<2>>) -> tensor<?xf32>
Expand Down
52 changes: 38 additions & 14 deletions stablehlo/tests/ops_stablehlo.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5058,15 +5058,15 @@ func.func @is_compatible_dynamism_mix(%arg0: tensor<?xf32>, %arg1: tensor<1xf32>
// -----

func.func @is_compatible_dynamism_ranked_mismatch(%arg0: tensor<?xf32>) {
// expected-error@+1 {{op requires compatible types for all operands and results}}
// expected-error@+1 {{op requires the same shape for all operands and results}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<?xf32>, tensor<?xf32>) -> tensor<?x?xf32>
func.return
}

// -----

func.func @is_compatible_dynamism_dim_mismatch(%arg0: tensor<1x?xf32>) {
// expected-error@+1 {{op requires compatible types for all operands and results}}
// expected-error@+1 {{op requires the same shape for all operands and results}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x?xf32>, tensor<1x?xf32>) -> tensor<2x2xf32>
func.return
}
Expand All @@ -5077,43 +5077,66 @@ func.func @is_compatible_quant_mix_non_quant(%arg0: tensor<1xf32>, %arg1: tensor
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
%1 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%2 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
%3 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 2.0:17>>
%4 = "stablehlo.add"(%arg1, %arg1) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:18>>

func.return
}


// -----

func.func @is_compatible_quant_mix_scale(%arg0: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 2.0:17>>
func.func @add_c4(%arg0: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
// expected-error@+1 {{mismatched operands and result quantization expressed types}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:bf16, 1.0:17>>
func.return
}

// -----

func.func @is_compatible_quant_mix_zero_point(%arg0: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:18>>
func.func @add_c3(%arg0: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
// expected-error@+1 {{mismatched operands and result quantization storage types}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i4:f32, 1.0:17>>
func.return
}

// -----

func.func @is_compatible_quant_expressed_mismatch(%arg0: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
// expected-error@+1 {{op requires compatible types for all operands and results}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i8:bf16, 1.0:17>>
func.func @add_c2(%arg0: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
// expected-error@+1 {{all operands and results to be either quantized or non-quantized}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1xf32>
func.return
}

// -----

func.func @is_compatible_quant_storage_mismatch(%arg0: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
// expected-error@+1 {{op requires compatible types for all operands and results}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<i4:f32, 1.0:17>>
func.func @add_c5(%arg0: tensor<1x!quant.uniform<i8:f32:0, {1.0:17}>>) {
// expected-error@+1 {{result is not per_axis quantized but lhs or rhs are}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform<i8:f32:0, {1.0:17}>>, tensor<1x!quant.uniform<i8:f32:0, {1.0:17}>>) -> tensor<1x!quant.uniform<i8:f32, 1.0:17>>
func.return
}

// -----

func.func @add_c6(%arg0: tensor<1x2x!quant.uniform<i8:f32:0, {1.0:17}>>) {
// expected-error@+1 {{quantization_dimension of lhs and result are not same}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x2x!quant.uniform<i8:f32:0, {1.0:17}>>, tensor<1x2x!quant.uniform<i8:f32:0, {1.0:17}>>) -> tensor<1x2x!quant.uniform<i8:f32:2, {1.0:17}>>
func.return
}

// -----

func.func @add_c7(%arg0: tensor<1x2x!quant.uniform<i8:f32:0, {1.0:17}>>, %arg1: tensor<1x2x!quant.uniform<i8:f32:1, {1.0:17}>>) {
// expected-error@+1 {{quantization_dimension of rhs and result are not same}}
%0 = "stablehlo.add"(%arg0, %arg1) : (tensor<1x2x!quant.uniform<i8:f32:0, {1.0:17}>>, tensor<1x2x!quant.uniform<i8:f32:1, {1.0:17}>>) -> tensor<1x2x!quant.uniform<i8:f32:0, {1.0:17}>>
func.return
}

// -----

func.func @is_compatible_quant_signedness_mismatch(%arg0: tensor<1x!quant.uniform<i8:f32, 1.0:17>>) {
// expected-error@+1 {{op requires compatible types for all operands and results}}
// expected-error@+2 {{op failed to infer returned types}}
// expected-error@+1 {{op inferred type(s) 'tensor<1x!quant.uniform<i8:f32, 1.000000e+00:17>>' are incompatible with return type(s) of operation 'tensor<1x!quant.uniform<u8:f32, 1.000000e+00:17>>'}}
%0 = "stablehlo.add"(%arg0, %arg0) : (tensor<1x!quant.uniform<i8:f32, 1.0:17>>, tensor<1x!quant.uniform<i8:f32, 1.0:17>>) -> tensor<1x!quant.uniform<u8:f32, 1.0:17>>
func.return
}
Expand All @@ -5135,7 +5158,8 @@ func.func @is_compatible_dynamism_bounds_mismatch(
func.func @is_compatible_dynamism_bounds_mismatch(
%arg0: tensor<?xf32, #stablehlo.type_extensions<bounds = [4]>>,
%arg1: tensor<?xf32, #stablehlo.type_extensions<bounds = [4]>>) {
// expected-error@+1 {{requires compatible types for all operands and results}}
// expected-error@+2 {{op failed to infer returned types}}
// expected-error@+1 {{'stablehlo.add' op inferred type(s) 'tensor<?xf32, #stablehlo.bounds<4>>' are incompatible with return type(s) of operation 'tensor<5xf32>'}}
%0 = "stablehlo.add"(%arg0, %arg1) : (
tensor<?xf32, #stablehlo.type_extensions<bounds = [4]>>,
tensor<?xf32, #stablehlo.type_extensions<bounds = [4]>>) -> tensor<5xf32>
Expand Down
3 changes: 1 addition & 2 deletions stablehlo/tests/print_types_invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,9 @@ func.func @unary_eltwise_two_types(%arg0: tensor<?x?xf64>,

// -----

// TODO(ajcbik): error message is a bit too strict, should be "compatible" type?
func.func @binary_eltwise_type_mismatch(%arg0: tensor<?x?xf64>,
%arg1: tensor<?x?xf32>) -> tensor<?x?xf64> {
// expected-error @+1 {{'stablehlo.add' op requires compatible types for all operands and results}}
// expected-error @+1 {{op requires the same element type for all operands and results}}
%0 = stablehlo.add %arg0, %arg1 : (tensor<?x?xf64>, tensor<?x?xf32>) -> tensor<?x?xf64>
func.return %0 : tensor<?x?xf64>
}
Expand Down
2 changes: 1 addition & 1 deletion stablehlo/tests/stablehlo_refine_shapes.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
func.func @error_illformed(%arg0: tensor<3xf32>, %arg1: tensor<4xf32>) -> tensor<?xf32> {
%0 = stablehlo.abs %arg0 : (tensor<3xf32>) -> tensor<?xf32>
%1 = stablehlo.abs %arg1 : (tensor<4xf32>) -> tensor<?xf32>
// expected-error@+1{{requires compatible types for all operands and results}}
// expected-error@+1{{requires the same shape for all operands and results}}
%2 = stablehlo.add %0, %1 : (tensor<?xf32>, tensor<?xf32>) -> tensor<?xf32>
func.return %2 : tensor<?xf32>
}
Expand Down

0 comments on commit 3b953c6

Please sign in to comment.