diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td index 364655f8..d3c26094 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNOps.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNOps.td @@ -113,59 +113,82 @@ def XTenNN_QuantizeOp: XTenNN_Op<"quantize", [ Elementwise, Pure, SameOperandsAndResultShape]> { - let summary = "Quantizes a float32 tensor to a signless or unsigned integer tensor of given width."; + let summary = "Quantizes a float tensor to a signless or unsigned integer tensor of given width."; let description = [{ - Quantizes a given float32 tensor into a signless or unsigned integer tensor of given width. + Quantizes a given float tensor into a signless or unsigned integer tensor of given width. Since tosa is using signless/unsigned types currently, we also consider signless integer types for signed ones when the type is not unsigned until tosa support signed integers. Applies the following linear quantization to the input tensor x: - y = round( x / 2^shift ) + y = round((x / scale) + zero_point) - Where 2^shift is equal to the scale of the quantize operation and - the shift is an attribute of the operation in si32. + Iff log2(scale) is representable as si32 and zero_point == 0, shift is set to log2(scale). + In this case the quantization is equal to: + y = round( x / 2^shift ) Round will saturate to the range of the output type and the rounding mode is set to half to nearest even. }]; let arguments = (ins - F32Tensor:$input, - SI32Attr:$shift + XTenNN_AnyFloatTensor:$input, + OptionalAttr:$shift, + F32Attr:$scale, // Restricted to F32 for now, but may be relaxed in the future + XTenNN_AnyIntegerAttr:$zero_point ); let results = (outs XTenNN_AnySignlessOrUnsignedIntegerTensor:$output); + let builders = [ + OpBuilder<(ins "::mlir::Type":$output, "::mlir::Value":$input, "::mlir::FloatAttr":$scale, "::mlir::IntegerAttr":$zeroPoint)>, + OpBuilder<(ins "::mlir::Type":$output, "::mlir::Value":$input, "int32_t":$shift)> + ]; - let assemblyFormat = [{ `(`$input `:` type($input)`)` attr-dict `->` type($output) }]; + // `(`$input `:` type($input)`)` attr-dict `->` type($output); + // Zero point is optional if scale is set, defaults to 0 + // If shift is set but not scale, scale is based on the shift. Setting a zero point is not allowed in this case + let hasCustomAssemblyFormat = 1; let hasFolder = 1; + let hasVerifier = 1; } def XTenNN_DequantizeOp: XTenNN_Op<"dequantize", [ Elementwise, Pure, SameOperandsAndResultShape]> { - let summary = "Dequantizes a signless/unsigned integer tensor of given bitwidth to a float32 tensor."; + let summary = "Dequantizes a signless/unsigned integer tensor of given bitwidth to a float tensor."; let description = [{ - Dequantizes a signless/unsigned integer tensor of given bitwidth to a float32 tensor. + Dequantizes a signless/unsigned integer tensor of given bitwidth to a float tensor. Since tosa is using signless/unsigned types currently, we also consider signless integer types for signed ones when the type is not unsigned until tosa support signed integers. Applies the following linear dequantization to the input tensor x: - y = x * ( 2^shift ) + y = (x - zero_point) * scale - Where 2^shift is equal to scale of the dequantize operation and - the shift is an attribute of the operation in si32. + Iff log2(scale) is representable as si32 and zero_point == 0, shift is set to log2(scale). + In this case the dequantization is equal to: + y = x * ( 2^shift ) }]; let arguments = (ins XTenNN_AnySignlessOrUnsignedIntegerTensor:$input, - SI32Attr:$shift + OptionalAttr:$shift, + F32Attr:$scale, // Restricted to F32 for now, but may be relaxed in the future + XTenNN_AnyIntegerAttr:$zero_point ); - let results = (outs F32Tensor:$output); + let results = (outs XTenNN_AnyFloatTensor:$output); + + let builders = [ + OpBuilder<(ins "::mlir::Type":$output, "::mlir::Value":$input, "::mlir::FloatAttr":$scale, "::mlir::IntegerAttr":$zeroPoint)>, + OpBuilder<(ins "::mlir::Type":$output, "::mlir::Value":$input, "int32_t":$shift)> + ]; - let assemblyFormat = [{ `(`$input `:` type($input)`)` attr-dict `->` type($output) }]; + // `(`$input `:` type($input)`)` attr-dict `->` type($output); + // Zero point is optional if scale is set, defaults to 0 + // If shift is set but not scale, scale is based on the shift. Setting a zero point is not allowed in this case + let hasCustomAssemblyFormat = 1; + let hasVerifier = 1; } def XTenNN_GroupQuantizeOp: XTenNN_Op<"group_quantize", [ diff --git a/include/xten/Dialect/XTenNN/IR/XTenNNTypes.td b/include/xten/Dialect/XTenNN/IR/XTenNNTypes.td index cb060fc1..d6e725ed 100644 --- a/include/xten/Dialect/XTenNN/IR/XTenNNTypes.td +++ b/include/xten/Dialect/XTenNN/IR/XTenNNTypes.td @@ -22,6 +22,15 @@ def XTenNN_AnySignlessOrUnsignedIntegerTensor : TensorOf< def XTenNN_AnyFloatTensor : TensorOf<[AnyFloat]>; +def XTenNN_AnyIntegerAttr: + TypedAttrBase< + AnyInteger, "IntegerAttr", + CPred<"::llvm::isa<::mlir::IntegerAttr>($_self)">, + "Any integer attr"> { + let returnType = [{ ::llvm::APInt }]; + let constBuilderCall = ?; +} + def XTenNN_AnyIntegerOrFloat : AnyTypeOf<[AnyInteger, AnyFloat], "Integer or Float">; #endif // XTENNN_TYPES \ No newline at end of file diff --git a/lib/Conversion/TosaToXTenNN.cpp b/lib/Conversion/TosaToXTenNN.cpp index 19eef273..c0412319 100644 --- a/lib/Conversion/TosaToXTenNN.cpp +++ b/lib/Conversion/TosaToXTenNN.cpp @@ -224,7 +224,11 @@ class FoldMulsToQDQOps : public OpRewritePattern { } // Sum the shifts of the quantize, dequantize and update the operations - llvm::APInt shiftSum(32, dequantizeOp.getShift(), true); + if (!dequantizeOp.getShift()) { + return rewriter.notifyMatchFailure(dequantizeOp.getLoc(), + "Dequantize op has no shift"); + } + llvm::APInt shiftSum(32, *dequantizeOp.getShift(), true); bool overflow = false; shiftSum = shiftSum.sadd_ov(dequantizeShift, overflow); if (overflow) { diff --git a/lib/Conversion/XTenNNToTosa.cpp b/lib/Conversion/XTenNNToTosa.cpp index 81f266f1..29830411 100644 --- a/lib/Conversion/XTenNNToTosa.cpp +++ b/lib/Conversion/XTenNNToTosa.cpp @@ -81,6 +81,22 @@ IntegerMinMax calculateMinMaxOfElementType(TensorType type) { return IntegerMinMax{minValue.getSExtValue(), maxValue.getSExtValue()}; } +namespace { +APFloat convertF32AttrToFloatTy(FloatAttr attr, Type typeToConvertTo) { + // Convert from f32 to the float type that is actually used + assert(attr.getType().isF32()); + assert(isa(typeToConvertTo)); + auto floatResultType = cast(typeToConvertTo); + APFloat scale = attr.getValue(); + bool losesInfo; + // Ignore inaccuracies, there is nothing we can do. + [[maybe_unused]] const auto conversionResult = + scale.convert(floatResultType.getFloatSemantics(), + llvm::RoundingMode::NearestTiesToEven, &losesInfo); + return scale; +} +} // namespace + class QuantizeOp : public OpRewritePattern { public: using OpRewritePattern::OpRewritePattern; @@ -90,37 +106,52 @@ class QuantizeOp : public OpRewritePattern { // The QDQ operations only work on tensors, if they are not, then the // verifiers should find the error. At the moment, only signless tensors // are supported. - auto outputType = cast(quantizeOp->getResult(0).getType()); + const auto outputType = + cast(quantizeOp->getResult(0).getType()); if (!outputType.getElementType().isSignlessInteger()) { return rewriter.notifyMatchFailure( quantizeOp.getLoc(), "only signless integer tensor types are supported."); } - auto inputType = dyn_cast(quantizeOp->getOperand(0).getType()); - - // Calculate (1 / 2 ^ shift) - llvm::APFloat scale(std::pow(static_cast(2.0), - static_cast(-quantizeOp.getShift()))); - - // Create a constant that represents the (1 / 2 ^ shift) - RankedTensorType constType = - createSplatType(inputType.getRank(), rewriter.getF32Type()); + const auto inputType = + cast(quantizeOp->getOperand(0).getType()); + const auto inputElementType = inputType.getElementType(); + + // Convert the scale from f32 to the float type that is actually used + const llvm::APFloat scale = + convertF32AttrToFloatTy(quantizeOp.getScaleAttr(), inputElementType); + const llvm::APFloat scaleReciprocal = + llvm::APFloat::getOne(scale.getSemantics()) / scale; + + const RankedTensorType constType = + createSplatType(inputType.getRank(), inputElementType); auto constOp = rewriter.create( quantizeOp->getLoc(), constType, - DenseFPElementsAttr::get(constType, {scale})); + DenseFPElementsAttr::get(constType, {scaleReciprocal})); - // Calculate (x / 2 ^ shift) auto mulOp = rewriter.create( quantizeOp.getLoc(), inputType, quantizeOp->getOperand(0), constOp->getResult(0), rewriter.getI8IntegerAttr(0)); + const auto constAddType = + createSplatType(inputType.getRank(), outputType.getElementType()); + auto constAddOp = rewriter.create( + quantizeOp.getLoc(), constAddType, + DenseIntElementsAttr::get(constAddType, {quantizeOp.getZeroPoint()})); + auto constAddCastOp = rewriter.create( + quantizeOp.getLoc(), inputType, constAddOp.getResult()); + auto zeroPointAdd = rewriter.create( + quantizeOp.getLoc(), inputType, mulOp.getResult(), + constAddCastOp.getResult()); + // TOSA only supports signed integers of i8, i16 or i32 here we convert our // si to this types and add a clamp to mimic arbitrary bit width. TensorType newIntegerStorageType = getNewStorageType(outputType); // Cast from fp32 -> i where bit width is the supported storage // bit width. Either i8, i16 or i32 - auto castOp = rewriter.create( - quantizeOp->getLoc(), newIntegerStorageType, mulOp->getResult(0)); + auto castOp = rewriter.create(quantizeOp->getLoc(), + newIntegerStorageType, + zeroPointAdd->getResult(0)); // Find the max and min of the signed integer type. IntegerMinMax intLimits = calculateMinMaxOfElementType(outputType); @@ -156,7 +187,10 @@ class DequantizeOp : public OpRewritePattern { // The QDQ operations only work on tensors, if they are not, then the // verifiers should find the error. At the moment, only signless tensors // are supported. - auto inputType = cast(dequantizeOp->getOperand(0).getType()); + const auto resultElementType = + dequantizeOp.getResult().getType().getElementType(); + const auto inputType = + cast(dequantizeOp->getOperand(0).getType()); if (!inputType.getElementType().isSignlessInteger()) { return rewriter.notifyMatchFailure( dequantizeOp.getLoc(), @@ -170,18 +204,30 @@ class DequantizeOp : public OpRewritePattern { dequantizeOp.getLoc(), newIntegerStorageType, dequantizeOp->getOperand(0)); - // We can then cast from i<8,16,32> -> fp32 + // We can then cast from i<8,16,32> -> fp auto castOp = rewriter.create( dequantizeOp->getLoc(), dequantizeOp->getResult(0).getType(), unrealizedCast.getResult(0)); - // Calculate the (x * 2 ^ shift) for the dequantize part - llvm::APFloat scale(std::pow(static_cast(2.0), - static_cast(dequantizeOp.getShift()))); + // Do the zero_point sub on the float type to to avoid underflows + const auto constSubType = + createSplatType(inputType.getRank(), inputType.getElementType()); + auto constSubOp = rewriter.create( + dequantizeOp.getLoc(), constSubType, + DenseIntElementsAttr::get(constSubType, {dequantizeOp.getZeroPoint()})); + auto constSubCastOp = rewriter.create( + dequantizeOp.getLoc(), dequantizeOp.getResult().getType(), + constSubOp.getResult()); + auto zeroPointSub = rewriter.create( + dequantizeOp.getLoc(), dequantizeOp.getResult().getType(), + castOp.getResult(), constSubCastOp.getResult()); + + // Convert the scale from f32 to the float type that is actually used + const llvm::APFloat scale = + convertF32AttrToFloatTy(dequantizeOp.getScaleAttr(), resultElementType); // Create a constant to hold the floating point scale we just calculated - auto constType = - createSplatType(inputType.getRank(), rewriter.getF32Type()); + auto constType = createSplatType(inputType.getRank(), resultElementType); auto constOp = rewriter.create( dequantizeOp->getLoc(), constType, DenseFPElementsAttr::get(constType, {scale})); @@ -189,7 +235,7 @@ class DequantizeOp : public OpRewritePattern { // Replace the dequantize op with the new operations we just created. rewriter.replaceOpWithNewOp( dequantizeOp, dequantizeOp->getResult(0).getType(), - castOp->getResult(0), constOp->getResult(0), + zeroPointSub->getResult(0), constOp->getResult(0), rewriter.getI8IntegerAttr(0)); return success(); } diff --git a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp index e6582841..6fbb751a 100644 --- a/lib/Dialect/XTenNN/IR/XTenNNOps.cpp +++ b/lib/Dialect/XTenNN/IR/XTenNNOps.cpp @@ -35,6 +35,7 @@ #include "xten/Dialect/XTenNN/IR/XTenNNBase.h" #include "xten/Dialect/XTenNN/Interfaces/EnclaveOpInterfaces.h" +#include #include using namespace mlir; @@ -444,25 +445,323 @@ LogicalResult SubgraphOp::inferReturnTypeComponents( // XTenNNDialect //===----------------------------------------------------------------------===// -OpFoldResult amd::xten_nn::QuantizeOp::fold(FoldAdaptor adaptor) { - // Fold away cases where a xten_nn.quantize is preceeded by xten_nn.dequantize - // that uses the same shift factor and has same types. +namespace { +std::optional getShiftValue(float constValue) { + const float log2Value = std::log2f(constValue); - auto dequantizeOp = - dyn_cast_or_null(getInput().getDefiningOp()); - if (!dequantizeOp) + // The log2 of the value must not have fractions. + if (std::roundf(log2Value) != log2Value) + return {}; + + return static_cast(log2Value); +} + +std::optional getScaleFromShift(int32_t shift) { + std::feclearexcept(FE_ALL_EXCEPT); + errno = 0; + const float scale = exp2f(shift); + if (std::fetestexcept(FE_OVERFLOW) || errno) { return {}; + } + return scale; +} + +template +mlir::ParseResult parseQuantizeDequantizeLikeOp( + mlir::OpAsmParser &parser, mlir::OperationState &result, + mlir::StringAttr scaleAttrName, mlir::StringAttr zeroPointAttrName, + mlir::StringAttr shiftAttrName) { + // Parse operands + if (parser.parseLParen()) + return mlir::failure(); + + const SMLoc inputOperandsLoc = parser.getCurrentLocation(); + mlir::OpAsmParser::UnresolvedOperand inputRawOperand{}; + ArrayRef inputOperands(&inputRawOperand, + 1); + if (parser.parseOperand(inputRawOperand)) + return mlir::failure(); + if (parser.parseColon()) + return mlir::failure(); + + mlir::TensorType operandType; + ArrayRef inputTypes(&operandType, 1); + { + if (parser.parseCustomTypeWithFallback(operandType)) + return mlir::failure(); + } + if (parser.parseRParen()) + return mlir::failure(); + + // Parse attributes + auto attrLoc = parser.getCurrentLocation(); + { + + if (parser.parseOptionalAttrDict(result.attributes)) + return mlir::failure(); + + if (failed(Op::verifyInherentAttrs(result.name, result.attributes, [&]() { + return parser.emitError(attrLoc) + << "'" << result.name.getStringRef() << "' op "; + }))) + return mlir::failure(); + } + + // Parse return type + if (parser.parseArrow()) + return mlir::failure(); + mlir::TensorType resultType; + ArrayRef outputTypes(&resultType, 1); + { + if (parser.parseCustomTypeWithFallback(resultType)) + return mlir::failure(); + } + result.addTypes(outputTypes); + + auto builder = parser.getBuilder(); + + // Handle missing scale + if (!result.attributes.getNamed(scaleAttrName)) { + if (result.attributes.getNamed(zeroPointAttrName)) { + return parser.emitError(attrLoc) << "'" << result.name.getStringRef() + << "' op: It is only allowed to set a " + "zero point if scale is set too"; + } + const auto shiftAttr = result.attributes.getNamed(shiftAttrName); + if (!shiftAttr) { + return parser.emitError(attrLoc) + << "'" << result.name.getStringRef() + << "' op: Shift and scale are both missing"; + } else { + const auto scale = + getScaleFromShift(cast(shiftAttr->getValue()).getSInt()); + if (!scale) { + return parser.emitError(attrLoc) + << "'" << result.name.getStringRef() + << "' op: Could not calculate scale from shift"; + } + auto scaleAttr = builder.getF32FloatAttr(*scale); + result.addAttribute(scaleAttrName, scaleAttr); + } + } + + // Handle missing zeroPoint -> default to 0 + if (!result.attributes.getNamed(zeroPointAttrName)) { + mlir::Type zeroPointType; + if constexpr (ZeroPointTypeIsSameAsOperand) { + zeroPointType = operandType.getElementType(); + } else { + zeroPointType = resultType.getElementType(); + } + auto zeroPointAttr = builder.getIntegerAttr(zeroPointType, 0); + result.addAttribute(zeroPointAttrName, zeroPointAttr); + } + + // Try to populate shift form scale, but only if the zero point is zero + if (!result.attributes.getNamed(shiftAttrName)) { + const auto scaleAttr = result.attributes.getNamed(scaleAttrName); + if (scaleAttr) { + const auto calculatedShift = + getShiftValue(cast(scaleAttr->getValue()) + .getValue() + .convertToFloat()); + if (calculatedShift && + cast( + result.attributes.getNamed(zeroPointAttrName)->getValue()) + .getValue() + .isZero()) { + result.addAttribute(shiftAttrName, + builder.getSI32IntegerAttr(*calculatedShift)); + } + } + } + + if (parser.resolveOperands(inputOperands, inputTypes, inputOperandsLoc, + result.operands)) + return mlir::failure(); + return mlir::success(); +} + +void printQuantizeDequantizeLikeOp(mlir::OpAsmPrinter &p, + mlir::TypedValue input, + mlir::TypedValue output, + ArrayRef attrs, + StringRef zeroPointName, + bool zeroPointIsZero, StringRef scaleName, + bool shiftIsSet) { + p << "("; + p << input; + p << ' ' << ":"; + p << ' '; + { + auto type = input.getType(); + if (auto validType = dyn_cast(type)) + p.printStrippedAttrOrType(validType); + else + p << type; + } + p << ")"; + SmallVector elidedAttrs; + // Skip printing scale and zero point if zero point is 0 and shift is set + if (zeroPointIsZero && shiftIsSet) { + elidedAttrs.push_back(zeroPointName); + elidedAttrs.push_back(scaleName); + } + p.printOptionalAttrDict(attrs, elidedAttrs); + p << ' ' << "->"; + p << ' '; + { + auto type = output.getType(); + if (auto validType = ::llvm::dyn_cast(type)) + p.printStrippedAttrOrType(validType); + else + p << type; + } +} +} // namespace - if (!dequantizeOp->hasOneUse() || dequantizeOp.getShift() != getShift()) +mlir::ParseResult QuantizeOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseQuantizeDequantizeLikeOp( + parser, result, getScaleAttrName(result.name), + getZeroPointAttrName(result.name), getShiftAttrName(result.name)); +} + +void QuantizeOp::print(mlir::OpAsmPrinter &p) { + return printQuantizeDequantizeLikeOp( + p, getInput(), getOutput(), (*this)->getAttrs(), getZeroPointAttrName(), + getZeroPointAttr() && getZeroPoint().isZero(), getScaleAttrName(), + getShift().has_value()); +} + +void QuantizeOp::build(mlir::OpBuilder &odsBuilder, + mlir::OperationState &odsState, mlir::Type output, + mlir::Value input, mlir::FloatAttr scale, + mlir::IntegerAttr zeroPoint) { + const bool zeroPointIsZero = zeroPoint.getValue().isZero(); + assert(scale.getType().isF32()); + const auto shiftValue = getShiftValue(scale.getValue().convertToFloat()); + if (zeroPointIsZero && shiftValue) { + return build(odsBuilder, odsState, output, input, + odsBuilder.getSI32IntegerAttr(*shiftValue), scale, zeroPoint); + } + return build(odsBuilder, odsState, output, input, nullptr, scale, zeroPoint); +} + +void QuantizeOp::build(mlir::OpBuilder &odsBuilder, + mlir::OperationState &odsState, mlir::Type output, + mlir::Value input, int32_t shift) { + const auto outputElemType = cast(output).getElementType(); + const auto scale = getScaleFromShift(shift); + assert(scale && "Could not calculate scale from shift"); + return build(odsBuilder, odsState, output, input, + odsBuilder.getSI32IntegerAttr(shift), + odsBuilder.getF32FloatAttr(*scale), + odsBuilder.getIntegerAttr(outputElemType, 0)); +} + +LogicalResult QuantizeOp::verify() { + if (getResult().getType().getElementType() != getZeroPointAttr().getType()) { + return emitOpError("Result elem type needs to match match zero point type"); + } + // if shift is set, zero point needs to be zero and scale needs to match + if (getShift()) { + const auto computedShift = getShiftValue(getScale().convertToFloat()); + if (!computedShift || computedShift != *getShift()) { + return emitOpError( + "Shift set, but does not match shift calculated from scale"); + } + if (!getZeroPoint().isZero()) { + return emitOpError("Shift set, but zero_point not zero"); + } + } + + return success(); +} + +OpFoldResult QuantizeOp::fold(FoldAdaptor /*adaptor*/) { + // Fold away cases where a xten_nn.quantize is preceeded by + // xten_nn.dequantize that uses the same scale factor, zeroPoint and has + // same types. + + auto dequantizeOp = + dyn_cast_or_null(getInput().getDefiningOp()); + if (!dequantizeOp) return {}; auto dequantizeInput = dequantizeOp.getInput(); if (dequantizeInput.getType() != getType()) return {}; + if (!dequantizeOp->hasOneUse() || dequantizeOp.getScale() != getScale() || + dequantizeOp.getZeroPoint() != getZeroPoint()) + return {}; + return dequantizeInput; } +mlir::ParseResult DequantizeOp::parse(mlir::OpAsmParser &parser, + mlir::OperationState &result) { + return parseQuantizeDequantizeLikeOp( + parser, result, getScaleAttrName(result.name), + getZeroPointAttrName(result.name), getShiftAttrName(result.name)); +} + +void DequantizeOp::print(mlir::OpAsmPrinter &p) { + return printQuantizeDequantizeLikeOp( + p, getInput(), getOutput(), (*this)->getAttrs(), getZeroPointAttrName(), + getZeroPointAttr() && getZeroPoint().isZero(), getScaleAttrName(), + getShift().has_value()); +} + +void DequantizeOp::build(mlir::OpBuilder &odsBuilder, + mlir::OperationState &odsState, mlir::Type output, + mlir::Value input, mlir::FloatAttr scale, + mlir::IntegerAttr zeroPoint) { + const bool zeroPointIsZero = zeroPoint.getValue().isZero(); + assert(scale.getType().isF32()); + const auto shiftValue = getShiftValue(scale.getValue().convertToFloat()); + if (zeroPointIsZero && shiftValue) { + return build(odsBuilder, odsState, output, input, + odsBuilder.getSI32IntegerAttr(*shiftValue), scale, zeroPoint); + } + return build(odsBuilder, odsState, output, input, nullptr, scale, zeroPoint); +} + +void DequantizeOp::build(mlir::OpBuilder &odsBuilder, + mlir::OperationState &odsState, mlir::Type output, + mlir::Value input, int32_t shift) { + const auto inputElemType = cast(input.getType()).getElementType(); + const auto scale = getScaleFromShift(shift); + assert(scale && "Could not calculate scale from shift"); + return build(odsBuilder, odsState, output, input, + odsBuilder.getSI32IntegerAttr(shift), + odsBuilder.getF32FloatAttr(*scale), + odsBuilder.getIntegerAttr(inputElemType, 0)); +} + +LogicalResult DequantizeOp::verify() { + // Input elem type should match zero point type + if (cast(getOperand().getType()).getElementType() != + getZeroPointAttr().getType()) { + return emitOpError( + "Operand elem type needs to match match zero point type"); + } + // if shift is set, zero point needs to be zero and scale needs to match + if (getShift()) { + const auto computedShift = getShiftValue(getScale().convertToFloat()); + if (!computedShift || computedShift != *getShift()) { + return emitOpError( + "Shift set, but does not match shift calculated from scale"); + } + if (!getZeroPoint().isZero()) { + return emitOpError("Shift set, but zero_point not zero"); + } + } + + return success(); +} + OpFoldResult amd::xten_nn::GroupQuantizeOp::fold(FoldAdaptor adaptor) { // Fold away cases where a xten_nn.group_quantize is preceeded by // xten_nn.group_dequantize that uses the same shift factor and has same diff --git a/lib/Dialect/XTenNN/Transforms/QDQConcat.cpp b/lib/Dialect/XTenNN/Transforms/QDQConcat.cpp index 6a3fd5d0..c61fc9b2 100644 --- a/lib/Dialect/XTenNN/Transforms/QDQConcat.cpp +++ b/lib/Dialect/XTenNN/Transforms/QDQConcat.cpp @@ -36,7 +36,8 @@ struct RemoveQDQBetweenConcat : public OpRewritePattern { PatternRewriter &rewriter) const override { // Match concat->QDQ->concat and remove QDQ, if concats would be foldable. // Removing a QDQ is already destructive. Try to be a little-less - // destructive by checking that the QDQ nodes have the same shift. + // destructive by checking that the QDQ nodes have the same scale + + // zeropoint. auto quantize = llvm::dyn_cast_or_null(op.getInput().getDefiningOp()); if (!quantize) { @@ -44,9 +45,16 @@ struct RemoveQDQBetweenConcat : public OpRewritePattern { op, "DequantizeOp input not produced by QuantizeOp."); } - if (quantize.getShift() != op.getShift()) { - return rewriter.notifyMatchFailure( - op, "DequantizeOp and QuantizeOp do not share the same shift value."); + if (quantize.getScale() != op.getScale()) { + return rewriter.notifyMatchFailure(op, + "DequantizeOp and QuantizeOp do not " + "share the same scale"); + } + + if (quantize.getZeroPoint() != op.getZeroPoint()) { + return rewriter.notifyMatchFailure(op, + "DequantizeOp and QuantizeOp do not " + "share the same zero_point."); } // Try to match an incoming concat diff --git a/test/Conversion/TosaToXTenNN/quantization.mlir b/test/Conversion/TosaToXTenNN/quantization.mlir index 558ccb95..b32ebac5 100644 --- a/test/Conversion/TosaToXTenNN/quantization.mlir +++ b/test/Conversion/TosaToXTenNN/quantization.mlir @@ -229,8 +229,32 @@ module attributes {} { %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> - %3 = xten_nn.quantize(%2 : tensor<1x3x4x4xf32>) {shift = 3 : si32} -> tensor<1x3x4x4xi8> - %4 = xten_nn.dequantize(%3 : tensor<1x3x4x4xi8>) {shift = 3 : si32} -> tensor<1x3x4x4xf32> + %3 = xten_nn.quantize(%2 : tensor<1x3x4x4xf32>) {shift = 3 : si32, scale = 8.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xi8> + %4 = xten_nn.dequantize(%3 : tensor<1x3x4x4xi8>) {shift = 3 : si32, scale = 8.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xf32> + %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + return %5 : tensor<1x3x4x4xf32> + } +} + +// -- + +module attributes {} { +// CHECK-LABEL: func.func @sum_shifts_no_shifts +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK: [[VAR_2_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_3_:%.+]] = xten_nn.quantize([[VAR_2_]] : tensor<1x3x4x4xf32>) {scale = 7.000000e+00 : f32, zero_point = 0 : i8} -> tensor<1x3x4x4xi8> +// CHECK: [[VAR_4_:%.+]] = xten_nn.dequantize([[VAR_3_]] : tensor<1x3x4x4xi8>) {scale = 7.000000e+00 : f32, zero_point = 0 : i8} -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_5_:%.+]] = tosa.mul [[VAR_4_]], [[VAR_1_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_5_]] : tensor<1x3x4x4xf32> +// CHECK: } + func.func @sum_shifts_no_shifts(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { + %0 = "tosa.const"() {value = dense<3.200000e+01> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> + %1 = "tosa.const"() {value = dense<3.125000e-02> : tensor<1x1x1x1xf32>} : () -> tensor<1x1x1x1xf32> + %2 = "tosa.mul"(%arg0, %0) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> + %3 = xten_nn.quantize(%2 : tensor<1x3x4x4xf32>) {scale = 7.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xi8> + %4 = xten_nn.dequantize(%3 : tensor<1x3x4x4xi8>) {scale = 7.0 : f32, zero_point = 0: i8} -> tensor<1x3x4x4xf32> %5 = "tosa.mul"(%4, %1) {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> return %5 : tensor<1x3x4x4xf32> } diff --git a/test/Conversion/XTenNNToTosa/quantization.mlir b/test/Conversion/XTenNNToTosa/quantization.mlir index 620d2eab..dbad4781 100644 --- a/test/Conversion/XTenNNToTosa/quantization.mlir +++ b/test/Conversion/XTenNNToTosa/quantization.mlir @@ -13,24 +13,51 @@ // RUN: aten-opt %s --xten-nn-to-tosa --split-input-file | FileCheck %s module attributes{} { -// CHECK-LABEL: func.func @explicit_case( -// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { -// CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> -// CHECK-DAG: %[[VAL_3:.*]] = tosa.mul %[[VAL_0]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK-DAG: %[[VAL_4:.*]] = tosa.cast %[[VAL_3]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> -// CHECK-DAG: %[[VAL_5:.*]] = tosa.cast %[[VAL_4]] : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> -// CHECK-DAG: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_2]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> -// CHECK-DAG: return %[[VAL_6]] : tensor<1x3x4x4xf32> +// CHECK-LABEL: func.func @explicit_case +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x3x4x4xf32>}> : () -> tensor<1x3x4x4xf32> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xf32>}> : () -> tensor<1x1x1x1xf32> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_2_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_4_:%.+]] = tosa.add [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> +// CHECK: [[VAR_6_:%.+]] = tosa.cast [[VAR_5_]] : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_7_:%.+]] = tosa.sub [[VAR_6_]], [[VAR_1_]] : (tensor<1x3x4x4xf32>, tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> +// CHECK: [[VAR_8_:%.+]] = tosa.mul [[VAR_7_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> +// CHECK: return [[VAR_8_]] : tensor<1x3x4x4xf32> +// CHECK: } func.func @explicit_case(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { - %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {shift = -5 : si32} -> tensor<1x3x4x4xi8> - %1 = xten_nn.dequantize(%0 : tensor<1x3x4x4xi8>) {shift = -5 : si32} -> tensor<1x3x4x4xf32> + %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x3x4x4xi8> + %1 = xten_nn.dequantize(%0 : tensor<1x3x4x4xi8>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x3x4x4xf32> return %1 : tensor<1x3x4x4xf32> } } // -- +module attributes{} { +// CHECK-LABEL: func.func @explicit_case_bf16 +// CHECK-SAME: ([[PARAM_0_:%.+]]: tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xbf16> { +// CHECK-DAG: [[VAR_0_:%.+]] = "tosa.const"() <{value = dense<3.125000e-02> : tensor<1x1x1x1xbf16>}> : () -> tensor<1x1x1x1xbf16> +// CHECK-DAG: [[VAR_1_:%.+]] = "tosa.const"() <{value = dense<1.000000e+00> : tensor<1x3x4x4xbf16>}> : () -> tensor<1x3x4x4xbf16> +// CHECK-DAG: [[VAR_2_:%.+]] = "tosa.const"() <{value = dense<3.200000e+01> : tensor<1x1x1x1xbf16>}> : () -> tensor<1x1x1x1xbf16> +// CHECK: [[VAR_3_:%.+]] = tosa.mul [[PARAM_0_]], [[VAR_2_]] {shift = 0 : i8} : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>) -> tensor<1x3x4x4xbf16> +// CHECK: [[VAR_4_:%.+]] = tosa.add [[VAR_3_]], [[VAR_1_]] : (tensor<1x3x4x4xbf16>, tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xbf16> +// CHECK: [[VAR_5_:%.+]] = tosa.cast [[VAR_4_]] : (tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xi8> +// CHECK: [[VAR_6_:%.+]] = tosa.cast [[VAR_5_]] : (tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xbf16> +// CHECK: [[VAR_7_:%.+]] = tosa.sub [[VAR_6_]], [[VAR_1_]] : (tensor<1x3x4x4xbf16>, tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xbf16> +// CHECK: [[VAR_8_:%.+]] = tosa.mul [[VAR_7_]], [[VAR_0_]] {shift = 0 : i8} : (tensor<1x3x4x4xbf16>, tensor<1x1x1x1xbf16>) -> tensor<1x3x4x4xbf16> +// CHECK: return [[VAR_8_]] : tensor<1x3x4x4xbf16> +// CHECK: } + func.func @explicit_case_bf16(%arg0: tensor<1x3x4x4xbf16>) -> tensor<1x3x4x4xbf16> { + %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xbf16>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x3x4x4xi8> + %1 = xten_nn.dequantize(%0 : tensor<1x3x4x4xi8>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x3x4x4xbf16> + return %1 : tensor<1x3x4x4xbf16> + } +} + +// -- + module attributes{} { // CHECK-LABEL: func.func @small_tensors( // CHECK-SAME: %[[VAL_0:.*]]: tensor<2x3xf32>) -> tensor<2x3xf32> { @@ -43,8 +70,8 @@ module attributes{} { // CHECK-DAG: %[[VAL_7:.*]] = tosa.mul %[[VAL_6]], %[[VAL_2]] {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32> // CHECK-DAG: return %[[VAL_7]] : tensor<2x3xf32> func.func @small_tensors(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> { - %0 = xten_nn.quantize(%arg0 : tensor<2x3xf32>) {shift = 3 : si32} -> tensor<2x3xi4> - %1 = xten_nn.dequantize(%0 : tensor<2x3xi4>) {shift = 3 : si32} -> tensor<2x3xf32> + %0 = xten_nn.quantize(%arg0 : tensor<2x3xf32>) {scale = 8.0 : f32, shift = 3 : si32, zero_point = 0 : i4} -> tensor<2x3xi4> + %1 = xten_nn.dequantize(%0 : tensor<2x3xi4>) {scale = 8.0 : f32, shift = 3 : si32, zero_point = 0 : i4} -> tensor<2x3xf32> return %1 : tensor<2x3xf32> } } @@ -59,7 +86,7 @@ module attributes{} { // CHECK-DAG: %[[VAL_3:.*]] = tosa.cast %[[VAL_2]] : (tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> // CHECK-DAG: return %[[VAL_3]] : tensor<1x3x4x4xi8> func.func @quantize_case(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xi8> { - %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {shift = -5 : si32} -> tensor<1x3x4x4xi8> + %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i8} -> tensor<1x3x4x4xi8> return %0 : tensor<1x3x4x4xi8> } } @@ -74,7 +101,7 @@ module attributes{} { // CHECK-DAG: %[[VAL_4:.*]] = tosa.mul %[[VAL_3]], %[[VAL_1]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> // CHECK-DAG: return %[[VAL_4]] : tensor<1x3x4x4xf32> func.func @dequantize_case(%arg0: tensor<1x3x4x4xi8>) -> tensor<1x3x4x4xf32> { - %0 = xten_nn.dequantize(%arg0 : tensor<1x3x4x4xi8>) {shift = -5 : si32} -> tensor<1x3x4x4xf32> + %0 = xten_nn.dequantize(%arg0 : tensor<1x3x4x4xi8>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i8} -> tensor<1x3x4x4xf32> return %0 : tensor<1x3x4x4xf32> } } @@ -92,8 +119,8 @@ module attributes{} { // CHECK-DAG: %[[VAL_6:.*]] = tosa.mul %[[VAL_5]], %[[VAL_2]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> // CHECK-DAG: return %[[VAL_6]] : tensor<1x3x4x4xf32> func.func @i16_case(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { - %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {shift = -5 : si32} -> tensor<1x3x4x4xi16> - %1 = xten_nn.dequantize(%0 : tensor<1x3x4x4xi16>) {shift = -5 : si32} -> tensor<1x3x4x4xf32> + %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i16} -> tensor<1x3x4x4xi16> + %1 = xten_nn.dequantize(%0 : tensor<1x3x4x4xi16>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i16} -> tensor<1x3x4x4xf32> return %1 : tensor<1x3x4x4xf32> } } @@ -112,8 +139,8 @@ module attributes{} { // CHECK-DAG: %[[VAL_7:.*]] = tosa.mul %[[VAL_6]], %[[VAL_2]] {shift = 0 : i8} : (tensor<1x3x4x4xf32>, tensor<1x1x1x1xf32>) -> tensor<1x3x4x4xf32> // CHECK-DAG: return %[[VAL_7]] : tensor<1x3x4x4xf32> func.func @i12_case(%arg0: tensor<1x3x4x4xf32>) -> tensor<1x3x4x4xf32> { - %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {shift = -5 : si32} -> tensor<1x3x4x4xi12> - %1 = xten_nn.dequantize(%0 : tensor<1x3x4x4xi12>) {shift = -5 : si32} -> tensor<1x3x4x4xf32> + %0 = xten_nn.quantize(%arg0 : tensor<1x3x4x4xf32>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i12} -> tensor<1x3x4x4xi12> + %1 = xten_nn.dequantize(%0 : tensor<1x3x4x4xi12>) {scale = 3.125000e-02 : f32, shift = -5 : si32, zero_point = 0 : i12} -> tensor<1x3x4x4xf32> return %1 : tensor<1x3x4x4xf32> } } diff --git a/test/Dialect/XTenNN/Transform/CanonicalizePass/remove_qdq_between_concats.mlir b/test/Dialect/XTenNN/Transform/CanonicalizePass/remove_qdq_between_concats.mlir index 9eaeeb16..72de5d9e 100644 --- a/test/Dialect/XTenNN/Transform/CanonicalizePass/remove_qdq_between_concats.mlir +++ b/test/Dialect/XTenNN/Transform/CanonicalizePass/remove_qdq_between_concats.mlir @@ -4,8 +4,8 @@ // RUN: aten-opt %s -xtennn-canonicalize -split-input-file | FileCheck %s --check-prefix=SANE func.func @single_qdq(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x1x7x7xf32> { - %0 = "xten_nn.quantize"(%arg0) {shift = -3 : si32} : (tensor<1x1x7x7xf32>) -> tensor<1x1x7x7xi8> - %1 = "xten_nn.dequantize"(%0) {shift = -3 : si32} : (tensor<1x1x7x7xi8>) -> tensor<1x1x7x7xf32> + %0 = xten_nn.quantize(%arg0 : tensor<1x1x7x7xf32>) {shift = -3 : si32} -> tensor<1x1x7x7xi8> + %1 = xten_nn.dequantize(%0 : tensor<1x1x7x7xi8>) {shift = -3 : si32} -> tensor<1x1x7x7xf32> return %1 : tensor<1x1x7x7xf32> } @@ -27,8 +27,8 @@ func.func @single_qdq(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x1x7x7xf32> { func.func @single_concat_at_input(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { %0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> - %1 = "xten_nn.quantize"(%0) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %2 = "xten_nn.dequantize"(%1) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -3 : si32} -> tensor<1x2x7x7xi8> + %2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32> return %2 : tensor<1x2x7x7xf32> } @@ -51,8 +51,8 @@ func.func @single_concat_at_input(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7x // ----- func.func @single_concat_at_output(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> { - %0 = "xten_nn.quantize"(%arg0) {shift = -3 : si32} : (tensor<1x1x7x7xf32>) -> tensor<1x1x7x7xi8> - %1 = "xten_nn.dequantize"(%0) {shift = -3 : si32} : (tensor<1x1x7x7xi8>) -> tensor<1x1x7x7xf32> + %0 = xten_nn.quantize(%arg0 : tensor<1x1x7x7xf32>) {shift = -3 : si32} -> tensor<1x1x7x7xi8> + %1 = xten_nn.dequantize(%0 : tensor<1x1x7x7xi8>) {shift = -3 : si32} -> tensor<1x1x7x7xf32> %2 = "tosa.concat"(%1, %1) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> return %2 : tensor<1x2x7x7xf32> } @@ -77,8 +77,8 @@ func.func @single_concat_at_output(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x2x7x7 func.func @non_foldable_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf32> { %0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> - %1 = "xten_nn.quantize"(%0) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %2 = "xten_nn.dequantize"(%1) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -3 : si32} -> tensor<1x2x7x7xi8> + %2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32> %3 = "tosa.concat"(%2, %2) {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> return %3 : tensor<2x2x7x7xf32> } @@ -105,8 +105,8 @@ func.func @non_foldable_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<2x2x7x7xf3 func.func @foldable_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { %0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> - %1 = "xten_nn.quantize"(%0) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %2 = "xten_nn.dequantize"(%1) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -3 : si32} -> tensor<1x2x7x7xi8> + %2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32> %3 = "tosa.concat"(%2, %2) {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> return %3 : tensor<1x4x7x7xf32> } @@ -130,8 +130,8 @@ func.func @foldable_concats(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { func.func @multiple_foldable_user_concats(%arg0: tensor<1x1x7x7xf32>) -> (tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32>) { %0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> - %1 = "xten_nn.quantize"(%0) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %2 = "xten_nn.dequantize"(%1) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -3 : si32} -> tensor<1x2x7x7xi8> + %2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32> %3 = "tosa.concat"(%2, %2) {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> %4 = "tosa.concat"(%2, %2) {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> return %3, %4 : tensor<1x4x7x7xf32>, tensor<1x4x7x7xf32> @@ -158,8 +158,8 @@ func.func @multiple_foldable_user_concats(%arg0: tensor<1x1x7x7xf32>) -> (tensor func.func @partially_foldable_user_concats(%arg0: tensor<1x1x7x7xf32>) -> (tensor<1x4x7x7xf32>, tensor<2x2x7x7xf32>) { %0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> - %1 = "xten_nn.quantize"(%0) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %2 = "xten_nn.dequantize"(%1) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -3 : si32} -> tensor<1x2x7x7xi8> + %2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32> %3 = "tosa.concat"(%2, %2) {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> // This is foldable %4 = "tosa.concat"(%2, %2) {axis = 0 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<2x2x7x7xf32> // This is not foldable return %3, %4 : tensor<1x4x7x7xf32>, tensor<2x2x7x7xf32> @@ -189,8 +189,8 @@ func.func @partially_foldable_user_concats(%arg0: tensor<1x1x7x7xf32>) -> (tenso func.func @qdq_different_shift(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { %0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> - %1 = "xten_nn.quantize"(%0) {shift = -5 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %2 = "xten_nn.dequantize"(%1) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8> + %2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32> %3 = "tosa.concat"(%2, %2) {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> return %3 : tensor<1x4x7x7xf32> } @@ -212,3 +212,61 @@ func.func @qdq_different_shift(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32 // SANE: %[[VAL_4:.*]] = tosa.concat %[[VAL_3]], %[[VAL_3]] {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> // SANE: return %[[VAL_4]] : tensor<1x4x7x7xf32> // SANE: } + +// ----- + +func.func @qdq_different_zero(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8> + %2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x2x7x7xf32> + %3 = "tosa.concat"(%2, %2) {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> + return %3 : tensor<1x4x7x7xf32> +} + +// CHECK-LABEL: func.func @qdq_different_zero( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8> +// CHECK: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x2x7x7xf32> +// CHECK: %[[VAL_4:.*]] = tosa.concat %[[VAL_3]], %[[VAL_3]] {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> +// CHECK: return %[[VAL_4]] : tensor<1x4x7x7xf32> +// CHECK: } + +// SANE-LABEL: func.func @qdq_different_zero( +// SANE-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { +// SANE: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// SANE: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8> +// SANE: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {scale = 3.125000e-02 : f32, zero_point = 1 : i8} -> tensor<1x2x7x7xf32> +// SANE: %[[VAL_4:.*]] = tosa.concat %[[VAL_3]], %[[VAL_3]] {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> +// SANE: return %[[VAL_4]] : tensor<1x4x7x7xf32> +// SANE: } + + +// ----- + +func.func @qdq_different_scale(%arg0: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { + %0 = "tosa.concat"(%arg0, %arg0) {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> + %1 = xten_nn.quantize(%0 : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8> + %2 = xten_nn.dequantize(%1 : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32> + %3 = "tosa.concat"(%2, %2) {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> + return %3 : tensor<1x4x7x7xf32> +} + +// CHECK-LABEL: func.func @qdq_different_scale( +// CHECK-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { +// CHECK: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// CHECK: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8> +// CHECK: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32> +// CHECK: %[[VAL_4:.*]] = tosa.concat %[[VAL_3]], %[[VAL_3]] {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> +// CHECK: return %[[VAL_4]] : tensor<1x4x7x7xf32> +// CHECK: } + +// SANE-LABEL: func.func @qdq_different_scale( +// SANE-SAME: %[[VAL_0:.*]]: tensor<1x1x7x7xf32>) -> tensor<1x4x7x7xf32> { +// SANE: %[[VAL_1:.*]] = tosa.concat %[[VAL_0]], %[[VAL_0]] {axis = 1 : i32} : (tensor<1x1x7x7xf32>, tensor<1x1x7x7xf32>) -> tensor<1x2x7x7xf32> +// SANE: %[[VAL_2:.*]] = xten_nn.quantize(%[[VAL_1]] : tensor<1x2x7x7xf32>) {shift = -5 : si32} -> tensor<1x2x7x7xi8> +// SANE: %[[VAL_3:.*]] = xten_nn.dequantize(%[[VAL_2]] : tensor<1x2x7x7xi8>) {shift = -3 : si32} -> tensor<1x2x7x7xf32> +// SANE: %[[VAL_4:.*]] = tosa.concat %[[VAL_3]], %[[VAL_3]] {axis = 1 : i32} : (tensor<1x2x7x7xf32>, tensor<1x2x7x7xf32>) -> tensor<1x4x7x7xf32> +// SANE: return %[[VAL_4]] : tensor<1x4x7x7xf32> +// SANE: } + diff --git a/test/Dialect/XTenNN/dequantize.mlir b/test/Dialect/XTenNN/dequantize.mlir index fbb6a08c..e6dd1a44 100644 --- a/test/Dialect/XTenNN/dequantize.mlir +++ b/test/Dialect/XTenNN/dequantize.mlir @@ -3,21 +3,21 @@ // RUN: aten-opt %s -split-input-file -verify-diagnostics func.func @valid_dequantize_op_signed(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { - %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = -3: si32} -> tensor<1x2xf32> + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} -> tensor<1x2xf32> return %result : tensor<1x2xf32> } // ----- func.func @valid_dequantize_op_unsigned(%arg0: tensor<1x2xui8>) -> tensor<1x2xf32> { - %result = xten_nn.dequantize (%arg0: tensor<1x2xui8>) {shift = -3: si32} -> tensor<1x2xf32> + %result = xten_nn.dequantize (%arg0: tensor<1x2xui8>) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : ui8} -> tensor<1x2xf32> return %result : tensor<1x2xf32> } // ----- func.func @valid_dequantize_op_large_scale(%arg0: tensor<1x2xui8>) -> tensor<1x2xf32> { - %result = xten_nn.dequantize (%arg0: tensor<1x2xui8>) {shift = 5: si32} -> tensor<1x2xf32> + %result = xten_nn.dequantize (%arg0: tensor<1x2xui8>) {shift = 5: si32, scale = 32.0 : f32, zero_point = 0 : ui8} -> tensor<1x2xf32> return %result : tensor<1x2xf32> } @@ -26,7 +26,7 @@ func.func @valid_dequantize_op_large_scale(%arg0: tensor<1x2xui8>) -> tensor<1x2 func.func @invalid_shift(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { // expected-error@+1 {{'xten_nn.dequantize' op attribute 'shift' failed to satisfy constraint: 32-bit signed integer attribute}} - %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = 0.135: f32} -> tensor<1x2xf32> + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = 0.135: f32, scale = 0.5 : f32, zero_point = 0 : i8} -> tensor<1x2xf32> return %result : tensor<1x2xf32> } @@ -34,7 +34,7 @@ func.func @invalid_shift(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { func.func @invalid_tensor_signed(%arg0: tensor<1x2xsi8>) -> tensor<1x2xf32> { // expected-error@+1 {{op operand #0 must be signless-or-unsigned-tensor of signless integer or unsigned integer values, but got 'tensor<1x2xsi8>}} - %result = xten_nn.dequantize (%arg0: tensor<1x2xsi8>) {shift = -1: si32} -> tensor<1x2xf32> + %result = xten_nn.dequantize (%arg0: tensor<1x2xsi8>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : si8} -> tensor<1x2xf32> return %result : tensor<1x2xf32> } @@ -42,28 +42,95 @@ func.func @invalid_tensor_signed(%arg0: tensor<1x2xsi8>) -> tensor<1x2xf32> { func.func @invalid_io_shapes(%arg0: tensor<1x3xi8>) -> tensor<1x2xf32> { // expected-error@+1 {{op all non-scalar operands/results must have the same shape and base type}} - %result = xten_nn.dequantize (%arg0: tensor<1x3xi8>) {shift = -1: si32} -> tensor<1x2xf32> + %result = xten_nn.dequantize (%arg0: tensor<1x3xi8>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : i8} -> tensor<1x2xf32> return %result : tensor<1x2xf32> } // ----- func.func @invalid_output_type(%arg0: tensor<1x2xi8>) -> tensor<1x2xi32> { - // expected-error@+1 {{op result #0 must be tensor of 32-bit float values, but got 'tensor<1x2xi32>'}} - %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = -1: si32} -> tensor<1x2xi32> + // expected-error@+1 {{op result #0 must be tensor of floating-point values, but got 'tensor<1x2xi32>'}} + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : i8} -> tensor<1x2xi32> return %result : tensor<1x2xi32> } // ----- func.func @different_bitwidth(%arg0: tensor<1x2xi3>) -> tensor<1x2xf32> { - %result = xten_nn.dequantize (%arg0: tensor<1x2xi3>) {shift = -1: si32} -> tensor<1x2xf32> + %result = xten_nn.dequantize (%arg0: tensor<1x2xi3>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : i3} -> tensor<1x2xf32> return %result : tensor<1x2xf32> } // ----- func.func @sixteen_bitwidth(%arg0: tensor<1x2xi16>) -> tensor<1x2xf32> { - %result = xten_nn.dequantize (%arg0: tensor<1x2xi16>) {shift = -1: si32} -> tensor<1x2xf32> + %result = xten_nn.dequantize (%arg0: tensor<1x2xi16>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : i16} -> tensor<1x2xf32> return %result : tensor<1x2xf32> -} \ No newline at end of file +} + +// ----- + +func.func @valid_dequantize_no_shift(%arg0: tensor<1x2xi16>) -> tensor<1x2xf32> { + %result = xten_nn.dequantize (%arg0: tensor<1x2xi16>) {scale = 0.5 : f32, zero_point = 0 : i16} -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// ----- + +func.func @valid_dequantize_no_zero_point(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {scale = 0.125 : f32} -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// ----- + +func.func @valid_dequantize_only_shift(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = -2 : si32} -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// ----- + +func.func @invalid_dequantize_shift_and_zero(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { + // expected-error@+1 {{It is only allowed to set a zero point if scale is set too}} + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = -2 : si32, zero_point = 3 : i8} -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// ----- + +func.func @invalid_no_shift_or_scale(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { + // expected-error@+1 {{Shift and scale are both missing}} + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {} -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// ----- + +func.func @valid_dqquantize_op_generic(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { + %result = "xten_nn.dequantize" (%arg0) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2xi8>) -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} +// ----- + +func.func @dequantize_op_zero_type_mismatch(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { + // expected-error@+1 {{Operand elem type needs to match match zero point type}} + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i7} -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// ----- + +func.func @dequantize_op_scale_shift_mismatch(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { + // expected-error@+1 {{Shift set, but does not match shift calculated from scale}} + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = -3: si32, scale = 4.0 : f32, zero_point = 0 : i8} -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} + +// ----- + +func.func @quantize_op_zero_not_zero(%arg0: tensor<1x2xi8>) -> tensor<1x2xf32> { + // expected-error@+1 {{Shift set, but zero_point not zero}} + %result = xten_nn.dequantize (%arg0: tensor<1x2xi8>) {shift = -3: si32, scale = 0.125 : f32, zero_point = 3 : i8} -> tensor<1x2xf32> + return %result : tensor<1x2xf32> +} diff --git a/test/Dialect/XTenNN/folding.mlir b/test/Dialect/XTenNN/folding.mlir index 1ae8d6dd..14d6259f 100644 --- a/test/Dialect/XTenNN/folding.mlir +++ b/test/Dialect/XTenNN/folding.mlir @@ -3,10 +3,10 @@ // RUN: aten-opt --test-constant-fold --cse --split-input-file %s -o - | FileCheck %s func.func @simple_dqq_fold(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> { - %1 = "xten_nn.quantize"(%arg0) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %2 = "xten_nn.dequantize"(%1) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> - %3 = "xten_nn.quantize"(%2) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %4 = "xten_nn.dequantize"(%3) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %1 = "xten_nn.quantize"(%arg0) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %2 = "xten_nn.dequantize"(%1) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %3 = "xten_nn.quantize"(%2) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %4 = "xten_nn.dequantize"(%3) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> return %4 : tensor<1x2x7x7xf32> } @@ -19,8 +19,8 @@ func.func @simple_dqq_fold(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> { // ----- func.func @no_fold(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> { - %0 = "xten_nn.quantize"(%arg0) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %1 = "xten_nn.dequantize"(%0) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %0 = "xten_nn.quantize"(%arg0) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %1 = "xten_nn.dequantize"(%0) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> return %1 : tensor<1x2x7x7xf32> } @@ -31,9 +31,9 @@ func.func @no_fold(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> { // ----- func.func @no_dqq_fold_multiple_uses(%arg0: tensor<1x2x7x7xf32>) -> (tensor<1x2x7x7xf32>, tensor<1x2x7x7xi8>) { - %1 = "xten_nn.quantize"(%arg0) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %2 = "xten_nn.dequantize"(%1) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> - %3 = "xten_nn.quantize"(%2) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %1 = "xten_nn.quantize"(%arg0) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %2 = "xten_nn.dequantize"(%1) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %3 = "xten_nn.quantize"(%2) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> return %2, %3 : tensor<1x2x7x7xf32>, tensor<1x2x7x7xi8> } @@ -45,10 +45,10 @@ func.func @no_dqq_fold_multiple_uses(%arg0: tensor<1x2x7x7xf32>) -> (tensor<1x2x // ----- func.func @no_dqq_fold_different_type(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> { - %1 = "xten_nn.quantize"(%arg0) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi16> - %2 = "xten_nn.dequantize"(%1) {shift = -3 : si32} : (tensor<1x2x7x7xi16>) -> tensor<1x2x7x7xf32> - %3 = "xten_nn.quantize"(%2) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %4 = "xten_nn.dequantize"(%3) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %1 = "xten_nn.quantize"(%arg0) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i16} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi16> + %2 = "xten_nn.dequantize"(%1) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i16} : (tensor<1x2x7x7xi16>) -> tensor<1x2x7x7xf32> + %3 = "xten_nn.quantize"(%2) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %4 = "xten_nn.dequantize"(%3) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> return %4 : tensor<1x2x7x7xf32> } @@ -61,10 +61,10 @@ func.func @no_dqq_fold_different_type(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2x // ----- func.func @no_dqq_fold_different_shift(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> { - %1 = "xten_nn.quantize"(%arg0) {shift = -4 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %2 = "xten_nn.dequantize"(%1) {shift = -4 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> - %3 = "xten_nn.quantize"(%2) {shift = -3 : si32} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> - %4 = "xten_nn.dequantize"(%3) {shift = -3 : si32} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %1 = "xten_nn.quantize"(%arg0) {shift = -4 : si32, scale = 0.0625 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %2 = "xten_nn.dequantize"(%1) {shift = -4 : si32, scale = 0.0625 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %3 = "xten_nn.quantize"(%2) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %4 = "xten_nn.dequantize"(%3) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> return %4 : tensor<1x2x7x7xf32> } @@ -72,4 +72,37 @@ func.func @no_dqq_fold_different_shift(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2 // CHECK: xten_nn.quantize // CHECK: xten_nn.dequantize // CHECK: xten_nn.quantize -// CHECK: xten_nn.dequantize \ No newline at end of file +// CHECK: xten_nn.dequantize + +// ----- + +func.func @no_fold_different_zero(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> { + %1 = "xten_nn.quantize"(%arg0) {scale = 0.125 : f32, zero_point = 1 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %2 = "xten_nn.dequantize"(%1) {scale = 0.125 : f32, zero_point = 1 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %3 = "xten_nn.quantize"(%2) {scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %4 = "xten_nn.dequantize"(%3) {scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + return %4 : tensor<1x2x7x7xf32> +} + +// CHECK-LABEL: no_fold_different_zero +// CHECK: xten_nn.quantize +// CHECK: xten_nn.dequantize +// CHECK: xten_nn.quantize +// CHECK: xten_nn.dequantize + + +// ----- + +func.func @no_fold_different_scale(%arg0: tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xf32> { + %1 = "xten_nn.quantize"(%arg0) {scale = 0.5 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %2 = "xten_nn.dequantize"(%1) {scale = 0.5 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + %3 = "xten_nn.quantize"(%2) {scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xf32>) -> tensor<1x2x7x7xi8> + %4 = "xten_nn.dequantize"(%3) {scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2x7x7xi8>) -> tensor<1x2x7x7xf32> + return %4 : tensor<1x2x7x7xf32> +} + +// CHECK-LABEL: no_fold_different_scale +// CHECK: xten_nn.quantize +// CHECK: xten_nn.dequantize +// CHECK: xten_nn.quantize +// CHECK: xten_nn.dequantize diff --git a/test/Dialect/XTenNN/quantize.mlir b/test/Dialect/XTenNN/quantize.mlir index 014341ec..d34b0492 100644 --- a/test/Dialect/XTenNN/quantize.mlir +++ b/test/Dialect/XTenNN/quantize.mlir @@ -3,21 +3,21 @@ // RUN: aten-opt %s -split-input-file -verify-diagnostics func.func @valid_quantize_op_signed(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { - %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -3: si32} -> tensor<1x2xi8> + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} -> tensor<1x2xi8> return %result : tensor<1x2xi8> } // ----- func.func @valid_quantize_op_unsigned(%arg0: tensor<1x2xf32>) -> tensor<1x2xui8> { - %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -3: si32} -> tensor<1x2xui8> + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : ui8} -> tensor<1x2xui8> return %result : tensor<1x2xui8> } // ----- func.func @valid_quantize_op_large_scale(%arg0: tensor<1x2xf32>) -> tensor<1x2xui8> { - %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = 5: si32} -> tensor<1x2xui8> + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = 5: si32, scale = 32.0 : f32, zero_point = 0 : ui8} -> tensor<1x2xui8> return %result : tensor<1x2xui8> } @@ -25,7 +25,7 @@ func.func @valid_quantize_op_large_scale(%arg0: tensor<1x2xf32>) -> tensor<1x2xu func.func @invalid_shift(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { // expected-error@+1 {{'xten_nn.quantize' op attribute 'shift' failed to satisfy constraint: 32-bit signed integer attribute}} - %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = 0.135: f32} -> tensor<1x2xi8> + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = 0.135: f32, scale = 0.5 : f32, zero_point = 0 : i8} -> tensor<1x2xi8> return %result : tensor<1x2xi8> } @@ -33,7 +33,7 @@ func.func @invalid_shift(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { func.func @invalid_tensor_signed(%arg0: tensor<1x2xf32>) -> tensor<1x2xsi8> { // expected-error@+1 {{op result #0 must be signless-or-unsigned-tensor of signless integer or unsigned integer values, but got 'tensor<1x2xsi8>'}} - %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -1: si32} -> tensor<1x2xsi8> + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : si8} -> tensor<1x2xsi8> return %result : tensor<1x2xsi8> } @@ -41,27 +41,95 @@ func.func @invalid_tensor_signed(%arg0: tensor<1x2xf32>) -> tensor<1x2xsi8> { func.func @invalid_io_shapes(%arg0: tensor<1x2xf32>) -> tensor<1x3xi8> { // expected-error@+1 {{op all non-scalar operands/results must have the same shape and base type}} - %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -1: si32} -> tensor<1x3xi8> + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : i8} -> tensor<1x3xi8> return %result : tensor<1x3xi8> } // ----- func.func @invalid_input_type(%arg0: tensor<1x2xi32>) -> tensor<1x2xi8> { - // expected-error@+1 {{op operand #0 must be tensor of 32-bit float values, but got 'tensor<1x2xi32>'}} - %result = xten_nn.quantize (%arg0: tensor<1x2xi32>) {shift = -1: si32} -> tensor<1x2xi8> + // expected-error@+1 {{op operand #0 must be tensor of floating-point values, but got 'tensor<1x2xi32>'}} + %result = xten_nn.quantize (%arg0: tensor<1x2xi32>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : i8} -> tensor<1x2xi8> return %result : tensor<1x2xi8> } // ----- func.func @different_bitwidth(%arg0: tensor<1x2xf32>) -> tensor<1x2xi3> { - %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -1: si32} -> tensor<1x2xi3> + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : i3} -> tensor<1x2xi3> return %result : tensor<1x2xi3> } // ----- func.func @sixteen_bitwidth(%arg0: tensor<1x2xf32>) -> tensor<1x2xi16> { - %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -1: si32} -> tensor<1x2xi16> + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -1: si32, scale = 0.5 : f32, zero_point = 0 : i16} -> tensor<1x2xi16> return %result : tensor<1x2xi16> +} + +// ----- + +func.func @valid_quantize_no_shift(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {scale = 0.125 : f32, zero_point = 3 : i8} -> tensor<1x2xi8> + return %result : tensor<1x2xi8> +} + +// ----- + +func.func @valid_quantize_no_zero_point(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {scale = 0.125 : f32} -> tensor<1x2xi8> + return %result : tensor<1x2xi8> +} + +// ----- + +func.func @valid_quantize_only_shift(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -2 : si32} -> tensor<1x2xi8> + return %result : tensor<1x2xi8> +} + +// ----- + +func.func @invalid_quantize_shift_and_zero(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { + // expected-error@+1 {{It is only allowed to set a zero point if scale is set too}} + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -2 : si32, zero_point = 3 : i8} -> tensor<1x2xi8> + return %result : tensor<1x2xi8> +} + +// ----- + +func.func @invalid_no_shift_or_scale(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { + // expected-error@+1 {{Shift and scale are both missing}} + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {} -> tensor<1x2xi8> + return %result : tensor<1x2xi8> +} + +// ----- + +func.func @valid_quantize_op_generic(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { + %result = "xten_nn.quantize" (%arg0) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i8} : (tensor<1x2xf32>) -> tensor<1x2xi8> + return %result : tensor<1x2xi8> +} + +// ----- + +func.func @quantize_op_zero_type_mismatch(%arg0: tensor<1x2xf32>) -> tensor<1x2xi8> { + // expected-error@+1 {{Result elem type needs to match match zero point type}} + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = -3: si32, scale = 0.125 : f32, zero_point = 0 : i7} -> tensor<1x2xi8> + return %result : tensor<1x2xi8> +} + +// ----- + +func.func @quantize_op_scale_shift_mismatch(%arg0: tensor<1x2xf32>) -> tensor<1x2xui8> { + // expected-error@+1 {{Shift set, but does not match shift calculated from scale}} + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = 5: si32, scale = 31.0 : f32, zero_point = 0 : ui8} -> tensor<1x2xui8> + return %result : tensor<1x2xui8> +} + +// ----- + +func.func @quantize_op_zero_not_zero(%arg0: tensor<1x2xf32>) -> tensor<1x2xui8> { + // expected-error@+1 {{Shift set, but zero_point not zero}} + %result = xten_nn.quantize (%arg0: tensor<1x2xf32>) {shift = 5: si32, scale = 32.0 : f32, zero_point = 1 : ui8} -> tensor<1x2xui8> + return %result : tensor<1x2xui8> } \ No newline at end of file