From 5545d5e2b453a58d32f70000df91082875dfdd34 Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 4 Feb 2025 19:38:56 +0000 Subject: [PATCH] Add target independent optimization pass --- BUILD.bazel | 82 +- ...lehlo_target_independent_optimization.mlir | 12 + stablehlo/tools/StablehloOptMain.cpp | 2 + stablehlo/transforms/CMakeLists.txt | 7 +- stablehlo/transforms/Passes.h | 16 - stablehlo/transforms/Passes.td | 21 - .../transforms/StablehloAggressiveFolder.cpp | 941 ---------- .../StablehloAggressiveSimplification.cpp | 1537 ----------------- ...ablehloAggressiveSimplificationPatterns.td | 427 ----- .../transforms/StablehloRefineShapes.cpp | 1 + .../transforms/optimization/CMakeLists.txt | 45 + stablehlo/transforms/optimization/Passes.h | 57 + stablehlo/transforms/optimization/Passes.td | 45 + .../StablehloAggressiveFolder.cpp | 943 ++++++++++ .../StablehloAggressiveSimplification.cpp | 1537 +++++++++++++++++ ...ablehloAggressiveSimplificationPatterns.td | 427 +++++ ...StablehloTargetIndependentOptimization.cpp | 66 + 17 files changed, 3214 insertions(+), 2952 deletions(-) create mode 100644 stablehlo/tests/transforms/stablehlo_target_independent_optimization.mlir create mode 100644 stablehlo/transforms/optimization/CMakeLists.txt create mode 100644 stablehlo/transforms/optimization/Passes.h create mode 100644 stablehlo/transforms/optimization/Passes.td create mode 100644 stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp create mode 100644 stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp create mode 100644 stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td create mode 100644 stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp diff --git a/BUILD.bazel b/BUILD.bazel index 40cc04aedd..6ee6e89dc8 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -316,11 +316,11 @@ gentbl_cc_library( tbl_outs = [ ( ["--gen-rewriters"], - "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.h.inc", + "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td", + td_file = "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td", deps = [ ":stablehlo_ops_td_files", ], @@ -1125,15 +1125,33 @@ gentbl_cc_library( deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) +cc_library( + name = "stablehlo_pass_utils", + srcs = [ + "stablehlo/transforms/PassUtils.cpp", + ], + hdrs = [ + "stablehlo/transforms/PassUtils.h", + ], + strip_include_prefix = ".", + deps = [ + ":base", + ":chlo_ops", + ":stablehlo_ops", + ":stablehlo_pass_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "stablehlo_passes", srcs = [ "stablehlo/transforms/ChloLegalizeToStablehlo.cpp", "stablehlo/transforms/PassPipelines.cpp", - "stablehlo/transforms/PassUtils.cpp", "stablehlo/transforms/ShapeLegalizeToStablehlo.cpp", - "stablehlo/transforms/StablehloAggressiveFolder.cpp", - "stablehlo/transforms/StablehloAggressiveSimplification.cpp", "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", "stablehlo/transforms/StablehloCompatibilityExpander.cpp", "stablehlo/transforms/StablehloComplexMathExpander.cpp", @@ -1163,13 +1181,14 @@ cc_library( ":chlo_ops", ":chlo_rewriters_inc_gen", ":linalg_passes", - ":stablehlo_aggressive_simplification_inc_gen", ":stablehlo_create_compatibility_expander_inc_gen", ":stablehlo_create_complex_math_expander_inc_gen", ":stablehlo_legalize_deprecated_ops_inc_gen", ":stablehlo_ops", ":stablehlo_ops_inc_gen", ":stablehlo_pass_inc_gen", + ":stablehlo_pass_utils", + ":stablehlo_passes_optimization", ":stablehlo_type_inference", ":version", ":vhlo_ops", @@ -1198,6 +1217,56 @@ cc_library( ], ) +gentbl_cc_library( + name = "stablehlo_passes_optimization_inc_gen", + compatible_with = get_compatible_with_portable(), + strip_include_prefix = ".", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=Optimization", + ], + "stablehlo/transforms/optimization/Passes.h.inc", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/optimization/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + +cc_library( + name = "stablehlo_passes_optimization", + srcs = [ + "stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp", + "stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp", + "stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp", + ], + hdrs = [ + "stablehlo/transforms/optimization/Passes.h", + ], + strip_include_prefix = ".", + deps = [ + ":base", + ":stablehlo_aggressive_simplification_inc_gen", + ":stablehlo_ops", + ":stablehlo_pass_inc_gen", + ":stablehlo_pass_utils", + ":stablehlo_passes_optimization_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:CommonFolders", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "stablehlo_portable_api", srcs = [ @@ -1364,6 +1433,7 @@ cc_binary( ":linalg_passes", ":register", ":stablehlo_passes", + ":stablehlo_passes_optimization", ":tosa_passes", "//stablehlo/tests:check_ops", "//stablehlo/tests:test_utils", diff --git a/stablehlo/tests/transforms/stablehlo_target_independent_optimization.mlir b/stablehlo/tests/transforms/stablehlo_target_independent_optimization.mlir new file mode 100644 index 0000000000..069d59a764 --- /dev/null +++ b/stablehlo/tests/transforms/stablehlo_target_independent_optimization.mlir @@ -0,0 +1,12 @@ +// RUN: stablehlo-opt --stablehlo-target-independent-optimization --split-input-file %s | FileCheck %s + +// Check that simplificaiton and folding are both applied. + +// CHECK-LABEL: @add_cst_on_rhs +func.func @add_cst_on_rhs(%arg0: tensor) -> tensor { + %cst = stablehlo.constant dense<1.0> : tensor + %0 = stablehlo.add %cst, %cst : tensor + // CHECK: stablehlo.add %arg0, %cst : tensor + %1 = stablehlo.add %0, %arg0 : tensor + return %1 : tensor +} diff --git a/stablehlo/tools/StablehloOptMain.cpp b/stablehlo/tools/StablehloOptMain.cpp index 8562a46462..34f90d60e8 100644 --- a/stablehlo/tools/StablehloOptMain.cpp +++ b/stablehlo/tools/StablehloOptMain.cpp @@ -26,12 +26,14 @@ limitations under the License. #include "stablehlo/tests/CheckOps.h" #include "stablehlo/tests/TestUtils.h" #include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" int main(int argc, char **argv) { mlir::registerAllPasses(); mlir::hlo::registerAllTestPasses(); mlir::stablehlo::registerPassPipelines(); mlir::stablehlo::registerPasses(); + mlir::stablehlo::registerOptimizationPasses(); mlir::stablehlo::registerStablehloLinalgTransformsPasses(); mlir::stablehlo::registerInterpreterTransformsPasses(); mlir::tosa::registerStablehloTOSATransformsPasses(); diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index f5193cb61e..088f873fe7 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +add_subdirectory(optimization) + set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls) add_public_tablegen_target(PassesIncGen) @@ -20,10 +22,6 @@ set(LLVM_TARGET_DEFINITIONS ChloDecompositionPatterns.td) mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters) add_public_tablegen_target(ChloDecompositionPatternsIncGen) -set(LLVM_TARGET_DEFINITIONS StablehloAggressiveSimplificationPatterns.td) -mlir_tablegen(StablehloAggressiveSimplificationPatterns.h.inc --gen-rewriters) -add_public_tablegen_target(StablehloAggressiveSimplificationPatternsIncGen) - set(LLVM_TARGET_DEFINITIONS StablehloCompatibilityExpanderPatterns.td) mlir_tablegen(StablehloCompatibilityExpanderPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloCompatibilityExpanderPatternsIncGen) @@ -93,6 +91,7 @@ add_mlir_dialect_library(StablehloPasses StablehloBroadcastUtils StablehloLinalgTransforms StablehloOps + StablehloOptimizationPasses StablehloTypeInference VhloOps ) diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index 3fede6e9eb..055768cacd 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -72,11 +72,6 @@ void populateChloToStablehloPatterns(MLIRContext *context, void populateChloConstantLikePattern(MLIRContext *context, RewritePatternSet *patterns); -/// Collection of folding patterns for StableHLO. -void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns, - MLIRContext *context, - bool foldFloat); - /// Collection of rewrite patterns for lowering quantized StableHLO operations /// using uniform dequantize/quantize operations. void populateStablehloLegalizeQuantizedOpToQDQPatterns( @@ -88,17 +83,6 @@ void populateStablehloLegalizeQuantizedOpToQDQPatterns( void populateStablehloLegalizeQDQToQuantizedOpPatterns( RewritePatternSet *patterns, MLIRContext *context); -/// A subset of folding patterns for StableHLO that is necessary for shape -/// refinement. -void populateStablehloShapeFolderPatterns(RewritePatternSet *patterns, - MLIRContext *context, - bool foldFloat = false); - -/// Collection of canonicalization patterns for StableHLO. -void populateStablehloCanonicalizationPatterns(MLIRContext *context, - RewritePatternSet *patterns, - PatternBenefit benefit = 1); - /// Collection of patterns to upgrade deprecated ops to long-term supported ops. void populateStablehloLegalizeDeprecatedOpsPatterns( MLIRContext *context, RewritePatternSet *patterns); diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index aa9d696664..e0d9f317e1 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -36,27 +36,6 @@ def ShapeLegalizeToStablehloPass : Pass<"shape-legalize-to-stablehlo", "func::Fu let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; } -def StablehloAggressiveFolderPass - : Pass<"stablehlo-aggressive-folder", "func::FuncOp"> { - let summary = "Folds StableHLO operations"; - let dependentDialects = [ - "mlir::stablehlo::StablehloDialect", - "mlir::tensor::TensorDialect", - ]; - let options = [ - Option<"foldFloat", "fold-float", "bool", /*default=*/"true", - "Allow for potentially lossy computations using float type.">, - ]; -} - -def StablehloAggressiveSimplificationPass - : Pass<"stablehlo-aggressive-simplification", "func::FuncOp"> { - let summary = "Canonicalizes StableHLO operations"; - let dependentDialects = [ - "mlir::stablehlo::StablehloDialect", - ]; -} - def StablehloCanonicalizeDynamismPass : Pass<"stablehlo-canonicalize-dynamism", "func::FuncOp"> { let summary = "Canonicalizes dynamic StableHLO ops into static ops."; let description = [{ diff --git a/stablehlo/transforms/StablehloAggressiveFolder.cpp b/stablehlo/transforms/StablehloAggressiveFolder.cpp index 49942c4241..e69de29bb2 100644 --- a/stablehlo/transforms/StablehloAggressiveFolder.cpp +++ b/stablehlo/transforms/StablehloAggressiveFolder.cpp @@ -1,941 +0,0 @@ -/* Copyright 2024 The StableHLO Authors. -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include -#include -#include - -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/APSInt.h" -#include "llvm/ADT/FloatingPointMode.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringRef.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/CommonFolders.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/Operation.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Types.h" -#include "mlir/IR/Value.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "stablehlo/dialect/Base.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/Passes.h" - -namespace mlir { -namespace stablehlo { - -#define GEN_PASS_DEF_STABLEHLOAGGRESSIVEFOLDERPASS -#include "stablehlo/transforms/Passes.h.inc" - -namespace { - -// This is an upper limit on how many elements can be folded by an op folder. -// This limit doesn't apply to some special cases like adding a zero, -// multiplying by one, doing many operations with splats. -constexpr int64_t kFoldOpEltLimit = 65536; - -// DenseElementsAttr can be constructed from ArrayRef but not from -// ArrayRef. This helper bridges the gap. -DenseIntElementsAttr getTensorAttr(ShapedType type, ArrayRef values) { - SmallVector supportedValues(values); - return DenseIntElementsAttr::get(type, supportedValues); -} - -APSInt getAPSInt(Type type, uint64_t value) { - unsigned numBits; - bool isUnsigned; - if (auto integerType = dyn_cast(type)) { - numBits = integerType.getWidth(); - // Signless types are treated as signed, per StableHLO convention. - isUnsigned = integerType.isUnsignedInteger(); - } else { - llvm::report_fatal_error("expected integer type"); - } - return APSInt( - {/*numBits=*/numBits, value, /*isSigned=*/false, /*implicitTrunc=*/true}, - /*isUnsigned=*/isUnsigned); -} - -LogicalResult validateResultTypeForEval(PatternRewriter& rewriter, - Operation* op, ShapedType resultType) { - if (!resultType.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "unable to fold dynamically shaped result type to constant"); - return success(); -} - -/// Binary constant folder that used a generic folder function to handle both -/// ints and floats. -template -static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs, - Fn&& folder) { - Attribute operands[2] = {lhs, rhs}; - Type elemTy = getElementTypeOrSelf(lhs); - - Attribute res; - if (isa(elemTy)) - res = constFoldBinaryOp(operands, - folder); - if (isa(elemTy)) - res = constFoldBinaryOp(operands, - folder); - if (res) return cast(res); - - return nullptr; -} - -template -LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op, - DenseIntOrFPElementsAttr elements, Type resType, - CalculationT&& calculate) { - auto result = constFoldCastOp( - elements, resType, calculate); - - if (!result) - return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { - diag << "cast of " << elements.getElementType() << " to " << resType - << " failed"; - }); - - rewriter.replaceOpWithNewOp(op, result); - return success(); -} - -template -LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, - DenseIntOrFPElementsAttr elements, - RankedTensorType resultType) { - auto oldType = getElementTypeOrSelf(elements); - auto newType = getElementTypeOrSelf(resultType); - size_t newBitWidth = newType.getIntOrFloatBitWidth(); - - bool isOldTypeUnsigned = oldType.isInteger(1) || oldType.isUnsignedInteger(); - bool isNewTypeUnsigned = newType.isInteger(1) || newType.isUnsignedInteger(); - - if (isa(oldType)) { - if (auto newFloatType = dyn_cast(newType)) { - // Float -> Float - const auto& targetSemantics = newFloatType.getFloatSemantics(); - return evalConvertHelper( - rewriter, op, elements, resultType, - [&targetSemantics](const APFloat& operand, bool& castStatus) { - bool losesInfo; - APFloat newValue = operand; - castStatus = APFloat::opInvalidOp != - newValue.convert(targetSemantics, - llvm::RoundingMode::NearestTiesToEven, - &losesInfo); - return newValue; - }); - } - - // Float -> Int - return evalConvertHelper( - rewriter, op, elements, resultType, - [&newBitWidth, &isNewTypeUnsigned](const APFloat& operand, - bool& castStatus) { - APSInt api(newBitWidth, isNewTypeUnsigned); - if (operand.isInfinity() || operand.isNegZero()) { - castStatus = false; - return api; - } - bool ignored; - castStatus = - APFloat::opInvalidOp != - operand.convertToInteger(api, APFloat::rmTowardZero, &ignored); - return api; - }); - } - - if (auto newFloatType = dyn_cast(newType)) { - // Int -> Float - return evalConvertHelper( - rewriter, op, elements, resultType, - [&newFloatType, &isOldTypeUnsigned](const APInt& operand, - bool& /*castStatus*/) { - APFloat apf(newFloatType.getFloatSemantics(), - APInt::getZero(newFloatType.getWidth())); - apf.convertFromAPInt(operand, !isOldTypeUnsigned, - APFloat::rmNearestTiesToEven); - return apf; - }); - } - - // Int -> Int - return evalConvertHelper( - rewriter, op, elements, resultType, - [&newBitWidth, &isOldTypeUnsigned](const APInt& operand, - bool& /*castStatus*/) { - return APSInt(operand, isOldTypeUnsigned).extOrTrunc(newBitWidth); - }); -} - -// The patterns below implement partial evaluation of shape computations which -// is a critical part of implementing type refinement for ops like -// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape -// depends on the value of their shape operands. - -template -LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op, - FuncType fn) { - auto resultType = op.getType(); - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - if (!isa(resultType.getElementType())) - return rewriter.notifyMatchFailure(op, - "expected integer result tensor type"); - - SmallVector result; - if constexpr (OpType::template hasTrait()) { - SmallVector operand; - if (failed(hlo::matchInts(op.getOperand(), operand))) - return rewriter.notifyMatchFailure(op, "expected constant operand"); - for (const auto& operandEl : operand) { - result.push_back(fn(operandEl)); - } - } else if constexpr (OpType::template hasTrait< - OpTrait::NOperands<2>::Impl>()) { - SmallVector lhs, rhs; - if (failed(hlo::matchInts(op.getLhs(), lhs)) || - failed(hlo::matchInts(op.getRhs(), rhs))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - for (auto [lhsEl, rhsEl] : llvm::zip(lhs, rhs)) { - result.push_back(fn(lhsEl, rhsEl)); - } - } else if constexpr (OpType::template hasTrait< - OpTrait::NOperands<3>::Impl>()) { - SmallVector x, y, z; - if (failed(hlo::matchInts(op->getOperand(0), x)) || - failed(hlo::matchInts(op->getOperand(1), y)) || - failed(hlo::matchInts(op->getOperand(2), z))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - for (auto [xEl, yEl, zEl] : llvm::zip(x, y, z)) { - result.push_back(fn(xEl, yEl, zEl)); - } - } else { - llvm::report_fatal_error("unsupported number of operands"); - } - - rewriter.replaceOpWithNewOp(op, - getTensorAttr(resultType, result)); - return success(); -} - -struct FoldAddOpPattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, - PatternRewriter& rewriter) const override { - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - // Pattern: add(cst,cst) -> cst - TypedAttr lhsAttr, rhsAttr; - matchPattern(lhs, m_Constant(&lhsAttr)); - matchPattern(rhs, m_Constant(&rhsAttr)); - - if (TypedAttr res; - lhsAttr && rhsAttr && - (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::plus<>{}))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } - - return failure(); - } -}; - -struct EvalAddOpShapePattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AddOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); - } -}; - -struct EvalAndOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(AndOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (!resultType.getElementType().isInteger(1)) - return rewriter.notifyMatchFailure(op, "expected boolean element type"); - - return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { - return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0); - }); - } -}; - -// Pattern: broadcast_in_dim(splat, _) -> constant(splat) -struct FoldBroadcastInDimSplatPattern final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, - PatternRewriter& rewriter) const override { - TypedValue operand = op.getOperand(); - - if (SplatElementsAttr cstAttr; - matchPattern(operand, m_Constant(&cstAttr))) { - rewriter.replaceOpWithNewOp( - op, SplatElementsAttr::get(op.getType(), - cstAttr.getSplatValue())); - return success(); - } - return failure(); - } -}; - -struct EvalBroadcastInDimOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(BroadcastInDimOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - auto operandType = op.getOperand().getType(); - if (operandType.getRank() != 0) - return rewriter.notifyMatchFailure(op, "expected 0-dimensional type"); - - SmallVector operand; - if (failed(hlo::matchInts(op.getOperand(), operand))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - auto scalar = operand[0]; - - rewriter.replaceOpWithNewOp( - op, getTensorAttr(op.getType(), scalar)); - return success(); - } -}; - -struct EvalClampOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ClampOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt min, APSInt operand, APSInt max) { - if (operand < min) return min; - if (max < operand) return max; - return operand; - }); - } -}; - -struct EvalCompareOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(CompareOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - auto kind = op.getCompareType(); - return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) { - bool result = false; - switch (op.getComparisonDirection()) { - case ComparisonDirection::EQ: - result = lhs == rhs; - break; - case ComparisonDirection::NE: - result = lhs != rhs; - break; - case ComparisonDirection::GE: - result = kind == ComparisonType::SIGNED ? lhs.sge(rhs) : lhs.uge(rhs); - break; - case ComparisonDirection::GT: - result = kind == ComparisonType::SIGNED ? lhs.sgt(rhs) : lhs.ugt(rhs); - break; - case ComparisonDirection::LE: - result = kind == ComparisonType::SIGNED ? lhs.sle(rhs) : lhs.ule(rhs); - break; - case ComparisonDirection::LT: - result = kind == ComparisonType::SIGNED ? lhs.slt(rhs) : lhs.ult(rhs); - break; - } - return getAPSInt(resultType.getElementType(), result); - }); - } -}; - -////////////////////////////////// -// ConcatenateOp -///////////////////////////////// - -struct FoldConcatenateOpPattern final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, - PatternRewriter& rewriter) const override { - RankedTensorType type = op.getType(); - if (!type.hasStaticShape()) return failure(); - - size_t numElems = type.getNumElements(); - if (numElems > kFoldOpEltLimit) return failure(); - - // Fold concatenate when all inputs are constants. - OperandRange inputs = op.getInputs(); - SmallVector constants(inputs.size()); - for (auto [input, constant] : llvm::zip_equal(inputs, constants)) { - if (!matchPattern(input, m_Constant(&constant))) return failure(); - } - - uint64_t dim = op.getDimension(); - ArrayRef shape = type.getShape(); - int64_t topSize = std::accumulate(shape.begin(), shape.begin() + dim, - int64_t{1}, std::multiplies<>{}); - - SmallVector newElems; - newElems.reserve(numElems); - - for (int64_t i = 0; i != topSize; ++i) { - for (ElementsAttr attr : constants) { - size_t bottomSize = attr.getNumElements() / topSize; - auto begin = attr.value_begin() + (i * bottomSize); - newElems.append(begin, begin + bottomSize); - } - } - - assert(newElems.size() == numElems); - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(op.getType(), newElems)); - return success(); - } -}; - -struct EvalConcatenateOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ConcatenateOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - if (op.getDimension() != 0) - return rewriter.notifyMatchFailure(op, "expected dimension = 0"); - - SmallVector result; - for (Value operand : op->getOperands()) { - if (failed(hlo::matchInts(operand, result))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - } - - rewriter.replaceOpWithNewOp(op, - getTensorAttr(resultType, result)); - return success(); - } -}; - -struct EvalConvertOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - EvalConvertOpPattern(MLIRContext* context, bool foldFloat_) - : OpRewritePattern(context), foldFloat{foldFloat_} {} - - LogicalResult matchAndRewrite(ConvertOp op, - PatternRewriter& rewriter) const override { - auto operand = op.getOperand(); - RankedTensorType resultType = op.getType(); - - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - auto operandElemType = getElementTypeOrSelf(operand.getType()); - auto resultElemType = getElementTypeOrSelf(resultType); - if (!(operandElemType.isInteger() && resultElemType.isInteger()) && - !foldFloat) - return rewriter.notifyMatchFailure(op, - "lossy computations are not allowed"); - - if (!resultElemType.isIntOrFloat()) - return rewriter.notifyMatchFailure( - op, "expected integer or float result tensor type"); - - DenseIntOrFPElementsAttr elements; - if (!matchPattern(operand, m_Constant(&elements))) - return rewriter.notifyMatchFailure( - op, "expected constant integer or float operand"); - - return evalConvert(rewriter, op, elements, resultType); - } - - private: - bool foldFloat; -}; - -struct EvalDivOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(DivOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs / rhs; }); - } -}; - -struct EvalGetDimensionSizeOpPattern - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(GetDimensionSizeOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - auto operandType = op.getOperand().getType(); - if (operandType.isDynamicDim(op.getDimension())) - return rewriter.notifyMatchFailure(op, "expected static dimension"); - - auto result = operandType.getDimSize(op.getDimension()); - rewriter.replaceOpWithNewOp( - op, DenseIntElementsAttr::get(resultType, result)); - return success(); - } -}; - -struct EvalMaxOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(MaxOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { - return lhs >= rhs ? lhs : rhs; - }); - } -}; - -struct EvalMinOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(MinOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { - return lhs <= rhs ? lhs : rhs; - }); - } -}; - -struct FoldMulOpPattern final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, - PatternRewriter& rewriter) const override { - auto elemType = op.getType().getElementType(); - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - TypedAttr lhsAttr; - matchPattern(lhs, m_Constant(&lhsAttr)); - - TypedAttr rhsAttr; - matchPattern(rhs, m_Constant(&rhsAttr)); - - // The canonical form has the constant operand as the RHS. - if (isa(elemType) && lhsAttr && !rhsAttr) { - rewriter.modifyOpInPlace(op, [op, lhs, rhs] { - op->setOperands(ValueRange{rhs, lhs}); - }); - return success(); - } - - if (TypedAttr res; - lhsAttr && rhsAttr && - (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } - - return failure(); - } -}; - -struct EvalMulOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(MulOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs * rhs; }); - } -}; - -struct EvalOrOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(OrOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (!resultType.getElementType().isInteger(1)) - return rewriter.notifyMatchFailure(op, "expected boolean element type"); - - return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { - return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); - }); - } -}; - -struct EvalRemOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(RemOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs % rhs; }); - } -}; - -struct EvalReshapeOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ReshapeOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - // Pattern: reshape(cst, shape) -> cst - DenseIntElementsAttr attr; - if (!matchPattern(op.getOperand(), m_Constant(&attr))) - return rewriter.notifyMatchFailure(op, "expected constant operand"); - rewriter.replaceOpWithNewOp(op, attr.reshape(resultType)); - return success(); - } -}; - -struct EvalSelectOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(SelectOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - SmallVector pred, onTrue, onFalse; - if (failed(hlo::matchInts(op.getPred(), pred)) || - failed(hlo::matchInts(op.getOnTrue(), onTrue)) || - failed(hlo::matchInts(op.getOnFalse(), onFalse))) - return rewriter.notifyMatchFailure(op, "expected constant operands"); - - SmallVector result; - for (auto [predEl, onTrueEl, onFalseEl] : - llvm::zip(pred, onTrue, onFalse)) { - result.push_back(predEl != 0 ? onTrueEl : onFalseEl); - } - - rewriter.replaceOpWithNewOp( - op, getTensorAttr(op.getType(), result)); - return success(); - } -}; - -struct EvalSignOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(SignOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (!isa(resultType.getElementType())) - return rewriter.notifyMatchFailure(op, - "expected integer result tensor type"); - return evalElementwise(rewriter, op, [&](APSInt operand) { - int64_t result; - if (operand.isNegative()) - result = -1; - else if (operand.isZero()) - result = 0; - else - result = 1; - return getAPSInt(resultType.getElementType(), result); - }); - } -}; - -template -DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) { - using ElementType = std::decay_t; - - RankedTensorType operandType = op.getOperand().getType(); - RankedTensorType resultType = op.getResult().getType(); - - const auto dimOffsets = computeStrides(operandType.getShape()); - auto startIndices = op.getStartIndices(); - auto limitIndices = op.getLimitIndices(); - auto strides = op.getStrides(); - - const SmallVector startIndex(startIndices); - const SmallVector endIndex(limitIndices); - - SmallVector result; - result.reserve(resultType.getNumElements()); - - SmallVector srcIndex(startIndex); - for (int64_t i = 0; i < resultType.getNumElements(); ++i) { - auto srcLinearIndex = linearize(srcIndex, dimOffsets); - result.push_back(data[srcLinearIndex]); - for (int64_t dim = srcIndex.size() - 1; dim >= 0; --dim) { - srcIndex[dim] += strides[dim]; - if (srcIndex[dim] >= endIndex[dim]) - srcIndex[dim] = startIndex[dim]; - else - break; - } - } - - return DenseElementsAttr::get(op.getResult().getType(), - ArrayRef(result)); -} - -struct EvalSliceOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(SliceOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - auto operand = op.getOperand(); - RankedTensorType operandType = operand.getType(); - if (!operandType.hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "expected operand with static ranked tensor type"); - - ElementsAttr els; - if (!matchPattern(operand, m_Constant(&els))) - return rewriter.notifyMatchFailure( - op, "expected constant integer or float operand"); - - DenseElementsAttr resAttr; - if (auto data = els.tryGetValues()) - resAttr = sliceType(op, *data); - else if (auto data = els.tryGetValues()) - resAttr = sliceType(op, *data); - else - return rewriter.notifyMatchFailure(op.getLoc(), - "unsupported element type"); - - rewriter.replaceOpWithNewOp(op, resAttr); - return success(); - } -}; - -struct FoldSubtractOpPattern final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, - PatternRewriter& rewriter) const override { - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - TypedAttr lhsAttr, rhsAttr; - matchPattern(lhs, m_Constant(&lhsAttr)); - matchPattern(rhs, m_Constant(&rhsAttr)); - - if (TypedAttr res; - lhsAttr && rhsAttr && - (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::minus<>{}))) { - rewriter.replaceOpWithNewOp(op, res); - return success(); - } - - return failure(); - } -}; - -struct EvalSubtractOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(SubtractOp op, - PatternRewriter& rewriter) const override { - return evalElementwise(rewriter, op, - [&](APSInt lhs, APSInt rhs) { return lhs - rhs; }); - } -}; - -struct EvalIotaOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(IotaOp op, - PatternRewriter& rewriter) const override { - auto resultType = cast(op.getType()); - auto elementType = resultType.getElementType(); - - if (!elementType.isInteger()) - return rewriter.notifyMatchFailure(op, "expected integer result type"); - - auto outputSize = resultType.getNumElements(); - auto resultBitWidth = elementType.getIntOrFloatBitWidth(); - int64_t dimension = op.getIotaDimension(); - - llvm::SmallVector values; - values.reserve(outputSize); - - if (outputSize == 0) { - rewriter.replaceOpWithNewOp( - op, DenseIntElementsAttr::get(resultType, values)); - return success(); - } - - int64_t sequences = 1; - int64_t sequenceMax = resultType.getDimSize(dimension); - int64_t elementRepetitions = 1; - for (int64_t i = 0; i < resultType.getRank(); i++) { - sequences *= i < dimension ? resultType.getDimSize(i) : 1; - elementRepetitions *= i > dimension ? resultType.getDimSize(i) : 1; - } - - for (int64_t i = 0; i < sequences; ++i) { - for (int64_t value = 0; value < sequenceMax; ++value) { - for (int64_t k = 0; k < elementRepetitions; ++k) { - values.push_back(APInt(resultBitWidth, value)); - } - } - } - - rewriter.replaceOpWithNewOp( - op, DenseIntElementsAttr::get(resultType, values)); - return success(); - } -}; - -template -DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) { - using ElementType = std::decay_t; - - RankedTensorType operandType = op.getOperand().getType(); - RankedTensorType resultType = op.getResult().getType(); - - const auto operandStrides = computeStrides(operandType.getShape()); - const auto resultStrides = computeStrides(resultType.getShape()); - const auto inversePermutation = invertPermutationVector(op.getPermutation()); - - SmallVector result; - result.reserve(resultType.getNumElements()); - - for (int64_t i = 0; i < resultType.getNumElements(); ++i) { - auto dstDimIndex = delinearize(i, resultStrides); - auto srcDimIndex = applyPermutation(dstDimIndex, inversePermutation); - auto srcLinearIndex = linearize(srcDimIndex, operandStrides); - result.push_back(data[srcLinearIndex]); - } - - return DenseElementsAttr::get(resultType, ArrayRef(result)); -} - -// transpose(constant) => constant with permuted dimensions -// This covers ranked tensor types with 0 dimensions(zero elements) and 0 -// rank(scalar), as well as splat values. -struct EvalTransposeOpPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(TransposeOp op, - PatternRewriter& rewriter) const override { - auto resultType = op.getType(); - if (failed(validateResultTypeForEval(rewriter, op, resultType))) - return failure(); - - ElementsAttr els; - if (!matchPattern(op.getOperand(), m_Constant(&els))) - return rewriter.notifyMatchFailure( - op, "expected constant integer or float operand"); - - DenseElementsAttr resAttr; - if (auto data = els.tryGetValues()) - resAttr = transposeType(op, *data); - else if (auto data = els.tryGetValues()) - resAttr = transposeType(op, *data); - else - return rewriter.notifyMatchFailure(op.getLoc(), - "unsupported element type"); - - rewriter.replaceOpWithNewOp(op, resAttr); - return success(); - } -}; - -struct StablehloAggressiveFolderPass - : public impl::StablehloAggressiveFolderPassBase< - StablehloAggressiveFolderPass> { - using StablehloAggressiveFolderPassBase::StablehloAggressiveFolderPassBase; - - LogicalResult initialize(MLIRContext* context) override { - RewritePatternSet patterns_(context); - populateStablehloAggressiveFolderPatterns(&patterns_, context, foldFloat); - patterns = std::move(patterns_); - - return success(); - } - - void runOnOperation() override { - if (failed(applyPatternsGreedily(getOperation(), patterns))) - signalPassFailure(); - } - - private: - FrozenRewritePatternSet patterns; -}; - -} // namespace - -void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns, - MLIRContext* context, - bool foldFloat) { - populateStablehloShapeFolderPatterns(patterns, context, foldFloat); - patterns->add(context); - patterns->add(context); - - // TODO: Consolidate FoldOp patterns - // One is used by Shape Refinement, the other is a generic folder. - patterns - ->add( - context); -} - -void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns, - MLIRContext* context, - bool foldFloat) { - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context, foldFloat); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); -} - -} // namespace stablehlo -} // namespace mlir diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/StablehloAggressiveSimplification.cpp index ebb6c027a3..e69de29bb2 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/StablehloAggressiveSimplification.cpp @@ -1,1537 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License, Version 2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// Implements optional canonicalization patterns for StableHLO ops. - -#include -#include -#include -#include -#include -#include -#include -#include - -#include "llvm/ADT/APInt.h" -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/Sequence.h" -#include "llvm/ADT/SmallBitVector.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/SmallVectorExtras.h" -#include "llvm/Support/ErrorHandling.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/Block.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/BuiltinAttributeInterfaces.h" -#include "mlir/IR/BuiltinAttributes.h" -#include "mlir/IR/BuiltinTypeInterfaces.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Diagnostics.h" -#include "mlir/IR/IRMapping.h" -#include "mlir/IR/MLIRContext.h" -#include "mlir/IR/Matchers.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/IR/Region.h" -#include "mlir/IR/TypeUtilities.h" -#include "mlir/IR/Value.h" -#include "mlir/IR/ValueRange.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Rewrite/FrozenRewritePatternSet.h" -#include "mlir/Support/LLVM.h" -#include "mlir/Support/LogicalResult.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "stablehlo/dialect/Base.h" -#include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/PassUtils.h" -#include "stablehlo/transforms/Passes.h" - -using llvm::SmallBitVector; - -namespace mlir { -namespace stablehlo { - -#define GEN_PASS_DEF_STABLEHLOAGGRESSIVESIMPLIFICATIONPASS -#include "stablehlo/transforms/Passes.h.inc" - -namespace { -// This is an upper limit on how many elements can be folded by an op folder. -// This limit doesn't apply to some special cases like adding a zero, -// multiplying by one, doing many operations with splats. -constexpr int64_t kFoldOpEltLimit = 65536; - -static bool isIotaRange(ArrayRef dims) { - return llvm::all_of(llvm::enumerate(dims), [](const auto &it) { - return static_cast(it.index()) == it.value(); - }); -} - -/// Matches when either of the submatchers match. -template -struct m_AnyOf { - m_AnyOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {} - - bool match(Operation *op) { return matcherA.match(op) || matcherB.match(op); } - - MatcherA matcherA; - MatcherB matcherB; -}; - -template -m_AnyOf(MatcherA, MatcherB) -> m_AnyOf; - -/// Matches when either of the submatchers match. -template -struct m_AnyAttrOf { - m_AnyAttrOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {} - - bool match(Attribute attr) { - return matcherA.match(attr) || matcherB.match(attr); - } - - MatcherA matcherA; - MatcherB matcherB; -}; - -template -m_AnyAttrOf(MatcherA, MatcherB) -> m_AnyAttrOf; - -////////////////////////////////// -// CompareOp -///////////////////////////////// - -static ComparisonDirection invertDirection(ComparisonDirection direction) { - switch (direction) { - case ComparisonDirection::EQ: - case ComparisonDirection::NE: - return direction; - case ComparisonDirection::GE: - return ComparisonDirection::LE; - case ComparisonDirection::GT: - return ComparisonDirection::LT; - case ComparisonDirection::LE: - return ComparisonDirection::GE; - case ComparisonDirection::LT: - return ComparisonDirection::GT; - } - - llvm::report_fatal_error("Unhandled case"); -} - -struct CompareOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(CompareOp op, - PatternRewriter &rewriter) const override { - RankedTensorType type = op.getType(); - - // Bail out on non-integer comparison. - // TODO: Support more comparison types. - std::optional compType = op.getCompareType(); - if (!compType || - !llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED}, - *compType)) { - return failure(); - } - - ComparisonDirection direction = op.getComparisonDirection(); - Value lhs = op.getLhs(); - Value rhs = op.getRhs(); - - // Pattern: compare(X, X, [EQ,GE,LE]) -> true - // Pattern: compare(X, X, [NE,GT,LT]) -> false - if (lhs == rhs) { - switch (direction) { - case ComparisonDirection::EQ: - case ComparisonDirection::GE: - case ComparisonDirection::LE: { - rewriter.replaceOpWithNewOp( - op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true))); - return success(); - } - case ComparisonDirection::GT: - case ComparisonDirection::LT: - case ComparisonDirection::NE: { - rewriter.replaceOpWithNewOp(op, - rewriter.getZeroAttr(type)); - return success(); - } - } - llvm_unreachable("Unhandled case"); - } - - // Pattern: compare(cst, X, comparator) -> compare(X, cst, inv(comparator)) - TypedAttr lhsAttr, rhsAttr; - matchPattern(lhs, m_Constant(&lhsAttr)); - matchPattern(rhs, m_Constant(&rhsAttr)); - - // The canonical form has the constant operand as the RHS. - if (lhsAttr && !rhsAttr) { - rewriter.modifyOpInPlace(op, [&op, direction, lhs, rhs] { - op.setComparisonDirection(invertDirection(direction)); - op->setOperands(ValueRange{rhs, lhs}); - }); - return success(); - } - - return failure(); - } -}; - -////////////////////////////////// -// ConcatenateOp -///////////////////////////////// - -// Pattern: concatenate(X) -> X -class ConcatenateOpNoop : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ConcatenateOp op, - PatternRewriter &rewriter) const override { - if (op.getInputs().size() != 1 || - op.getInputs().front().getType() != op.getType()) - return rewriter.notifyMatchFailure(op, "not single operand noop-concat"); - - rewriter.replaceOp(op, op.getInputs().front()); - return success(); - } -}; - -// Pattern: concatenate(X, Y, []) -> concatenate(X, Y) -class ConcatenateOpRemoveEmpty : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ConcatenateOp op, - PatternRewriter &rewriter) const override { - auto axis = op.getDimension(); - llvm::SmallVector newOperands = llvm::to_vector( - llvm::make_filter_range(op.getOperands(), [&](Value operand) { - return cast(operand.getType()).getDimSize(axis) != 0; - })); - - // Only handle nonempty new operands, empty handled by - // ZeroExtentToEmptyConstant pattern. - if (!newOperands.empty() && newOperands.size() < op.getNumOperands()) { - rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); }); - return success(); - } - - return failure(); - } -}; - -// Pattern: concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z) -class ConcatenateOpFlatten : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - LogicalResult matchAndRewrite(ConcatenateOp op, - PatternRewriter &rewriter) const override { - auto getFlattenedOperands = [&](const Value &val) -> ValueRange { - auto definingOp = dyn_cast_or_null(val.getDefiningOp()); - // To avoid inflate the memory footprint, only flatten the - // ConcatenateOp when it has only one use. - if (definingOp && definingOp->hasOneUse() && - definingOp.getDimension() == op.getDimension()) - return definingOp.getInputs(); - return val; - }; - - bool needToFlatten = false; - int operandCount = 0; - llvm::for_each(op.getInputs(), [&](Value val) { - auto result = getFlattenedOperands(val); - if (result.size() != 1 || result[0] != val) needToFlatten = true; - operandCount += result.size(); - }); - - if (!needToFlatten) - return rewriter.notifyMatchFailure(op, "no need to flatten"); - - llvm::SmallVector newOperands; - newOperands.reserve(operandCount); - - for (auto operand : op.getInputs()) { - auto flattenedOperands = getFlattenedOperands(operand); - newOperands.append(flattenedOperands.begin(), flattenedOperands.end()); - } - - rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); }); - return success(); - } -}; - -////////////////////////////////// -// BroadcastInDimOp -///////////////////////////////// - -// Used in DRR file. -DenseI64ArrayAttr getMergedBroadcastDimensions(OpBuilder &b, - ArrayRef dims, - ArrayRef dimsParent) { - auto mergedDims = llvm::map_to_vector( - dimsParent, [&dims](int64_t dim) { return dims[dim]; }); - return b.getDenseI64ArrayAttr(mergedDims); -} - -////////////////////////////////// -// DynamicBroadcastInDimOp -///////////////////////////////// - -/// Does the same as PatternRewriter::replaceOpWithNewOp, but with a twist. -/// -/// Sometimes, we want to replace an op with a new op and simultaneously refine -/// the result type from a dynamically-shaped type to a statically-shaped type. -/// (Search for usages of this function for examples). -// -/// Oftentimes, this works just fine because HLO is designed to accommodate -/// this kind of type refinements. But sometimes, this doesn't work - when -/// the op is used outside of the HLO dialect (e.g. in func.return). In these -/// cases, we insert a stablehlo.convert to smooth things out. -template -static OpTy refineOpWithNewOp(PatternRewriter &rewriter, Operation *op, - Args &&...args) { - auto newOp = rewriter.create(op->getLoc(), std::forward(args)...); - - llvm::SmallVector replacementResults; - assert(op->getNumResults() == newOp->getNumResults() && - "replacement op doesn't match results of original op"); - for (auto [opResult, newOpResult] : - llvm::zip(op->getResults(), newOp->getResults())) { - Value replacementResult = newOpResult; - if (llvm::any_of(opResult.getUsers(), [&](Operation *user) { - return user->getDialect() != op->getDialect(); - })) - replacementResult = rewriter.create( - op->getLoc(), opResult.getType(), newOpResult); - replacementResults.push_back(replacementResult); - } - - rewriter.replaceOp(op, replacementResults); - return newOp; -} - -/// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary -/// BroadcastInDimOp. -struct DynamicBroadcastInDimOpNotActuallyDynamic final - : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, - PatternRewriter &rewriter) const override { - RankedTensorType operandType = op.getOperand().getType(); - if (!operandType.hasStaticShape()) - return rewriter.notifyMatchFailure(op, "requires operand static shape"); - - RankedTensorType type = op.getType(); - // output has static shape, replace with broadcast_in_dim - if (type.hasStaticShape()) { - rewriter.replaceOpWithNewOp( - op, type, op.getOperand(), op.getBroadcastDimensionsAttr()); - return success(); - } - - // output_dimensions are constant, set output shape with output_dimensions, - // then replace with broadcast_in_dim - if (llvm::SmallVector shape; - succeeded(hlo::matchInts(op.getOutputDimensions(), shape))) { - refineOpWithNewOp( - rewriter, op, RankedTensorType::get(shape, type.getElementType()), - op.getOperand(), op.getBroadcastDimensionsAttr()); - return success(); - } - return rewriter.notifyMatchFailure( - op, "requires output static shape or constant broadcast dimensions"); - } -}; - -////////////////////////////////// -// DynamicGatherOp -///////////////////////////////// - -DenseI64ArrayAttr convertToI64Array(OpBuilder &b, Attribute attr) { - auto denseAttr = cast(attr); - SmallVector result; - result.reserve(denseAttr.getNumElements()); - for (auto elem : denseAttr.getValues()) - result.push_back(elem.getSExtValue()); - return b.getDenseI64ArrayAttr(result); -} - -////////////////////////////////// -// DynamicIotaOp -///////////////////////////////// - -struct DynamicIotaIsStatic : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DynamicIotaOp iota, - PatternRewriter &rewriter) const override { - // Result type has static shape, replace with iota. - auto resultTy = cast(iota.getType()); - if (!resultTy.hasStaticShape()) - return rewriter.notifyMatchFailure(iota, "requires output static shape"); - rewriter.replaceOpWithNewOp(iota, resultTy, - iota.getIotaDimension()); - return success(); - } -}; - -// Dynamic Iota operations across multiple dimensions can be reduced to an iota -// and a ranked broadcast. -// Pattern: dynamic_iota(shape, dim) -> -// dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape) -struct DynamicIotaOpToBroadcast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DynamicIotaOp iota, - PatternRewriter &rewriter) const override { - auto resultTy = cast(iota.getType()); - if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(iota, "requires rank >= 2"); - - auto iotaDimension = static_cast(iota.getIotaDimension()); - - // Handle case where iota dimension is index, need to convert to/from i64 - // to interop with slice. These canonicalize away if input is i64. - auto convertedShape = rewriter.create( - iota.getLoc(), - RankedTensorType::get( - cast(iota.getOutputShape().getType()).getShape(), - rewriter.getI64Type()), - iota.getOutputShape()); - - auto slicedShape = rewriter.create( - iota.getLoc(), convertedShape, - rewriter.getDenseI64ArrayAttr(iotaDimension), - rewriter.getDenseI64ArrayAttr(iotaDimension + 1), - rewriter.getDenseI64ArrayAttr(1)); - - auto convertedSlicedShape = rewriter.create( - iota.getLoc(), - RankedTensorType::get( - {1}, - cast(iota.getOutputShape().getType()).getElementType()), - slicedShape); - - auto iotaType = RankedTensorType::get({resultTy.getDimSize(iotaDimension)}, - resultTy.getElementType()); - - auto newIota = rewriter.create( - iota.getLoc(), iotaType, convertedSlicedShape, - rewriter.getI64IntegerAttr(0)); - - rewriter.replaceOpWithNewOp( - iota, resultTy, newIota, iota.getOutputShape(), - rewriter.getDenseI64ArrayAttr(iotaDimension)); - return success(); - } -}; - -////////////////////////////////// -// DynamicReshapeOp -///////////////////////////////// - -struct DynamicReshapeOpIsStatic final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DynamicReshapeOp op, - PatternRewriter &rewriter) const override { - // This is a noop when the output type is already a static shape. - RankedTensorType type = op.getType(); - if (!type.hasStaticShape()) - return rewriter.notifyMatchFailure(op, "dynamic reshape not static"); - - rewriter.replaceOpWithNewOp(op, type, op.getOperand()); - return success(); - } -}; - -// Pattern: dynamic_reshape(op(dynamic_reshape(X, shape)), shape) -// -> op(dynamic_reshape(X, shape)) -// [if op has same operand and result shape] -class DynamicReshapeOpSameOperandAndResultShape - : public OpRewritePattern { - public: - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DynamicReshapeOp op, - PatternRewriter &rewriter) const override { - Operation *defOp = op.getOperand().getDefiningOp(); - if (!defOp || - !defOp->hasTrait()) { - return rewriter.notifyMatchFailure( - op, "dynamic reshape parent not same operand and result shape"); - } - DynamicReshapeOp reshape = - defOp->getOperand(0).getDefiningOp(); - if (!reshape) - return rewriter.notifyMatchFailure( - op, "dynamic reshape not wrapping same operand and result shape"); - if (reshape.getOutputShape() == op.getOutputShape()) { - rewriter.replaceOp(op, {defOp->getResult(0)}); - return success(); - } - return failure(); - } -}; - -////////////////////////////////// -// DynamicSliceOp -///////////////////////////////// - -// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops. -// This canonicalization is applied the case when the `begin` input values are -// compile time constants and thus can be made into a tensor. -// -// Pattern: dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes) -struct DynamicSliceOpToSlice : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(DynamicSliceOp dynamicSlice, - PatternRewriter &rewriter) const override { - Value input = dynamicSlice.getOperand(); - auto inputType = cast(input.getType()); - if (!inputType.hasStaticShape()) - return rewriter.notifyMatchFailure(dynamicSlice, - "dynamic slice input not static"); - - auto sliceSizes = dynamicSlice.getSliceSizes(); - SmallVector tempStartIndices; - for (const auto &indexAndSliceStart : - llvm::enumerate(dynamicSlice.getStartIndices())) { - APInt val; - Value start = indexAndSliceStart.value(); - int64_t index = indexAndSliceStart.index(); - if (!matchPattern(start, m_ConstantInt(&val))) - return rewriter.notifyMatchFailure(dynamicSlice, - "dynamic slice input not constant"); - - // Clamp the indices within bounds to faithfully mirror dynamic slice - // semantics. - int64_t clampedStart = - std::clamp(val.getSExtValue(), static_cast(0), - inputType.getDimSize(index) - sliceSizes[index]); - tempStartIndices.push_back(clampedStart); - } - - // At this point we've determined that the start indices are all constants; - // pack them into a single tensor. - auto sliceStartIndices = rewriter.getDenseI64ArrayAttr(tempStartIndices); - SmallVector tempSliceLimits; - for (const auto &[start, size] : llvm::zip(tempStartIndices, sliceSizes)) { - tempSliceLimits.push_back(start + size); - } - auto sliceLimits = rewriter.getDenseI64ArrayAttr(tempSliceLimits); - - auto sliceStrides = rewriter.getDenseI64ArrayAttr( - SmallVector(inputType.getRank(), 1)); - - rewriter.replaceOpWithNewOp(dynamicSlice, input, sliceStartIndices, - sliceLimits, sliceStrides); - return success(); - } -}; - -////////////////////////////////// -// RealDynamicSliceOp -///////////////////////////////// - -// Pattern: real_dynamic_slice(X, start, limit, strides) -// -> dynamic_slice(X, start, limit, strides) -// [if strides, start are constants, limit = start + constant] -struct RealDynamicSliceOpToDynamicSlice - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(RealDynamicSliceOp op, - PatternRewriter &rewriter) const override { - // This rewrite only works for unit strides because DynamicSliceOp - // doesn't support strides (i.e. it implicitly has unit strides). - DenseIntElementsAttr stridesAttr; - if (!matchPattern(op.getStrides(), m_Constant(&stridesAttr))) - return rewriter.notifyMatchFailure(op, "requires constant strides"); - if (!llvm::all_of(stridesAttr.getValues(), - [&](APInt stride) { return stride == 1; })) - return rewriter.notifyMatchFailure(op, "requires unit strides"); - - // Check that slice sizes are fully static (DynamicSliceOp style). - // To detect that, we check whether `limit_indices` is defined as - // `start_indices + constant` or `constant + start_indices`. - DenseIntElementsAttr sliceSizesAttr; - auto m_startIndices = matchers::m_Val(op.getStartIndices()); - // Only handle the AddOp case, if all constant we fold to SliceOp. - if (!matchPattern( - op.getLimitIndices(), - m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) && - !matchPattern(op.getLimitIndices(), - m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) - return rewriter.notifyMatchFailure( - op, "requires limit indices equal to start indices plus constant"); - - // RealDynamicSliceOp can take tensors of integer or index element types. - // DynamicSliceOp::slice_sizes only supports i64 element type. - // Adapt accordingly in order to be compatible with DynamicSliceOp. - SmallVector sliceSizes; - for (auto element : sliceSizesAttr.getValues()) { - sliceSizes.push_back(element.getSExtValue()); - } - - // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. - // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. - // Adapt accordingly in order to be compatible with DynamicSliceOp. - SmallVector startIndices; - for (auto i = 0; i < static_cast(sliceSizes.size()); ++i) { - auto startIndex1D = rewriter.create( - op.getLoc(), op.getStartIndices(), rewriter.getDenseI64ArrayAttr(i), - rewriter.getDenseI64ArrayAttr(i + 1), - rewriter.getDenseI64ArrayAttr(1)); - auto startIndex0DType = RankedTensorType::get( - {}, - cast(op.getStartIndices().getType()).getElementType()); - auto startIndex0D = rewriter.create( - op.getLoc(), startIndex0DType, startIndex1D); - startIndices.push_back(startIndex0D); - } - - rewriter.replaceOpWithNewOp( - op, op.getOperand(), startIndices, - rewriter.getDenseI64ArrayAttr(sliceSizes)); - return success(); - } -}; - -////////////////////////////////// -// ReduceOp -///////////////////////////////// - -// Pattern: reduce[A](_, _, fn:return A) -> A... -struct ReduceOpNoopVariableReturn final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ReduceOp op, - PatternRewriter &rewriter) const override { - // If all returned values in the ReduceOp region exists outside the - // region, replace the ReduceOp with those values. - if (auto retOp = dyn_cast(op.getBody().front().getTerminator())) { - Region *retRegion = retOp->getParentRegion(); - if (llvm::any_of(retOp.getResults(), [retRegion](Value result) { - return result.getParentRegion() == retRegion; - })) - return failure(); - - rewriter.replaceOp(op, retOp.getResults()); - return success(); - } - - return failure(); - } -}; - -// Pattern: reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...] -struct ReduceOpEmptyCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ReduceOp op, - PatternRewriter &rewriter) const override { - // We require all reduce shapes to be the same, up to the element types, so - // we can just use the first operand and the first result as - // representatives. - auto elemTy = cast(op.getInputs().getType().front()); - - if (!llvm::is_contained(elemTy.getShape(), 0)) return failure(); - - Location loc = op.getLoc(); - DenseI64ArrayAttr empty = rewriter.getDenseI64ArrayAttr({}); - if (elemTy.hasStaticShape()) { - SmallVector broadcasts(op.getNumResults()); - for (auto [bcast, init, outTy] : llvm::zip_equal( - broadcasts, op.getInitValues(), op.getResultTypes())) { - bcast = rewriter.create(loc, outTy, init, empty); - } - rewriter.replaceOp(op, broadcasts); - return success(); - } - - SmallVector shapes; - if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), shapes))) - return failure(); - - SmallVector broadcasts(op.getNumResults()); - for (auto [bcast, init, shape, outTy] : llvm::zip_equal( - broadcasts, op.getInitValues(), shapes, op.getResultTypes())) { - bcast = rewriter.create(loc, outTy, init, shape, - empty); - } - rewriter.replaceOp(op, broadcasts); - return success(); - } -}; - -// Pattern: reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)] -struct ReduceOpUnusedResultCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(ReduceOp op, - PatternRewriter &rewriter) const override { - SmallVector usedResults; - llvm::copy_if(op.getResults(), std::back_inserter(usedResults), - [](OpResult result) { return !result.use_empty(); }); - - if (usedResults.size() == op.getNumResults()) - return rewriter.notifyMatchFailure(op, "all operation results have uses"); - - const auto pairSize = 2; - const auto numOperands = op.getNumOperands(); - const auto numOperandPairs = numOperands / pairSize; - - Block &reducerBlock = op.getBody().front(); - auto retOp = cast(reducerBlock.getTerminator()); - - assert(numOperandPairs == op.getNumResults() && - numOperandPairs == retOp.getNumOperands()); - - SmallVector workList; - auto addToWorkList = [&workList, - reducerBody = retOp->getParentRegion()](Value v) { - if (v.getParentRegion() == reducerBody) workList.push_back(v); - }; - - SmallPtrSet usedOps; - SmallBitVector usedArgs(numOperands); - SmallBitVector usedReturnOperands(numOperandPairs); - for (const auto &usedResult : usedResults) { - auto resultNo = usedResult.getResultNumber(); - usedReturnOperands.set(resultNo); - - // Follow the def-use chain starting from return operand to identify - // which argument pairs are used to compute it. - addToWorkList(retOp.getOperand(resultNo)); - while (!workList.empty()) { - auto definition = workList.pop_back_val(); - if (auto blockArg = dyn_cast(definition)) { - // using one argument implies using the whole argument pair - const auto pairNo = blockArg.getArgNumber() % numOperandPairs; - usedArgs.set(pairNo); - usedArgs.set(pairNo + numOperandPairs); - } else if (auto *defOp = definition.getDefiningOp()) { - usedOps.insert(defOp); - for (const auto &operand : defOp->getOperands()) - addToWorkList(operand); - } - } - } - - const auto newNumOperandPairs = usedResults.size(); - const auto newNumOperands = newNumOperandPairs * pairSize; - if (newNumOperands != usedArgs.count()) - return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { - diag << "non-conservative case: " << newNumOperandPairs - << " return results should be matched with " << newNumOperands - << " operands, but got " << usedArgs.count(); - }); - - SmallVector newInputs; - SmallVector newInitVals; - SmallVector newElementTypes; - for (auto i : llvm::seq(0u, numOperandPairs)) { - if (usedReturnOperands[i]) - newElementTypes.push_back( - getElementTypeOrSelf(retOp.getOperand(i).getType())); - - if (!usedArgs[i]) continue; - - newInputs.push_back(op.getOperand(i)); - newInitVals.push_back(op.getOperand(i + numOperandPairs)); - } - - auto newOp = - rewriter.create(op.getLoc(), newInputs, newInitVals, - op.getDimensionsAttr(), newElementTypes); - Block *newReducerBlock = rewriter.createBlock(&newOp.getBody()); - - IRMapping mapper; - for (auto arg : reducerBlock.getArguments()) - if (usedArgs[arg.getArgNumber()]) - mapper.map(arg, - newReducerBlock->addArgument(arg.getType(), arg.getLoc())); - - rewriter.setInsertionPointToStart(newReducerBlock); - for (Operation &op : reducerBlock.getOperations()) - if (usedOps.contains(&op)) rewriter.clone(op, mapper); - - SmallVector newReturnOperands; - for (const auto &en : llvm::enumerate(retOp.getOperands())) - if (usedReturnOperands[en.index()]) - newReturnOperands.push_back(mapper.lookup(en.value())); - - rewriter.create(retOp.getLoc(), newReturnOperands); - - // Build new results list (unused entries will be null). - SmallVector newResults(op.getNumResults()); - for (const auto &[i, result] : llvm::enumerate(usedResults)) { - newResults[result.getResultNumber()] = newOp.getResult(i); - } - - rewriter.replaceOp(op, newResults); - return success(); - } -}; - -///////////////////////////////// -// GetDimensionSizeOp -///////////////////////////////// - -// TODO: This is duplicated with a pattern in shape refinement, consider -// consolidating. -// Pattern: get_dimension_size(X, i) -> X.shape[i] -struct GetDimensionSizeOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GetDimensionSizeOp op, - PatternRewriter &rewriter) const override { - // Fold get_dimension_size when the queried dim is statically known. - RankedTensorType operandTy = op.getOperand().getType(); - - int64_t dimSize = operandTy.getDimSize(op.getDimension()); - if (dimSize < 0) return failure(); - - auto elemTy = cast(op.getType().getElementType()); - IntegerAttr elemVal = rewriter.getIntegerAttr(elemTy, dimSize); - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(op.getType(), elemVal)); - return success(); - } -}; - -////////////////////////////////// -// GatherOp -///////////////////////////////// - -/// Converts gather ops to slice ops in case we have a single set of constant -/// indices. -// Pattern: gather(X, cst_start_indices) -> slice(X, slice_start, slice_end) -struct GatherOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GatherOp gather, - PatternRewriter &rewriter) const override { - DenseIntElementsAttr index; - if (!matchPattern(gather.getStartIndices(), m_Constant(&index))) - return failure(); - - GatherDimensionNumbersAttr dnums = gather.getDimensionNumbers(); - if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1) - return failure(); - - // TODO: Remove when the verifier catches this case that is - // invalid if all previous condition holds. - if (index.getNumElements() != - static_cast(dnums.getStartIndexMap().size())) { - return failure(); - } - - auto operandType = cast(gather->getOperand(0).getType()); - if (!operandType.hasStaticShape()) return failure(); - - auto sliceEnd = llvm::to_vector(gather.getSliceSizes()); - SmallVector sliceStart(sliceEnd.size(), 0); - for (auto [mapIndex, value] : - llvm::zip_equal(dnums.getStartIndexMap(), index.getValues())) { - // Clamp the indices within bounds to faithfully mirror gather semantics. - int64_t offset = - std::clamp(value.getSExtValue(), static_cast(0), - operandType.getDimSize(mapIndex) - sliceEnd[mapIndex]); - sliceStart[mapIndex] += offset; - sliceEnd[mapIndex] += offset; - } - - SmallVector sliceStride(sliceEnd.size(), 1); - SmallVector sliceShape(sliceEnd.size()); - for (auto [shapeElem, startElem, endElem] : - llvm::zip_equal(sliceShape, sliceStart, sliceEnd)) { - shapeElem = endElem - startElem; - } - - Type elementType = gather.getType().getElementType(); - auto sliceType = RankedTensorType::get(sliceShape, elementType); - Value result = rewriter.create( - gather.getLoc(), sliceType, gather.getOperand(), - rewriter.getDenseI64ArrayAttr(sliceStart), - rewriter.getDenseI64ArrayAttr(sliceEnd), - rewriter.getDenseI64ArrayAttr(sliceStride)); - - ArrayRef collapsedSliceDims = dnums.getCollapsedSliceDims(); - if (!collapsedSliceDims.empty()) { - llvm::SmallVector reshapeShape; - for (auto [idx, dim] : llvm::enumerate(sliceShape)) { - if (!llvm::is_contained(collapsedSliceDims, idx)) - reshapeShape.push_back(dim); - } - auto reshapeType = RankedTensorType::get(reshapeShape, elementType); - result = rewriter.create(gather.getLoc(), reshapeType, result); - } - - result.setType(gather.getType()); - rewriter.replaceOp(gather, result); - return success(); - } -}; - -////////////////////////////////// -// IotaOp -///////////////////////////////// - -// Iota operations across multiple dimensions can be reduced to an iota and a -// ranked broadcast. -// Pattern: iota(dim) : multi_rank -// -> broadcast_in_dim(iota(dim) : array, multi_rank) -struct IotaOpBroadcast : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(IotaOp iota, - PatternRewriter &rewriter) const override { - auto resultTy = cast(iota.getType()); - if (resultTy.getRank() < 2) - return rewriter.notifyMatchFailure(iota, "itoa not broadcastable"); - - auto iotaDim = iota.getIotaDimension(); - auto iotaDimSize = resultTy.getDimSize(iotaDim); - auto iota1D = rewriter.create( - iota.getLoc(), - RankedTensorType::get({iotaDimSize}, resultTy.getElementType()), - rewriter.getI64IntegerAttr(0)); - - auto broadcastAttr = - rewriter.getDenseI64ArrayAttr({static_cast(iotaDim)}); - rewriter.replaceOpWithNewOp(iota, resultTy, iota1D, - broadcastAttr); - return success(); - } -}; - -////////////////////////////////// -// PadOp -///////////////////////////////// - -// If the input tensor has a dimension of length-0, the input tensor is -// irrelevant. Instead we can broadcast the pad value to the output size rather -// than pad the input tensor. - -// If the input tensor has a dimension of length-0, the input tensor is -// irrelevant. Instead we can broadcast the pad value to the output size rather -// than pad the input tensor. - -// Pattern: pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _) -struct PadOpBroadcastEmptyTensor : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(PadOp op, - PatternRewriter &rewriter) const override { - auto operand = op.getOperand(); - auto padVal = op.getPaddingValue(); - - auto resultTy = cast(op.getType()); - - if (cast(operand.getType()).getNumElements() != 0) - return rewriter.notifyMatchFailure(op, "operand is not empty tensor"); - - if (resultTy.hasStaticShape()) { - rewriter.replaceOpWithNewOp( - op, resultTy, padVal, rewriter.getDenseI64ArrayAttr({})); - return success(); - } - - llvm::SmallVector reifiedShapes; - if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), - reifiedShapes))) - return rewriter.notifyMatchFailure(op, "failed to reify return type"); - - rewriter.replaceOpWithNewOp( - op, op.getType(), padVal, reifiedShapes.front(), - rewriter.getDenseI64ArrayAttr({})); - return success(); - } -}; - -////////////////////////////////// -// SelectOp -///////////////////////////////// - -struct SelectOpCanon final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SelectOp op, - PatternRewriter &rewriter) const override { - RankedTensorType type = op.getType(); - - Value trueVal = op.getOnTrue(); - Value falseVal = op.getOnFalse(); - - // Eliminate select with two identical outcomes. - if (trueVal == falseVal) { - rewriter.replaceOp(op, trueVal); - return success(); - } - - // Simplify when the condition is a constant. - Value pred = op.getPred(); - ElementsAttr cond; - if (!matchPattern(pred, m_Constant(&cond))) return failure(); - - // Handle splat predicate and select either `trueVal` or `falseVal`. - if (cond.isSplat()) { - rewriter.replaceOp(op, cond.getSplatValue() ? trueVal : falseVal); - return success(); - } - - // Handle elementwise selection when both outcomes are also constants. This - // will create a new, likely non-splat constant. - if (cond.getNumElements() > kFoldOpEltLimit) return failure(); - - ElementsAttr trueAttr; - if (!matchPattern(trueVal, m_Constant(&trueAttr))) return failure(); - - ElementsAttr falseAttr; - if (!matchPattern(falseVal, m_Constant(&falseAttr))) return failure(); - - SmallVector newValues; - newValues.reserve(cond.getNumElements()); - for (auto [condElem, trueElem, falseElem] : llvm::zip_equal( - cond.getValues(), trueAttr.getValues(), - falseAttr.getValues())) { - newValues.push_back(condElem ? trueElem : falseElem); - } - - rewriter.replaceOpWithNewOp( - op, DenseElementsAttr::get(type, newValues)); - return success(); - } -}; - -struct CompareSelectIntoMinMax final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SelectOp op, - PatternRewriter &rewriter) const override { - Value pred = op.getPred(); - Value trueVal = op.getOnTrue(); - Value falseVal = op.getOnFalse(); - - auto cmpOp = pred.getDefiningOp(); - if (!cmpOp) return failure(); - - ComparisonDirection direction = cmpOp.getComparisonDirection(); - Value cmpLhs = cmpOp.getLhs(); - Value cmpRhs = cmpOp.getRhs(); - - // Turn into canonical form: - // b <= a ? a : b ---> a >= b ? a : b - // b < a ? a : b ---> a > b ? a : b - // b >= a ? a : b ---> a <= b ? a : b - // b > a ? a : b ---> a < b ? a : b - if (cmpLhs == falseVal && cmpRhs == trueVal) { - direction = invertDirection(direction); - } else if (!(cmpLhs == trueVal && cmpRhs == falseVal)) { - return failure(); - } - - switch (direction) { - case ComparisonDirection::GE: - case ComparisonDirection::GT: { - rewriter.replaceOpWithNewOp(op, trueVal, falseVal); - return success(); - } - case ComparisonDirection::LE: - case ComparisonDirection::LT: { - rewriter.replaceOpWithNewOp(op, trueVal, falseVal); - return success(); - } - default: { - return failure(); - } - } - } -}; - -////////////////////////////////// -// SliceOp -///////////////////////////////// - -// In cases where a concat is fed into a slice, it is possible the concat -// can be simplified or bypassed. This checks which inputs to the concat are -// used by the slice, either reducing the number of concatenated values or -// entirely removes the concat. -// Pattern: slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z)) -struct SliceOpConcatSimplify : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SliceOp slice, - PatternRewriter &rewriter) const override { - auto resultTy = cast(slice.getType()); - if (!resultTy.hasStaticShape()) - return rewriter.notifyMatchFailure(slice, "result shape not static"); - - auto concat = slice.getOperand().getDefiningOp(); - if (!concat) - return rewriter.notifyMatchFailure(slice, "slice input not concat"); - - auto concatType = cast(concat.getType()); - auto dimension = concat.getDimension(); - - auto start = slice.getStartIndices(); - auto limit = slice.getLimitIndices(); - - int64_t sliceStart = start[dimension]; - int64_t sliceLimit = limit[dimension]; - - // We need to determine what inputs from the concat affect the slice, and - // how the bounds of the slice need to be updated for the minimally required - // inputs. - int64_t runningSize = 0; - int64_t frontOffset = concatType.getShape()[dimension]; - - auto subsetStart = concat.operand_end(); - auto subsetEnd = concat.operand_end(); - for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) { - auto input = *it; - ShapedType inputTy = cast(input.getType()); - if (inputTy.isDynamicDim(dimension)) - return rewriter.notifyMatchFailure( - slice, "concat input has dynamic dimension"); - - auto dimSize = inputTy.getShape()[dimension]; - - // If this position is in the slice its the start of the subset and we - // need to update the start and limit values. - if (runningSize + dimSize > sliceStart && - subsetStart == concat.operand_end()) { - subsetStart = it; - frontOffset = runningSize; - } - - // Determine the last required offset. - if (runningSize < sliceLimit) { - subsetEnd = it + 1; - } - - runningSize += dimSize; - } - - auto subsetSize = subsetEnd - subsetStart; - // We need all inputs so no optimization. - if (subsetSize == concat.getNumOperands()) - return rewriter.notifyMatchFailure(slice, - "slice needs all concat inputs"); - - // If there's nothing to slice that means the output is an empty tensor and - // there is dead code. We do nothing here and rely on other passes to clean - // this up. - if (subsetSize == 0) - return rewriter.notifyMatchFailure(slice, "slice is empty"); - - if (subsetSize > 1 && !concat.getResult().hasOneUse()) - return rewriter.notifyMatchFailure(slice, - "slice is not the only concat user"); - - auto concatRange = OperandRange(subsetStart, subsetEnd); - auto newConcat = rewriter.create( - concat.getLoc(), concatRange, concat.getDimension()); - - SmallVector newStart(start); - SmallVector newLimit(limit); - newStart[dimension] -= frontOffset; - newLimit[dimension] -= frontOffset; - - rewriter.replaceOpWithNewOp( - slice, newConcat, rewriter.getDenseI64ArrayAttr(newStart), - rewriter.getDenseI64ArrayAttr(newLimit), slice.getStrides()); - return success(); - } -}; - -////////////////////////////////// -// SortOp -///////////////////////////////// - -/// Drops the operands if the results are not used and they are not used in -/// op.comparator(). - -// Pattern: sort(X,Y) -> sort(X) [if Y unused and unused in comparator] -struct SortOpDropUnusedArgs : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SortOp op, - PatternRewriter &rewriter) const override { - DenseSet erasedArgs; - unsigned numOperands = op.getNumOperands(); - for (unsigned i = 0; i < numOperands; ++i) { - if (!op.getResult(i).use_empty()) continue; - Block &block = op.getComparator().front(); - if (!block.getArgument(i * 2).use_empty()) continue; - if (!block.getArgument(i * 2 + 1).use_empty()) continue; - erasedArgs.insert(i); - } - if (erasedArgs.empty()) return failure(); - - SmallVector newOperands; - BitVector erasedBlockArgs(op.getNumOperands() * 2); - for (const auto &en : llvm::enumerate(op.getInputs())) { - if (erasedArgs.contains(en.index())) { - erasedBlockArgs.set(en.index() * 2); - erasedBlockArgs.set(en.index() * 2 + 1); - } else { - newOperands.push_back(en.value()); - } - } - - auto newOp = rewriter.create(op.getLoc(), newOperands, - op.getDimension(), op.getIsStable()); - Region ®ion = newOp.getComparator(); - rewriter.inlineRegionBefore(op.getComparator(), region, region.end()); - region.front().eraseArguments(erasedBlockArgs); - - SmallVector results; - for (unsigned i = 0, j = 0; i < numOperands; ++i) { - if (erasedArgs.contains(i)) { - results.push_back({}); - } else { - results.push_back(newOp.getResult(j++)); - } - } - rewriter.replaceOp(op, results); - - return success(); - } -}; - -/// Set the sorting dimension to the last dimension if it's not set and the rank -/// is known. -// Pattern: sort(X) -> sort(X, dim = N) [when dim can be inferred] -struct SortOpSetDimension : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(SortOp op, - PatternRewriter &rewriter) const override { - if (op.getResults().empty() || - static_cast(op.getDimension()) != -1) - return rewriter.notifyMatchFailure(op, - "dimension already set or no results"); - - auto type = cast(op.getResultTypes()[0]); - IntegerAttr dim = rewriter.getI64IntegerAttr(type.getRank() - 1); - auto newOp = - rewriter.create(op.getLoc(), op.getResultTypes(), - op.getInputs(), dim, op.getIsStableAttr()); - newOp.getComparator().takeBody(op.getComparator()); - rewriter.replaceOp(op, newOp.getResults()); - return success(); - } -}; - -////////////////////////////////// -// TransposeOp -///////////////////////////////// - -// Pattern: transpose(X, [no_mem_layout_change...]) -> reshape(X) -struct TransposeIsReshape final : OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TransposeOp op, - PatternRewriter &rewriter) const override { - auto input = op.getOperand(); - auto permutation = op.getPermutation(); - - RankedTensorType inputTy = input.getType(); - if (!inputTy.hasStaticShape() || !op.getType().hasStaticShape()) - return rewriter.notifyMatchFailure( - op, - "requires input and output to be of a statically-shaped ranked " - "tensor type"); - - // Check that the permutation is a valid memory layout change. - // All non-zero/one dimensions must be in increasing order. - SmallVector nonZeroPerms; - nonZeroPerms.reserve(permutation.size()); - for (auto idx : permutation) - if (inputTy.getDimSize(idx) != 1) nonZeroPerms.push_back(idx); - - for (size_t i = 1; i < nonZeroPerms.size(); ++i) - if (nonZeroPerms[i - 1] > nonZeroPerms[i]) - return rewriter.notifyMatchFailure(op, "memory layout change"); - - rewriter.replaceOpWithNewOp(op, op.getType(), input); - return success(); - } -}; - -////////////////////////////////// -// TupleOp -///////////////////////////////// - -// Pattern: tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X -struct TupleIsRepacking : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(TupleOp op, - PatternRewriter &rewriter) const override { - if (op.getVal().empty()) - return rewriter.notifyMatchFailure(op, "empty tuple"); - - // Get parent tuple - Value firstElement = op.getVal().front(); - auto firstElementOp = firstElement.getDefiningOp(); - if (!firstElementOp) - return rewriter.notifyMatchFailure(op, "parent not get_tuple_element"); - - Value tuplePredecessor = firstElementOp.getOperand(); - if (tuplePredecessor.getType() != op.getType()) - return rewriter.notifyMatchFailure( - op, "tuple predecessor type does not match"); - - // Check that this is a repacking of the parent tuple. - for (const auto &elementAndIdx : llvm::enumerate(op.getVal())) { - auto elementOp = elementAndIdx.value().getDefiningOp(); - if (!elementOp || - elementOp.getIndexAttr().getInt() != - static_cast(elementAndIdx.index()) || - elementOp.getOperand() != tuplePredecessor) - return rewriter.notifyMatchFailure( - op, "not a repacking of the parent tuple"); - } - - rewriter.replaceOp(op, tuplePredecessor); - return success(); - } -}; - -///////////////////////////////// -// WhileOp -///////////////////////////////// - -// Turn loop invariant values into implicit capture. -// Check if there is at least one value is forwarded from one iteration to -// the next, or one of the yielded value is an implicit capture already. -// Otherwise there is nothing to do here. - -// Pattern: while -> while (loop invariants as implicit captures) -struct WhileOpImplicitCapture : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(WhileOp whileOp, - PatternRewriter &rewriter) const override { - Block *cond = whileOp.SingleBlock::getBody(0); - Block *body = whileOp.SingleBlock::getBody(1); - auto bodyReturnOp = cast(body->getTerminator()); - if (!llvm::any_of(llvm::zip(whileOp->getOperands(), body->getArguments(), - bodyReturnOp->getOperands()), - [&](auto zip) { - return (std::get<0>(zip) == std::get<2>(zip) || - std::get<1>(zip) == std::get<2>(zip)); - })) - return rewriter.notifyMatchFailure(whileOp, "no loop invariant found"); - - SmallVector newOperands, resultsToReplace; - SmallVector invariantArgIdxs; - BitVector invariantArgIdxBitVector(cond->getNumArguments()); - for (const auto &enumeratedOperands : llvm::enumerate(llvm::zip( - whileOp.getOperands(), cond->getArguments(), body->getArguments(), - bodyReturnOp->getOperands(), whileOp->getResults()))) { - const auto &operands = enumeratedOperands.value(); - Value whileOperand = std::get<0>(operands); - BlockArgument condBlockArg = std::get<1>(operands); - BlockArgument bodyBlockArg = std::get<2>(operands); - Value bodyReturnOperand = std::get<3>(operands); - Value whileResult = std::get<4>(operands); - - bool forwarded = (whileOperand == bodyReturnOperand || - bodyBlockArg == bodyReturnOperand); - if (forwarded) { - invariantArgIdxs.push_back(enumeratedOperands.index()); - invariantArgIdxBitVector.set(enumeratedOperands.index()); - condBlockArg.replaceAllUsesWith(whileOperand); - bodyBlockArg.replaceAllUsesWith(whileOperand); - whileResult.replaceAllUsesWith(whileOperand); - continue; - } - newOperands.push_back(whileOperand); - resultsToReplace.push_back(whileResult); - } - cond->eraseArguments(invariantArgIdxBitVector); - body->eraseArguments(invariantArgIdxBitVector); - for (int idx : llvm::reverse(invariantArgIdxs)) - bodyReturnOp->eraseOperand(idx); - - WhileOp newWhileOp = rewriter.create( - whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands); - newWhileOp.getBodyRegion(0).takeBody(whileOp.getBodyRegion(0)); - newWhileOp.getBodyRegion(1).takeBody(whileOp.getBodyRegion(1)); - for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults())) - std::get<0>(results).replaceAllUsesWith(std::get<1>(results)); - rewriter.eraseOp(whileOp); - return success(); - } -}; - -////////////////////////////////// -// Generic and Elementwise Ops -///////////////////////////////// - -/// Check if a `t` is a tensor with zero extents. -static std::optional getMaybeZeroExtentType(Type t) { - auto type = dyn_cast(t); - if (type && type.hasStaticShape() && type.getNumElements() == 0) return type; - return std::nullopt; -} - -// Replace instances of zero extent tensors with empty tensors -// Pattern: op(X : zero_extent_tensor) -> constant([]) -struct ZeroExtentToEmptyConstant final : RewritePattern { - ZeroExtentToEmptyConstant(MLIRContext *context, PatternBenefit benefit) - : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - auto loc = op->getLoc(); - - if (!isa_and_present(op->getDialect())) - return rewriter.notifyMatchFailure(op, "not stablehlo"); - if (isa(op)) - return rewriter.notifyMatchFailure(op, "op is empty constant"); - - // If the result is a zero-extent tensor, replace the whole op with an empty - // constant. - bool didUpdate = false; - for (auto result : op->getResults()) { - auto resultType = getMaybeZeroExtentType(result.getType()); - if (!resultType || result.use_empty()) continue; - rewriter.replaceAllUsesWith( - result, rewriter.create( - loc, result.getType(), - DenseElementsAttr::get(resultType.value(), - ArrayRef()))); - didUpdate = true; - } - - // If one of the operands is a zero-extent tensor, replace the operand with - // an empty tensor. - for (OpOperand &operand : op->getOpOperands()) { - auto operandType = getMaybeZeroExtentType(operand.get().getType()); - if (!operandType || operand.get().getDefiningOp()) continue; - Operation *owner = operand.getOwner(); - int operandNum = operand.getOperandNumber(); - auto emptyConstantOp = rewriter.create( - loc, operandType.value(), - DenseElementsAttr::get(operandType.value(), ArrayRef())); - rewriter.modifyOpInPlace( - owner, [&]() { owner->setOperand(operandNum, emptyConstantOp); }); - didUpdate = true; - } - return success(didUpdate); - } -}; - -struct ReorderElementwiseAndShapeOp final - : OpTraitRewritePattern { - using OpTraitRewritePattern::OpTraitRewritePattern; - - LogicalResult matchAndRewrite(Operation *op, - PatternRewriter &rewriter) const override { - if (op->getOperands().size() != 1) - return rewriter.notifyMatchFailure(op, "expected to be unary"); - - auto definingOp = op->getOperand(0).getDefiningOp(); - if (!definingOp) - return rewriter.notifyMatchFailure( - op, "expected to have an op before elementise op"); - - if (!isa(definingOp)) - return rewriter.notifyMatchFailure( - op, "defining operation of unexpected type"); - - // Reshape and broadcast are not allowed to have dynamic shape. - Value result = op->getResult(0); - if (isa(definingOp) && - !cast(result.getType()).hasStaticShape()) - return rewriter.notifyMatchFailure( - op, "cannot reorder around reshape/broadcast with dynamic shape"); - - // Only reorder if the defining op has no other uses. - if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) - return rewriter.notifyMatchFailure(op, "operation has more than one use"); - - Value input = definingOp->getOperand(0); - auto intermediateType = cast(input.getType()) - .clone(getElementTypeOrSelf(result.getType())); - - // Reorder the operation and rewire the inputs/outputs. - op->moveBefore(definingOp); - definingOp->getResult(0).setType(result.getType()); - rewriter.replaceAllUsesWith(result, definingOp->getResult(0)); - result.setType(intermediateType); - op->setOperands(input); - definingOp->setOperands(result); - return success(); - } -}; - -struct StablehloAggressiveSimplificationPass final - : impl::StablehloAggressiveSimplificationPassBase< - StablehloAggressiveSimplificationPass> { - StablehloAggressiveSimplificationPass() = default; - StablehloAggressiveSimplificationPass(GreedyRewriteConfig config) - : config(config) {} - LogicalResult initialize(MLIRContext *context) override { - RewritePatternSet patterns_(context); - populateStablehloCanonicalizationPatterns(context, &patterns_); - patterns = std::move(patterns_); - return success(); - } - - void runOnOperation() override { - if (failed(applyPatternsGreedily(getOperation(), patterns, config))) - signalPassFailure(); - } - - private: - GreedyRewriteConfig config; - FrozenRewritePatternSet patterns; -}; - -#include "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.h.inc" -} // namespace - -void populateStablehloCanonicalizationPatterns(MLIRContext *context, - RewritePatternSet *patterns, - PatternBenefit benefit) { - populateWithGenerated(*patterns); - patterns->add(context); - patterns->add< - CompareOpCanon, CompareSelectIntoMinMax, ConcatenateOpFlatten, - ConcatenateOpNoop, ConcatenateOpRemoveEmpty, DynamicIotaOpToBroadcast, - DynamicReshapeOpSameOperandAndResultShape, DynamicSliceOpToSlice, - GatherOpCanon, IotaOpBroadcast, PadOpBroadcastEmptyTensor, - RealDynamicSliceOpToDynamicSlice, ReduceOpEmptyCanon, - ReduceOpNoopVariableReturn, ReduceOpUnusedResultCanon, SelectOpCanon, - SliceOpConcatSimplify, SortOpDropUnusedArgs, SortOpSetDimension, - TransposeIsReshape, TupleIsRepacking, WhileOpImplicitCapture>(context, - benefit); - - // Generic patterns - patterns->add( - context, benefit); - - // TODO: Dynamism Refinements, consider merging with canonicalize dynamism - patterns - ->add(context); -} - -std::unique_ptr createStablehloAggressiveSimplificationPass( - GreedyRewriteConfig config) { - return std::make_unique(config); -} - -} // namespace stablehlo -} // namespace mlir diff --git a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td index 9cbcc07ca6..e69de29bb2 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td +++ b/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td @@ -1,427 +0,0 @@ -// Copyright 2020 The IREE Authors -// -// Licensed under the Apache License, Version 2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -// This is the legalization pattern definition file for CHLO to StableHLO. -// These are included in the populateDecompositionPatterns factory -// and should only include canonical expansions which are not actually -// ambiguous/different for various backends. Avoid patterns that are actually -// lowering to non-canonical forms. - -include "mlir/IR/OpBase.td" -include "stablehlo/dialect/StablehloOps.td" -include "mlir/Dialect/Shape/IR/ShapeOps.td" - -/////////// -//// Op & Type Constraints - -class DimSizeEquals : Constraint< - CPred<"llvm::cast($0.getType()).getDimSize($1.getInt()) == " # dimSize>, - "dim size is " # dimSize>; - -def AllDimsNonExpanding : Constraint< - CPred<"$0 && cast($0).size() == llvm::cast($1.getType()).getRank()">, - "all dims are non-expanding">; - -def AllZero : Constraint< - CPred<"llvm::all_of($0, [](Value operand) {return matchPattern(operand, m_Zero()); })">, - "is all zero">; - -def CommutativeOp : Constraint< - CPred<"$0.getDefiningOp()->hasTrait()">, - "op is commutative">; - -def HasOneUse : Constraint>; - -def NotConstantOp : Constraint< - CPred<"llvm::isa($0) || !llvm::isa($0.getDefiningOp())">, - "is not a constant.">; - -def NumberOfElementsEqual : Constraint< - CPred<"llvm::cast($0.getType()).getNumElements() == llvm::cast($1.getType()).getNumElements()">, - "same number of elements">; - -def OperandsEqual : Constraint, "operands are equal">; - -def RankEqual : Constraint< - CPred<"llvm::cast($0.getType()).getRank() == llvm::cast($1.getType()).getRank()">, - "same rank">; - -def TypesEqual : Constraint, "operands are equal">; - -/////////// -//// Attribute Constraints - -def AnySplat : AttrConstraint, "is any splat">; - -def AnyZero : AttrConstraint< - CPred<"::mlir::matchPattern($_self, m_AnyAttrOf(m_Zero(), m_AnyZeroFloat()))">, - "is int or float zero">; - -def DenseIntElementsAttr : AttrConstraint< - CPred<"llvm::isa($_self)">, - "is dense int elements attr">; - -def EmptyI64Array : AttrConstraint< - CPred<"cast($_self).empty()">, - "is empty i64 array">; - -def IntOne : AttrConstraint< - CPred<"::mlir::matchPattern($_self, m_One())">, - "is integer one">; - -def IntAllOnes : AttrConstraint< - CPred<[{ - ::mlir::matchPattern($_self, - ::mlir::detail::constant_int_predicate_matcher{ - [](const llvm::APInt &val) { - return val.isAllOnes(); - }}) - }]>, - "is integer with all bits set to 1">; - -def IntZero : AttrConstraint< - CPred<"::mlir::matchPattern($_self, m_Zero())">,"is integer zero">; - -def IotaDims : AttrConstraint< - CPred<"isIotaRange(cast($_self).asArrayRef())">, - "is iota dimensions">; - -def SortedDims : AttrConstraint< - CPred<"llvm::is_sorted(cast($_self).asArrayRef())">, - "is sorted dimensions">; - -def ZeroExtent : AttrConstraint< - CPred<"cast($_self).getNumElements() == 0">, - "is zero extent">; - -/////////// -//// Native Code Call Utilities - -def CastIntElementsAttr : NativeCodeCall<"cast($0)">; - -def ConvertToI64Array : NativeCodeCall<"convertToI64Array($_builder, $0)">; - -def GetOperandN : NativeCodeCall<"$0.getDefiningOp()->getOperand($1.getInt())">; - -def GetEmptyI64Array : NativeCodeCall<"$_builder.getDenseI64ArrayAttr({})">; - -def MergeBroadcastDims : NativeCodeCall<"getMergedBroadcastDimensions($_builder, $0, $1)">; - -def StableHLO_ConvertOpWithShape : NativeCodeCall< - "$_builder.create($_loc, $0.getType(), $1)">; - -def StableHLO_ReshapeOpWithShape : NativeCodeCall< - "$_builder.create($_loc, $0.getType(), $1)">; - -class StableHLO_ConstantLike : NativeCodeCall< - "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; - -//////////////////////////// -// Generic BinaryOp Patterns - -// op(cst, X) -> op(X, cst) -class CanonicalizeConstantToRhs - : Pat<(StableHLO_OpType:$op (StableHLO_ConstantOp:$lhs $value), $rhs), - (StableHLO_OpType $rhs, $lhs), - [(NotConstantOp $rhs), (CommutativeOp $op)]>; - -//////// -// AddOp - -// Pattern: add(cst, X) -> add(X, cst) -def : CanonicalizeConstantToRhs; - -// Pattern: add(X, 0) -> X -def : Pat<(StableHLO_AddOp $lhs, (ConstantLikeMatcher AnyZero:$value)), - (replaceWithValue $lhs)>; - -//////// -// AndOp - -// Pattern: and(cst, X) -> and(X, cst) -def : CanonicalizeConstantToRhs; - -// Pattern: and(X, 0) -> 0 -def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), - (replaceWithValue $zero)>; - -// Pattern: and(X, 1) -> X -def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)), - (replaceWithValue $lhs)>; - -//////// -// BroadcastInDimOp - -// Pattern: broadcast_in_dim(X, [iota...]) -> X -def : Pat<(StableHLO_BroadcastInDimOp:$op $operand, IotaDims:$dims), - (replaceWithValue $operand), - [(TypesEqual $op, $operand)]>; - -// Pattern: broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB)) -def : Pat<(StableHLO_BroadcastInDimOp - (StableHLO_BroadcastInDimOp $operand, $dims_parent), $dims), - (StableHLO_BroadcastInDimOp $operand, (MergeBroadcastDims $dims, $dims_parent))>; - -// Pattern: broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel] -def : Pat<(StableHLO_BroadcastInDimOp:$op $operand, SortedDims:$dims), - (StableHLO_ReshapeOpWithShape $op, $operand), - [(NumberOfElementsEqual $op, $operand)]>; - -// Pattern: broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank] -def : Pat<(StableHLO_BroadcastInDimOp:$op $operand, $dims), - (StableHLO_TransposeOp $operand, $dims), - [(NumberOfElementsEqual $op, $operand), (RankEqual $op, $operand)]>; - -//////// -// ConvertOp - -// Pattern: convert(X, [X.type]) -> X -def : Pat<(StableHLO_ConvertOp:$convert $operand), - (replaceWithValue $operand), - [(TypesEqual $convert, $operand)]>; - -//////// -// DynamicBroadcastInDimOp - -// Pattern: dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB)) -// TODO: Think more if the values of known_expanding_dimensions and known_non_expanding_dimensions can be preserved. -def : Pat<(StableHLO_DynamicBroadcastInDimOp - (StableHLO_DynamicBroadcastInDimOp $operand, $shape_p, $dims_p, $expanding_p, $nonexpanding_p), - $shape, $dims, $expanding, $nonexpanding), - (StableHLO_DynamicBroadcastInDimOp $operand, $shape, (MergeBroadcastDims $dims, $dims_p), (GetEmptyI64Array), (GetEmptyI64Array))>; - -// Pattern: dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X) -// No-op, but wrap in ConvertOp to preserve dynamic output shape, can be -// important if this result is returned, where refining type would require -// also updating the funciton signature. -def : Pat<(StableHLO_DynamicBroadcastInDimOp:$op $operand, $shape, IotaDims:$dims, $expanding, $nonexpanding), - (StableHLO_ConvertOpWithShape $op, $operand), - [(AllDimsNonExpanding $nonexpanding, $op)]>; - -// Pattern: dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape) -// If sharing same shape operand, is dynamic reshape. -def : Pat<(StableHLO_DynamicBroadcastInDimOp - (StableHLO_DynamicReshapeOp $operand, $shape), $shape, IotaDims:$dims, $expanding, $nonexpanding), - (StableHLO_DynamicReshapeOp $operand, $shape)>; - -// Pattern: dynamic_broadcast_in_dim(X, shape_of(X)) -> X -def : Pat<(StableHLO_DynamicBroadcastInDimOp - $operand, (Shape_ShapeOfOp $operand), IotaDims:$dims, $expanding, $nonexpanding), - (replaceWithValue $operand)>; - -//////// -// DynamicGatherOp - -// Pattern: dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes) -def : Pat<(StableHLO_DynamicGatherOp $operand, $start_indices, (StableHLO_ConstantOp DenseIntElementsAttr:$slice_sizes), $dimension_numbers, $indices_are_sorted), - (StableHLO_GatherOp $operand, $start_indices, $dimension_numbers, (ConvertToI64Array $slice_sizes), $indices_are_sorted)>; - -//////// -// DynamicPadOp - -// Pattern: dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) -// [if low, high, interior are all constants] -def : Pat<(StableHLO_DynamicPadOp $input, - $padding_value, - (ConstantLikeMatcher AnyIntElementsAttr:$edge_padding_low), - (ConstantLikeMatcher AnyIntElementsAttr:$edge_padding_high), - (ConstantLikeMatcher AnyIntElementsAttr:$interior_padding)), - (StableHLO_PadOp $input, $padding_value, - (ConvertToI64Array $edge_padding_low), - (ConvertToI64Array $edge_padding_high), - (ConvertToI64Array $interior_padding))>; - -//////// -// DynamicReshapeOp - -// Pattern: dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape) -def : Pat<(StableHLO_DynamicReshapeOp (StableHLO_DynamicReshapeOp $operand, $shape_p), $shape), - (StableHLO_DynamicReshapeOp $operand, $shape)>; - -// Pattern: shape_of(dynamic_reshape(X, shape)) -> shape -def : Pat<(Shape_ShapeOfOp:$op (StableHLO_DynamicReshapeOp $x, $shape)), - (replaceWithValue $shape), - [(TypesEqual $shape, $op)]>; - -//////// -// DynamicUpdateSliceOp - -// Pattern: dynamic_update_slice(X, update : zero_extent)) -> X -def : Pat<(StableHLO_DynamicUpdateSliceOp $operand, (ConstantLikeMatcher ZeroExtent:$update), $start_indices), - (replaceWithValue $operand)>; - -// Pattern: dynamic_update_slice(X, update, start_indices : zero)) -> update -def : Pat<(StableHLO_DynamicUpdateSliceOp AnyStaticShapeTensor:$operand, AnyStaticShapeTensor:$update, $start_indices), - (replaceWithValue $update), - [(TypesEqual $operand, $update), (AllZero $start_indices)]>; - - -//////// -// ComplexOp - -// Pattern: complex(real(X), imag(X))) -> X -def : Pat<(StableHLO_ComplexOp (StableHLO_RealOp $operand), (StableHLO_ImagOp $operand)), - (replaceWithValue $operand)>; - - -//////// -// ImagOp - -// Pattern: imag(complex(R,I)) -> I -def : Pat<(StableHLO_ImagOp (StableHLO_ComplexOp $lhs, $rhs)), - (replaceWithValue $rhs)>; - -//////// -// IotaOp - -// Pattern: iota(dim) : type -> constant(0) : type [if type[dim] == 1] -def : Pat<(StableHLO_IotaOp:$iota $dim), - (StableHLO_ConstantLike<"0"> $iota), - [(DimSizeEquals<1> $iota, $dim)]>; - - -//////// -// MaxOp - -// Pattern: max(cst, X) -> max(X, cst) -def : CanonicalizeConstantToRhs; - -//////// -// MinOp - -// Pattern: minimum(cst, X) -> minimum(X, cst) -def : CanonicalizeConstantToRhs; - -//////// -// MulOp - -// Pattern: multiply(cst, X) -> multiply(X, cst) -def : CanonicalizeConstantToRhs; - -// Pattern: multiply(X, 0i) -> 0i -// Multiplication by 0. This fold is not trivial for floats in presence of NaNs -def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), - (replaceWithValue $zero)>; - -// Pattern: multiply(X, 1i) -> X -def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp IntOne:$value)), - (replaceWithValue $lhs)>; - -//////// -// OrOp - -// Pattern: or(cst, X) -> or(X, cst) -def : CanonicalizeConstantToRhs; - -// Pattern: or(X, 1) -> 1 -def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)), - (replaceWithValue $one)>; - -// Pattern: or(X, 0) -> X -def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), - (replaceWithValue $lhs)>; - -//////// -// RealDynamicSliceOp - -// Pattern: real_dynamic_slice(X, start, limit, strides) -// -> slice(X, start, limit, strides) -// [if start, limit, strides are all constants] -def : Pat<(StableHLO_RealDynamicSliceOp $operand, - (ConstantLikeMatcher DenseIntElementsAttr:$start_indices), - (ConstantLikeMatcher DenseIntElementsAttr:$limit_indices), - (ConstantLikeMatcher DenseIntElementsAttr:$strides)), - (StableHLO_SliceOp $operand, - (ConvertToI64Array $start_indices), - (ConvertToI64Array $limit_indices), - (ConvertToI64Array $strides))>; - -//////// -// RealOp - -// Pattern: real(complex(R,I)) -> X -def : Pat<(StableHLO_RealOp (StableHLO_ComplexOp $lhs, $rhs)), - (replaceWithValue $lhs)>; - -//////// -// ReduceOp -// Note: If modifying region is required, must write pattern in C++ - -// Pattern: reduce(X..., dims=[], add) -> X... -def : Pat<(StableHLO_ReduceOp $operands, $init, EmptyI64Array:$dims), - (replaceWithValue $operands)>; - -//////// -// ReshapeOp - -// Pattern: reshape(reshape(X, _), [shape]) -> reshape(X, [shape]) -def : Pat<(StableHLO_ReshapeOp:$reshape (StableHLO_ReshapeOp $operand)), - (StableHLO_ReshapeOpWithShape $reshape, $operand)>; - -// Pattern: reshape(X, [X.shape]) -> X -def : Pat<(StableHLO_ReshapeOp:$reshape $operand), - (replaceWithValue $operand), - [(TypesEqual $reshape, $operand)]>; - - -//////// -// SelectOp - -// Pattern: select(not(p), t, f) => select(p, f, t) -def : Pat< - (StableHLO_SelectOp (StableHLO_NotOp $pred), $on_true, $on_false), - (StableHLO_SelectOp $pred, $on_false, $on_true)>; - -// Pattern: select(broadcast(not(p)), t, f) => select(broadcast(p), f, t) -def : Pat<(StableHLO_SelectOp (StableHLO_BroadcastInDimOp:$b (StableHLO_NotOp $pred), $broadcast_dimensions), $on_true, $on_false), - (StableHLO_SelectOp (StableHLO_BroadcastInDimOp $pred, $broadcast_dimensions, (returnType $b)), $on_false, $on_true), - [(HasOneUse $b)]>; - -//////// -// SubtractOp - -// Pattern: subtract(X, X) -> 0 -// Must be static shape, otherwise would require broadcasting via CHLO_ConstantLike -def : Pat<(StableHLO_SubtractOp AnyStaticShapeTensor:$operand, $operand), - (StableHLO_ConstantLike<"0"> $operand)>; - -// Pattern: subtract(X, 0) -> X -def : Pat<(StableHLO_SubtractOp $lhs, (StableHLO_ConstantOp AnyZero:$value)), - (replaceWithValue $lhs)>; - -//////// -// SliceOp - -// Pattern: slice(X, [A:A], [B:B], ...) -> X -def : Pat<(StableHLO_SliceOp:$op AnyStaticShapeTensor:$operand, $start_indices, $limit_indices, $strides), - (replaceWithValue $operand), - [(TypesEqual $operand, $op)]>; - -//////// -// TransposeOp - -// Pattern: transpose(X, [iota...]) -> X -def : Pat<(StableHLO_TransposeOp $lhs, IotaDims:$dims), - (replaceWithValue $lhs)>; - -//////// -// GetTupleElementOp - -// Pattern: get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i -def : Pat<(StableHLO_GetTupleElementOp (StableHLO_TupleOp:$tuple $operands), $idx), - (GetOperandN $tuple, $idx)>; - -//////// -// XorOp - -// Pattern: xor(cst, X) -> xor(X, cst) -def : CanonicalizeConstantToRhs; - -// To consider: xor(X, X) -> 0 -// Unclear if this is beneficial on hardware vs adding another constant -// -// def : Pat<(StableHLO_XorOp AnyStaticShapeTensor:$operand, $operand), -// (StableHLO_ConstantLike<"0"> $operand)>; diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index 4274c09413..c8fb1e515b 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -53,6 +53,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/TypeInference.h" #include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" #define DEBUG_TYPE "stablehlo-refine-shapes" diff --git a/stablehlo/transforms/optimization/CMakeLists.txt b/stablehlo/transforms/optimization/CMakeLists.txt new file mode 100644 index 0000000000..f52054aec2 --- /dev/null +++ b/stablehlo/transforms/optimization/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright 2022 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name=Optimization) +add_public_tablegen_target(OptimizationPassesIncGen) + +set(LLVM_TARGET_DEFINITIONS StablehloAggressiveSimplificationPatterns.td) +mlir_tablegen(StablehloAggressiveSimplificationPatterns.h.inc --gen-rewriters) +add_public_tablegen_target(StablehloAggressiveSimplificationPatternsIncGen) + +add_mlir_dialect_library(StablehloOptimizationPasses + PARTIAL_SOURCES_INTENDED + StablehloAggressiveFolder.cpp + StablehloAggressiveSimplification.cpp + StablehloTargetIndependentOptimization.cpp + + DEPENDS + OptimizationPassesIncGen + StablehloAggressiveSimplificationPatternsIncGen + + LINK_LIBS PUBLIC + ChloOps + MLIRArithDialect + MLIRDialectUtils + MLIRFuncDialect + MLIRIR + MLIRRewrite + MLIRSupport + MLIRTransformUtils + StablehloBase + StablehloOps + StablehloTypeInference +) diff --git a/stablehlo/transforms/optimization/Passes.h b/stablehlo/transforms/optimization/Passes.h new file mode 100644 index 0000000000..0b1f060d7d --- /dev/null +++ b/stablehlo/transforms/optimization/Passes.h @@ -0,0 +1,57 @@ +/* Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H +#define STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "stablehlo/transforms/optimization/Passes.h.inc" + + +/// Collection of canonicalization patterns for StableHLO. +void populateStablehloCanonicalizationPatterns(MLIRContext *context, + RewritePatternSet *patterns, + PatternBenefit benefit = 1); + +/// Collection of folding patterns for StableHLO. +void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns, + MLIRContext *context, + bool foldFloat, + PatternBenefit benefit = 1); + +/// A subset of folding patterns for StableHLO that is necessary for shape +/// refinement. +void populateStablehloShapeFolderPatterns(RewritePatternSet *patterns, + MLIRContext *context, + bool foldFloat = false, + PatternBenefit benefit = 1); +} // namespace stablehlo +} // namespace mlir + +#endif // STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H diff --git a/stablehlo/transforms/optimization/Passes.td b/stablehlo/transforms/optimization/Passes.td new file mode 100644 index 0000000000..b3cbe9371e --- /dev/null +++ b/stablehlo/transforms/optimization/Passes.td @@ -0,0 +1,45 @@ +/* Copyright 2022 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def StablehloAggressiveFolderPass + : Pass<"stablehlo-aggressive-folder", "func::FuncOp"> { + let summary = "Folds StableHLO operations"; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + ]; + let options = [ + Option<"foldFloat", "fold-float", "bool", /*default=*/"true", + "Allow for potentially lossy computations using float type.">, + ]; +} + +def StablehloAggressiveSimplificationPass + : Pass<"stablehlo-aggressive-simplification", "func::FuncOp"> { + let summary = "Canonicalizes StableHLO operations"; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::arith::ArithDialect", + ]; +} + +def StablehloTargetIndependentOptimizationPass + : Pass<"stablehlo-target-independent-optimization", "func::FuncOp"> { + let summary = "Runs canonicalizers, folders, and other target-independent optimizations."; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + ]; +} diff --git a/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp new file mode 100644 index 0000000000..3487501aa4 --- /dev/null +++ b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp @@ -0,0 +1,943 @@ +/* Copyright 2024 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include +#include + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/APSInt.h" +#include "llvm/ADT/FloatingPointMode.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/CommonFolders.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Utils/IndexingUtils.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/Operation.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Types.h" +#include "mlir/IR/Value.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/Base.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/optimization/Passes.h" + + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DEF_STABLEHLOAGGRESSIVEFOLDERPASS +#include "stablehlo/transforms/optimization/Passes.h.inc" + +namespace { + +// This is an upper limit on how many elements can be folded by an op folder. +// This limit doesn't apply to some special cases like adding a zero, +// multiplying by one, doing many operations with splats. +constexpr int64_t kFoldOpEltLimit = 65536; + +// DenseElementsAttr can be constructed from ArrayRef but not from +// ArrayRef. This helper bridges the gap. +DenseIntElementsAttr getTensorAttr(ShapedType type, ArrayRef values) { + SmallVector supportedValues(values); + return DenseIntElementsAttr::get(type, supportedValues); +} + +APSInt getAPSInt(Type type, uint64_t value) { + unsigned numBits; + bool isUnsigned; + if (auto integerType = dyn_cast(type)) { + numBits = integerType.getWidth(); + // Signless types are treated as signed, per StableHLO convention. + isUnsigned = integerType.isUnsignedInteger(); + } else { + llvm::report_fatal_error("expected integer type"); + } + return APSInt( + {/*numBits=*/numBits, value, /*isSigned=*/false, /*implicitTrunc=*/true}, + /*isUnsigned=*/isUnsigned); +} + +LogicalResult validateResultTypeForEval(PatternRewriter& rewriter, + Operation* op, ShapedType resultType) { + if (!resultType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "unable to fold dynamically shaped result type to constant"); + return success(); +} + +/// Binary constant folder that used a generic folder function to handle both +/// ints and floats. +template +static TypedAttr foldBinaryOpIntOrFloat(TypedAttr lhs, TypedAttr rhs, + Fn&& folder) { + Attribute operands[2] = {lhs, rhs}; + Type elemTy = getElementTypeOrSelf(lhs); + + Attribute res; + if (isa(elemTy)) + res = constFoldBinaryOp(operands, + folder); + if (isa(elemTy)) + res = constFoldBinaryOp(operands, + folder); + if (res) return cast(res); + + return nullptr; +} + +template +LogicalResult evalConvertHelper(PatternRewriter& rewriter, OpType op, + DenseIntOrFPElementsAttr elements, Type resType, + CalculationT&& calculate) { + auto result = constFoldCastOp( + elements, resType, calculate); + + if (!result) + return rewriter.notifyMatchFailure(op, [&](Diagnostic& diag) { + diag << "cast of " << elements.getElementType() << " to " << resType + << " failed"; + }); + + rewriter.replaceOpWithNewOp(op, result); + return success(); +} + +template +LogicalResult evalConvert(PatternRewriter& rewriter, OpType op, + DenseIntOrFPElementsAttr elements, + RankedTensorType resultType) { + auto oldType = getElementTypeOrSelf(elements); + auto newType = getElementTypeOrSelf(resultType); + size_t newBitWidth = newType.getIntOrFloatBitWidth(); + + bool isOldTypeUnsigned = oldType.isInteger(1) || oldType.isUnsignedInteger(); + bool isNewTypeUnsigned = newType.isInteger(1) || newType.isUnsignedInteger(); + + if (isa(oldType)) { + if (auto newFloatType = dyn_cast(newType)) { + // Float -> Float + const auto& targetSemantics = newFloatType.getFloatSemantics(); + return evalConvertHelper( + rewriter, op, elements, resultType, + [&targetSemantics](const APFloat& operand, bool& castStatus) { + bool losesInfo; + APFloat newValue = operand; + castStatus = APFloat::opInvalidOp != + newValue.convert(targetSemantics, + llvm::RoundingMode::NearestTiesToEven, + &losesInfo); + return newValue; + }); + } + + // Float -> Int + return evalConvertHelper( + rewriter, op, elements, resultType, + [&newBitWidth, &isNewTypeUnsigned](const APFloat& operand, + bool& castStatus) { + APSInt api(newBitWidth, isNewTypeUnsigned); + if (operand.isInfinity() || operand.isNegZero()) { + castStatus = false; + return api; + } + bool ignored; + castStatus = + APFloat::opInvalidOp != + operand.convertToInteger(api, APFloat::rmTowardZero, &ignored); + return api; + }); + } + + if (auto newFloatType = dyn_cast(newType)) { + // Int -> Float + return evalConvertHelper( + rewriter, op, elements, resultType, + [&newFloatType, &isOldTypeUnsigned](const APInt& operand, + bool& /*castStatus*/) { + APFloat apf(newFloatType.getFloatSemantics(), + APInt::getZero(newFloatType.getWidth())); + apf.convertFromAPInt(operand, !isOldTypeUnsigned, + APFloat::rmNearestTiesToEven); + return apf; + }); + } + + // Int -> Int + return evalConvertHelper( + rewriter, op, elements, resultType, + [&newBitWidth, &isOldTypeUnsigned](const APInt& operand, + bool& /*castStatus*/) { + return APSInt(operand, isOldTypeUnsigned).extOrTrunc(newBitWidth); + }); +} + +// The patterns below implement partial evaluation of shape computations which +// is a critical part of implementing type refinement for ops like +// dynamic_broadcast_in_dim, dynamic_iota and dynamic_reshape whose shape +// depends on the value of their shape operands. + +template +LogicalResult evalElementwise(PatternRewriter& rewriter, OpType op, + FuncType fn) { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + if (!isa(resultType.getElementType())) + return rewriter.notifyMatchFailure(op, + "expected integer result tensor type"); + + SmallVector result; + if constexpr (OpType::template hasTrait()) { + SmallVector operand; + if (failed(hlo::matchInts(op.getOperand(), operand))) + return rewriter.notifyMatchFailure(op, "expected constant operand"); + for (const auto& operandEl : operand) { + result.push_back(fn(operandEl)); + } + } else if constexpr (OpType::template hasTrait< + OpTrait::NOperands<2>::Impl>()) { + SmallVector lhs, rhs; + if (failed(hlo::matchInts(op.getLhs(), lhs)) || + failed(hlo::matchInts(op.getRhs(), rhs))) + return rewriter.notifyMatchFailure(op, "expected constant operands"); + for (auto [lhsEl, rhsEl] : llvm::zip(lhs, rhs)) { + result.push_back(fn(lhsEl, rhsEl)); + } + } else if constexpr (OpType::template hasTrait< + OpTrait::NOperands<3>::Impl>()) { + SmallVector x, y, z; + if (failed(hlo::matchInts(op->getOperand(0), x)) || + failed(hlo::matchInts(op->getOperand(1), y)) || + failed(hlo::matchInts(op->getOperand(2), z))) + return rewriter.notifyMatchFailure(op, "expected constant operands"); + for (auto [xEl, yEl, zEl] : llvm::zip(x, y, z)) { + result.push_back(fn(xEl, yEl, zEl)); + } + } else { + llvm::report_fatal_error("unsupported number of operands"); + } + + rewriter.replaceOpWithNewOp(op, + getTensorAttr(resultType, result)); + return success(); +} + +struct FoldAddOpPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::AddOp op, + PatternRewriter& rewriter) const override { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + // Pattern: add(cst,cst) -> cst + TypedAttr lhsAttr, rhsAttr; + matchPattern(lhs, m_Constant(&lhsAttr)); + matchPattern(rhs, m_Constant(&rhsAttr)); + + if (TypedAttr res; + lhsAttr && rhsAttr && + (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::plus<>{}))) { + rewriter.replaceOpWithNewOp(op, res); + return success(); + } + + return failure(); + } +}; + +struct EvalAddOpShapePattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AddOp op, + PatternRewriter& rewriter) const override { + return evalElementwise(rewriter, op, + [&](APSInt lhs, APSInt rhs) { return lhs + rhs; }); + } +}; + +struct EvalAndOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(AndOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (!resultType.getElementType().isInteger(1)) + return rewriter.notifyMatchFailure(op, "expected boolean element type"); + + return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { + return getAPSInt(resultType.getElementType(), lhsInt != 0 && rhsInt != 0); + }); + } +}; + +// Pattern: broadcast_in_dim(splat, _) -> constant(splat) +struct FoldBroadcastInDimSplatPattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::BroadcastInDimOp op, + PatternRewriter& rewriter) const override { + TypedValue operand = op.getOperand(); + + if (SplatElementsAttr cstAttr; + matchPattern(operand, m_Constant(&cstAttr))) { + rewriter.replaceOpWithNewOp( + op, SplatElementsAttr::get(op.getType(), + cstAttr.getSplatValue())); + return success(); + } + return failure(); + } +}; + +struct EvalBroadcastInDimOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(BroadcastInDimOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + auto operandType = op.getOperand().getType(); + if (operandType.getRank() != 0) + return rewriter.notifyMatchFailure(op, "expected 0-dimensional type"); + + SmallVector operand; + if (failed(hlo::matchInts(op.getOperand(), operand))) + return rewriter.notifyMatchFailure(op, "expected constant operands"); + auto scalar = operand[0]; + + rewriter.replaceOpWithNewOp( + op, getTensorAttr(op.getType(), scalar)); + return success(); + } +}; + +struct EvalClampOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ClampOp op, + PatternRewriter& rewriter) const override { + return evalElementwise(rewriter, op, + [&](APSInt min, APSInt operand, APSInt max) { + if (operand < min) return min; + if (max < operand) return max; + return operand; + }); + } +}; + +struct EvalCompareOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(CompareOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + auto kind = op.getCompareType(); + return evalElementwise(rewriter, op, [&](APInt lhs, APInt rhs) { + bool result = false; + switch (op.getComparisonDirection()) { + case ComparisonDirection::EQ: + result = lhs == rhs; + break; + case ComparisonDirection::NE: + result = lhs != rhs; + break; + case ComparisonDirection::GE: + result = kind == ComparisonType::SIGNED ? lhs.sge(rhs) : lhs.uge(rhs); + break; + case ComparisonDirection::GT: + result = kind == ComparisonType::SIGNED ? lhs.sgt(rhs) : lhs.ugt(rhs); + break; + case ComparisonDirection::LE: + result = kind == ComparisonType::SIGNED ? lhs.sle(rhs) : lhs.ule(rhs); + break; + case ComparisonDirection::LT: + result = kind == ComparisonType::SIGNED ? lhs.slt(rhs) : lhs.ult(rhs); + break; + } + return getAPSInt(resultType.getElementType(), result); + }); + } +}; + +////////////////////////////////// +// ConcatenateOp +///////////////////////////////// + +struct FoldConcatenateOpPattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::ConcatenateOp op, + PatternRewriter& rewriter) const override { + RankedTensorType type = op.getType(); + if (!type.hasStaticShape()) return failure(); + + size_t numElems = type.getNumElements(); + if (numElems > kFoldOpEltLimit) return failure(); + + // Fold concatenate when all inputs are constants. + OperandRange inputs = op.getInputs(); + SmallVector constants(inputs.size()); + for (auto [input, constant] : llvm::zip_equal(inputs, constants)) { + if (!matchPattern(input, m_Constant(&constant))) return failure(); + } + + uint64_t dim = op.getDimension(); + ArrayRef shape = type.getShape(); + int64_t topSize = std::accumulate(shape.begin(), shape.begin() + dim, + int64_t{1}, std::multiplies<>{}); + + SmallVector newElems; + newElems.reserve(numElems); + + for (int64_t i = 0; i != topSize; ++i) { + for (ElementsAttr attr : constants) { + size_t bottomSize = attr.getNumElements() / topSize; + auto begin = attr.value_begin() + (i * bottomSize); + newElems.append(begin, begin + bottomSize); + } + } + + assert(newElems.size() == numElems); + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(op.getType(), newElems)); + return success(); + } +}; + +struct EvalConcatenateOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + if (op.getDimension() != 0) + return rewriter.notifyMatchFailure(op, "expected dimension = 0"); + + SmallVector result; + for (Value operand : op->getOperands()) { + if (failed(hlo::matchInts(operand, result))) + return rewriter.notifyMatchFailure(op, "expected constant operands"); + } + + rewriter.replaceOpWithNewOp(op, + getTensorAttr(resultType, result)); + return success(); + } +}; + +struct EvalConvertOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + EvalConvertOpPattern(MLIRContext* context, PatternBenefit benefit, + bool foldFloat_) + : OpRewritePattern(context, benefit), foldFloat{foldFloat_} {} + + LogicalResult matchAndRewrite(ConvertOp op, + PatternRewriter& rewriter) const override { + auto operand = op.getOperand(); + RankedTensorType resultType = op.getType(); + + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + auto operandElemType = getElementTypeOrSelf(operand.getType()); + auto resultElemType = getElementTypeOrSelf(resultType); + if (!(operandElemType.isInteger() && resultElemType.isInteger()) && + !foldFloat) + return rewriter.notifyMatchFailure(op, + "lossy computations are not allowed"); + + if (!resultElemType.isIntOrFloat()) + return rewriter.notifyMatchFailure( + op, "expected integer or float result tensor type"); + + DenseIntOrFPElementsAttr elements; + if (!matchPattern(operand, m_Constant(&elements))) + return rewriter.notifyMatchFailure( + op, "expected constant integer or float operand"); + + return evalConvert(rewriter, op, elements, resultType); + } + + private: + bool foldFloat; +}; + +struct EvalDivOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(DivOp op, + PatternRewriter& rewriter) const override { + return evalElementwise(rewriter, op, + [&](APSInt lhs, APSInt rhs) { return lhs / rhs; }); + } +}; + +struct EvalGetDimensionSizeOpPattern + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(GetDimensionSizeOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + auto operandType = op.getOperand().getType(); + if (operandType.isDynamicDim(op.getDimension())) + return rewriter.notifyMatchFailure(op, "expected static dimension"); + + auto result = operandType.getDimSize(op.getDimension()); + rewriter.replaceOpWithNewOp( + op, DenseIntElementsAttr::get(resultType, result)); + return success(); + } +}; + +struct EvalMaxOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(MaxOp op, + PatternRewriter& rewriter) const override { + return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { + return lhs >= rhs ? lhs : rhs; + }); + } +}; + +struct EvalMinOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(MinOp op, + PatternRewriter& rewriter) const override { + return evalElementwise(rewriter, op, [&](APSInt lhs, APSInt rhs) { + return lhs <= rhs ? lhs : rhs; + }); + } +}; + +struct FoldMulOpPattern final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::MulOp op, + PatternRewriter& rewriter) const override { + auto elemType = op.getType().getElementType(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + TypedAttr lhsAttr; + matchPattern(lhs, m_Constant(&lhsAttr)); + + TypedAttr rhsAttr; + matchPattern(rhs, m_Constant(&rhsAttr)); + + // The canonical form has the constant operand as the RHS. + if (isa(elemType) && lhsAttr && !rhsAttr) { + rewriter.modifyOpInPlace(op, [op, lhs, rhs] { + op->setOperands(ValueRange{rhs, lhs}); + }); + return success(); + } + + if (TypedAttr res; + lhsAttr && rhsAttr && + (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::multiplies<>{}))) { + rewriter.replaceOpWithNewOp(op, res); + return success(); + } + + return failure(); + } +}; + +struct EvalMulOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(MulOp op, + PatternRewriter& rewriter) const override { + return evalElementwise(rewriter, op, + [&](APSInt lhs, APSInt rhs) { return lhs * rhs; }); + } +}; + +struct EvalOrOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(OrOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (!resultType.getElementType().isInteger(1)) + return rewriter.notifyMatchFailure(op, "expected boolean element type"); + + return evalElementwise(rewriter, op, [&](APSInt lhsInt, APSInt rhsInt) { + return getAPSInt(resultType.getElementType(), lhsInt != 0 || rhsInt != 0); + }); + } +}; + +struct EvalRemOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(RemOp op, + PatternRewriter& rewriter) const override { + return evalElementwise(rewriter, op, + [&](APSInt lhs, APSInt rhs) { return lhs % rhs; }); + } +}; + +struct EvalReshapeOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ReshapeOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + // Pattern: reshape(cst, shape) -> cst + DenseIntElementsAttr attr; + if (!matchPattern(op.getOperand(), m_Constant(&attr))) + return rewriter.notifyMatchFailure(op, "expected constant operand"); + rewriter.replaceOpWithNewOp(op, attr.reshape(resultType)); + return success(); + } +}; + +struct EvalSelectOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + SmallVector pred, onTrue, onFalse; + if (failed(hlo::matchInts(op.getPred(), pred)) || + failed(hlo::matchInts(op.getOnTrue(), onTrue)) || + failed(hlo::matchInts(op.getOnFalse(), onFalse))) + return rewriter.notifyMatchFailure(op, "expected constant operands"); + + SmallVector result; + for (auto [predEl, onTrueEl, onFalseEl] : + llvm::zip(pred, onTrue, onFalse)) { + result.push_back(predEl != 0 ? onTrueEl : onFalseEl); + } + + rewriter.replaceOpWithNewOp( + op, getTensorAttr(op.getType(), result)); + return success(); + } +}; + +struct EvalSignOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SignOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (!isa(resultType.getElementType())) + return rewriter.notifyMatchFailure(op, + "expected integer result tensor type"); + return evalElementwise(rewriter, op, [&](APSInt operand) { + int64_t result; + if (operand.isNegative()) + result = -1; + else if (operand.isZero()) + result = 0; + else + result = 1; + return getAPSInt(resultType.getElementType(), result); + }); + } +}; + +template +DenseElementsAttr sliceType(SliceOp& op, const RangeType& data) { + using ElementType = std::decay_t; + + RankedTensorType operandType = op.getOperand().getType(); + RankedTensorType resultType = op.getResult().getType(); + + const auto dimOffsets = computeStrides(operandType.getShape()); + auto startIndices = op.getStartIndices(); + auto limitIndices = op.getLimitIndices(); + auto strides = op.getStrides(); + + const SmallVector startIndex(startIndices); + const SmallVector endIndex(limitIndices); + + SmallVector result; + result.reserve(resultType.getNumElements()); + + SmallVector srcIndex(startIndex); + for (int64_t i = 0; i < resultType.getNumElements(); ++i) { + auto srcLinearIndex = linearize(srcIndex, dimOffsets); + result.push_back(data[srcLinearIndex]); + for (int64_t dim = srcIndex.size() - 1; dim >= 0; --dim) { + srcIndex[dim] += strides[dim]; + if (srcIndex[dim] >= endIndex[dim]) + srcIndex[dim] = startIndex[dim]; + else + break; + } + } + + return DenseElementsAttr::get(op.getResult().getType(), + ArrayRef(result)); +} + +struct EvalSliceOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SliceOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + auto operand = op.getOperand(); + RankedTensorType operandType = operand.getType(); + if (!operandType.hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "expected operand with static ranked tensor type"); + + ElementsAttr els; + if (!matchPattern(operand, m_Constant(&els))) + return rewriter.notifyMatchFailure( + op, "expected constant integer or float operand"); + + DenseElementsAttr resAttr; + if (auto data = els.tryGetValues()) + resAttr = sliceType(op, *data); + else if (auto data = els.tryGetValues()) + resAttr = sliceType(op, *data); + else + return rewriter.notifyMatchFailure(op.getLoc(), + "unsupported element type"); + + rewriter.replaceOpWithNewOp(op, resAttr); + return success(); + } +}; + +struct FoldSubtractOpPattern final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(mlir::stablehlo::SubtractOp op, + PatternRewriter& rewriter) const override { + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + TypedAttr lhsAttr, rhsAttr; + matchPattern(lhs, m_Constant(&lhsAttr)); + matchPattern(rhs, m_Constant(&rhsAttr)); + + if (TypedAttr res; + lhsAttr && rhsAttr && + (res = foldBinaryOpIntOrFloat(lhsAttr, rhsAttr, std::minus<>{}))) { + rewriter.replaceOpWithNewOp(op, res); + return success(); + } + + return failure(); + } +}; + +struct EvalSubtractOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(SubtractOp op, + PatternRewriter& rewriter) const override { + return evalElementwise(rewriter, op, + [&](APSInt lhs, APSInt rhs) { return lhs - rhs; }); + } +}; + +struct EvalIotaOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(IotaOp op, + PatternRewriter& rewriter) const override { + auto resultType = cast(op.getType()); + auto elementType = resultType.getElementType(); + + if (!elementType.isInteger()) + return rewriter.notifyMatchFailure(op, "expected integer result type"); + + auto outputSize = resultType.getNumElements(); + auto resultBitWidth = elementType.getIntOrFloatBitWidth(); + int64_t dimension = op.getIotaDimension(); + + llvm::SmallVector values; + values.reserve(outputSize); + + if (outputSize == 0) { + rewriter.replaceOpWithNewOp( + op, DenseIntElementsAttr::get(resultType, values)); + return success(); + } + + int64_t sequences = 1; + int64_t sequenceMax = resultType.getDimSize(dimension); + int64_t elementRepetitions = 1; + for (int64_t i = 0; i < resultType.getRank(); i++) { + sequences *= i < dimension ? resultType.getDimSize(i) : 1; + elementRepetitions *= i > dimension ? resultType.getDimSize(i) : 1; + } + + for (int64_t i = 0; i < sequences; ++i) { + for (int64_t value = 0; value < sequenceMax; ++value) { + for (int64_t k = 0; k < elementRepetitions; ++k) { + values.push_back(APInt(resultBitWidth, value)); + } + } + } + + rewriter.replaceOpWithNewOp( + op, DenseIntElementsAttr::get(resultType, values)); + return success(); + } +}; + +template +DenseElementsAttr transposeType(TransposeOp& op, const RangeType& data) { + using ElementType = std::decay_t; + + RankedTensorType operandType = op.getOperand().getType(); + RankedTensorType resultType = op.getResult().getType(); + + const auto operandStrides = computeStrides(operandType.getShape()); + const auto resultStrides = computeStrides(resultType.getShape()); + const auto inversePermutation = invertPermutationVector(op.getPermutation()); + + SmallVector result; + result.reserve(resultType.getNumElements()); + + for (int64_t i = 0; i < resultType.getNumElements(); ++i) { + auto dstDimIndex = delinearize(i, resultStrides); + auto srcDimIndex = applyPermutation(dstDimIndex, inversePermutation); + auto srcLinearIndex = linearize(srcDimIndex, operandStrides); + result.push_back(data[srcLinearIndex]); + } + + return DenseElementsAttr::get(resultType, ArrayRef(result)); +} + +// transpose(constant) => constant with permuted dimensions +// This covers ranked tensor types with 0 dimensions(zero elements) and 0 +// rank(scalar), as well as splat values. +struct EvalTransposeOpPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter& rewriter) const override { + auto resultType = op.getType(); + if (failed(validateResultTypeForEval(rewriter, op, resultType))) + return failure(); + + ElementsAttr els; + if (!matchPattern(op.getOperand(), m_Constant(&els))) + return rewriter.notifyMatchFailure( + op, "expected constant integer or float operand"); + + DenseElementsAttr resAttr; + if (auto data = els.tryGetValues()) + resAttr = transposeType(op, *data); + else if (auto data = els.tryGetValues()) + resAttr = transposeType(op, *data); + else + return rewriter.notifyMatchFailure(op.getLoc(), + "unsupported element type"); + + rewriter.replaceOpWithNewOp(op, resAttr); + return success(); + } +}; + +struct StablehloAggressiveFolderPass + : public impl::StablehloAggressiveFolderPassBase< + StablehloAggressiveFolderPass> { + using StablehloAggressiveFolderPassBase::StablehloAggressiveFolderPassBase; + + LogicalResult initialize(MLIRContext* context) override { + RewritePatternSet patterns_(context); + populateStablehloAggressiveFolderPatterns(&patterns_, context, foldFloat); + patterns = std::move(patterns_); + + return success(); + } + + void runOnOperation() override { + if (failed(applyPatternsGreedily(getOperation(), patterns))) + signalPassFailure(); + } + + private: + FrozenRewritePatternSet patterns; +}; + +} // namespace + +void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns, + MLIRContext* context, + bool foldFloat, + PatternBenefit benefit) { + populateStablehloShapeFolderPatterns(patterns, context, foldFloat, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + + // TODO: Consolidate FoldOp patterns + // One is used by Shape Refinement, the other is a generic folder. + patterns + ->add( + context); +} + +void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns, + MLIRContext* context, bool foldFloat, + PatternBenefit benefit) { + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit, foldFloat); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp new file mode 100644 index 0000000000..f32f8d66b6 --- /dev/null +++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp @@ -0,0 +1,1537 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Implements optional canonicalization patterns for StableHLO ops. + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "llvm/ADT/APInt.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Sequence.h" +#include "llvm/ADT/SmallBitVector.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/SmallVectorExtras.h" +#include "llvm/Support/ErrorHandling.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Block.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/BuiltinAttributeInterfaces.h" +#include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/BuiltinTypes.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Matchers.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/IR/Region.h" +#include "mlir/IR/TypeUtilities.h" +#include "mlir/IR/Value.h" +#include "mlir/IR/ValueRange.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/Base.h" +#include "stablehlo/dialect/StablehloOps.h" +#include "stablehlo/transforms/PassUtils.h" +#include "stablehlo/transforms/optimization/Passes.h" + +using llvm::SmallBitVector; + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DEF_STABLEHLOAGGRESSIVESIMPLIFICATIONPASS +#include "stablehlo/transforms/optimization/Passes.h.inc" + +namespace { +// This is an upper limit on how many elements can be folded by an op folder. +// This limit doesn't apply to some special cases like adding a zero, +// multiplying by one, doing many operations with splats. +constexpr int64_t kFoldOpEltLimit = 65536; + +static bool isIotaRange(ArrayRef dims) { + return llvm::all_of(llvm::enumerate(dims), [](const auto &it) { + return static_cast(it.index()) == it.value(); + }); +} + +/// Matches when either of the submatchers match. +template +struct m_AnyOf { + m_AnyOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {} + + bool match(Operation *op) { return matcherA.match(op) || matcherB.match(op); } + + MatcherA matcherA; + MatcherB matcherB; +}; + +template +m_AnyOf(MatcherA, MatcherB) -> m_AnyOf; + +/// Matches when either of the submatchers match. +template +struct m_AnyAttrOf { + m_AnyAttrOf(MatcherA a, MatcherB b) : matcherA(a), matcherB(b) {} + + bool match(Attribute attr) { + return matcherA.match(attr) || matcherB.match(attr); + } + + MatcherA matcherA; + MatcherB matcherB; +}; + +template +m_AnyAttrOf(MatcherA, MatcherB) -> m_AnyAttrOf; + +////////////////////////////////// +// CompareOp +///////////////////////////////// + +static ComparisonDirection invertDirection(ComparisonDirection direction) { + switch (direction) { + case ComparisonDirection::EQ: + case ComparisonDirection::NE: + return direction; + case ComparisonDirection::GE: + return ComparisonDirection::LE; + case ComparisonDirection::GT: + return ComparisonDirection::LT; + case ComparisonDirection::LE: + return ComparisonDirection::GE; + case ComparisonDirection::LT: + return ComparisonDirection::GT; + } + + llvm::report_fatal_error("Unhandled case"); +} + +struct CompareOpCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(CompareOp op, + PatternRewriter &rewriter) const override { + RankedTensorType type = op.getType(); + + // Bail out on non-integer comparison. + // TODO: Support more comparison types. + std::optional compType = op.getCompareType(); + if (!compType || + !llvm::is_contained({ComparisonType::SIGNED, ComparisonType::UNSIGNED}, + *compType)) { + return failure(); + } + + ComparisonDirection direction = op.getComparisonDirection(); + Value lhs = op.getLhs(); + Value rhs = op.getRhs(); + + // Pattern: compare(X, X, [EQ,GE,LE]) -> true + // Pattern: compare(X, X, [NE,GT,LT]) -> false + if (lhs == rhs) { + switch (direction) { + case ComparisonDirection::EQ: + case ComparisonDirection::GE: + case ComparisonDirection::LE: { + rewriter.replaceOpWithNewOp( + op, SplatElementsAttr::get(type, rewriter.getBoolAttr(true))); + return success(); + } + case ComparisonDirection::GT: + case ComparisonDirection::LT: + case ComparisonDirection::NE: { + rewriter.replaceOpWithNewOp(op, + rewriter.getZeroAttr(type)); + return success(); + } + } + llvm_unreachable("Unhandled case"); + } + + // Pattern: compare(cst, X, comparator) -> compare(X, cst, inv(comparator)) + TypedAttr lhsAttr, rhsAttr; + matchPattern(lhs, m_Constant(&lhsAttr)); + matchPattern(rhs, m_Constant(&rhsAttr)); + + // The canonical form has the constant operand as the RHS. + if (lhsAttr && !rhsAttr) { + rewriter.modifyOpInPlace(op, [&op, direction, lhs, rhs] { + op.setComparisonDirection(invertDirection(direction)); + op->setOperands(ValueRange{rhs, lhs}); + }); + return success(); + } + + return failure(); + } +}; + +////////////////////////////////// +// ConcatenateOp +///////////////////////////////// + +// Pattern: concatenate(X) -> X +class ConcatenateOpNoop : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter &rewriter) const override { + if (op.getInputs().size() != 1 || + op.getInputs().front().getType() != op.getType()) + return rewriter.notifyMatchFailure(op, "not single operand noop-concat"); + + rewriter.replaceOp(op, op.getInputs().front()); + return success(); + } +}; + +// Pattern: concatenate(X, Y, []) -> concatenate(X, Y) +class ConcatenateOpRemoveEmpty : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter &rewriter) const override { + auto axis = op.getDimension(); + llvm::SmallVector newOperands = llvm::to_vector( + llvm::make_filter_range(op.getOperands(), [&](Value operand) { + return cast(operand.getType()).getDimSize(axis) != 0; + })); + + // Only handle nonempty new operands, empty handled by + // ZeroExtentToEmptyConstant pattern. + if (!newOperands.empty() && newOperands.size() < op.getNumOperands()) { + rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); }); + return success(); + } + + return failure(); + } +}; + +// Pattern: concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z) +class ConcatenateOpFlatten : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + LogicalResult matchAndRewrite(ConcatenateOp op, + PatternRewriter &rewriter) const override { + auto getFlattenedOperands = [&](const Value &val) -> ValueRange { + auto definingOp = dyn_cast_or_null(val.getDefiningOp()); + // To avoid inflate the memory footprint, only flatten the + // ConcatenateOp when it has only one use. + if (definingOp && definingOp->hasOneUse() && + definingOp.getDimension() == op.getDimension()) + return definingOp.getInputs(); + return val; + }; + + bool needToFlatten = false; + int operandCount = 0; + llvm::for_each(op.getInputs(), [&](Value val) { + auto result = getFlattenedOperands(val); + if (result.size() != 1 || result[0] != val) needToFlatten = true; + operandCount += result.size(); + }); + + if (!needToFlatten) + return rewriter.notifyMatchFailure(op, "no need to flatten"); + + llvm::SmallVector newOperands; + newOperands.reserve(operandCount); + + for (auto operand : op.getInputs()) { + auto flattenedOperands = getFlattenedOperands(operand); + newOperands.append(flattenedOperands.begin(), flattenedOperands.end()); + } + + rewriter.modifyOpInPlace(op, [&] { op->setOperands(newOperands); }); + return success(); + } +}; + +////////////////////////////////// +// BroadcastInDimOp +///////////////////////////////// + +// Used in DRR file. +DenseI64ArrayAttr getMergedBroadcastDimensions(OpBuilder &b, + ArrayRef dims, + ArrayRef dimsParent) { + auto mergedDims = llvm::map_to_vector( + dimsParent, [&dims](int64_t dim) { return dims[dim]; }); + return b.getDenseI64ArrayAttr(mergedDims); +} + +////////////////////////////////// +// DynamicBroadcastInDimOp +///////////////////////////////// + +/// Does the same as PatternRewriter::replaceOpWithNewOp, but with a twist. +/// +/// Sometimes, we want to replace an op with a new op and simultaneously refine +/// the result type from a dynamically-shaped type to a statically-shaped type. +/// (Search for usages of this function for examples). +// +/// Oftentimes, this works just fine because HLO is designed to accommodate +/// this kind of type refinements. But sometimes, this doesn't work - when +/// the op is used outside of the HLO dialect (e.g. in func.return). In these +/// cases, we insert a stablehlo.convert to smooth things out. +template +static OpTy refineOpWithNewOp(PatternRewriter &rewriter, Operation *op, + Args &&...args) { + auto newOp = rewriter.create(op->getLoc(), std::forward(args)...); + + llvm::SmallVector replacementResults; + assert(op->getNumResults() == newOp->getNumResults() && + "replacement op doesn't match results of original op"); + for (auto [opResult, newOpResult] : + llvm::zip(op->getResults(), newOp->getResults())) { + Value replacementResult = newOpResult; + if (llvm::any_of(opResult.getUsers(), [&](Operation *user) { + return user->getDialect() != op->getDialect(); + })) + replacementResult = rewriter.create( + op->getLoc(), opResult.getType(), newOpResult); + replacementResults.push_back(replacementResult); + } + + rewriter.replaceOp(op, replacementResults); + return newOp; +} + +/// If a DynamicBroadCastInDimOp is not actually dynamic, use an ordinary +/// BroadcastInDimOp. +struct DynamicBroadcastInDimOpNotActuallyDynamic final + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicBroadcastInDimOp op, + PatternRewriter &rewriter) const override { + RankedTensorType operandType = op.getOperand().getType(); + if (!operandType.hasStaticShape()) + return rewriter.notifyMatchFailure(op, "requires operand static shape"); + + RankedTensorType type = op.getType(); + // output has static shape, replace with broadcast_in_dim + if (type.hasStaticShape()) { + rewriter.replaceOpWithNewOp( + op, type, op.getOperand(), op.getBroadcastDimensionsAttr()); + return success(); + } + + // output_dimensions are constant, set output shape with output_dimensions, + // then replace with broadcast_in_dim + if (llvm::SmallVector shape; + succeeded(hlo::matchInts(op.getOutputDimensions(), shape))) { + refineOpWithNewOp( + rewriter, op, RankedTensorType::get(shape, type.getElementType()), + op.getOperand(), op.getBroadcastDimensionsAttr()); + return success(); + } + return rewriter.notifyMatchFailure( + op, "requires output static shape or constant broadcast dimensions"); + } +}; + +////////////////////////////////// +// DynamicGatherOp +///////////////////////////////// + +DenseI64ArrayAttr convertToI64Array(OpBuilder &b, Attribute attr) { + auto denseAttr = cast(attr); + SmallVector result; + result.reserve(denseAttr.getNumElements()); + for (auto elem : denseAttr.getValues()) + result.push_back(elem.getSExtValue()); + return b.getDenseI64ArrayAttr(result); +} + +////////////////////////////////// +// DynamicIotaOp +///////////////////////////////// + +struct DynamicIotaIsStatic : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicIotaOp iota, + PatternRewriter &rewriter) const override { + // Result type has static shape, replace with iota. + auto resultTy = cast(iota.getType()); + if (!resultTy.hasStaticShape()) + return rewriter.notifyMatchFailure(iota, "requires output static shape"); + rewriter.replaceOpWithNewOp(iota, resultTy, + iota.getIotaDimension()); + return success(); + } +}; + +// Dynamic Iota operations across multiple dimensions can be reduced to an iota +// and a ranked broadcast. +// Pattern: dynamic_iota(shape, dim) -> +// dynamic_broadcast_in_dim(dynamic_iota(slice(shape), dim), shape) +struct DynamicIotaOpToBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicIotaOp iota, + PatternRewriter &rewriter) const override { + auto resultTy = cast(iota.getType()); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(iota, "requires rank >= 2"); + + auto iotaDimension = static_cast(iota.getIotaDimension()); + + // Handle case where iota dimension is index, need to convert to/from i64 + // to interop with slice. These canonicalize away if input is i64. + auto convertedShape = rewriter.create( + iota.getLoc(), + RankedTensorType::get( + cast(iota.getOutputShape().getType()).getShape(), + rewriter.getI64Type()), + iota.getOutputShape()); + + auto slicedShape = rewriter.create( + iota.getLoc(), convertedShape, + rewriter.getDenseI64ArrayAttr(iotaDimension), + rewriter.getDenseI64ArrayAttr(iotaDimension + 1), + rewriter.getDenseI64ArrayAttr(1)); + + auto convertedSlicedShape = rewriter.create( + iota.getLoc(), + RankedTensorType::get( + {1}, + cast(iota.getOutputShape().getType()).getElementType()), + slicedShape); + + auto iotaType = RankedTensorType::get({resultTy.getDimSize(iotaDimension)}, + resultTy.getElementType()); + + auto newIota = rewriter.create( + iota.getLoc(), iotaType, convertedSlicedShape, + rewriter.getI64IntegerAttr(0)); + + rewriter.replaceOpWithNewOp( + iota, resultTy, newIota, iota.getOutputShape(), + rewriter.getDenseI64ArrayAttr(iotaDimension)); + return success(); + } +}; + +////////////////////////////////// +// DynamicReshapeOp +///////////////////////////////// + +struct DynamicReshapeOpIsStatic final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicReshapeOp op, + PatternRewriter &rewriter) const override { + // This is a noop when the output type is already a static shape. + RankedTensorType type = op.getType(); + if (!type.hasStaticShape()) + return rewriter.notifyMatchFailure(op, "dynamic reshape not static"); + + rewriter.replaceOpWithNewOp(op, type, op.getOperand()); + return success(); + } +}; + +// Pattern: dynamic_reshape(op(dynamic_reshape(X, shape)), shape) +// -> op(dynamic_reshape(X, shape)) +// [if op has same operand and result shape] +class DynamicReshapeOpSameOperandAndResultShape + : public OpRewritePattern { + public: + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicReshapeOp op, + PatternRewriter &rewriter) const override { + Operation *defOp = op.getOperand().getDefiningOp(); + if (!defOp || + !defOp->hasTrait()) { + return rewriter.notifyMatchFailure( + op, "dynamic reshape parent not same operand and result shape"); + } + DynamicReshapeOp reshape = + defOp->getOperand(0).getDefiningOp(); + if (!reshape) + return rewriter.notifyMatchFailure( + op, "dynamic reshape not wrapping same operand and result shape"); + if (reshape.getOutputShape() == op.getOutputShape()) { + rewriter.replaceOp(op, {defOp->getResult(0)}); + return success(); + } + return failure(); + } +}; + +////////////////////////////////// +// DynamicSliceOp +///////////////////////////////// + +// Canonicalizes DynamicSlice ops that can be replaced instead with Slice ops. +// This canonicalization is applied the case when the `begin` input values are +// compile time constants and thus can be made into a tensor. +// +// Pattern: dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes) +struct DynamicSliceOpToSlice : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(DynamicSliceOp dynamicSlice, + PatternRewriter &rewriter) const override { + Value input = dynamicSlice.getOperand(); + auto inputType = cast(input.getType()); + if (!inputType.hasStaticShape()) + return rewriter.notifyMatchFailure(dynamicSlice, + "dynamic slice input not static"); + + auto sliceSizes = dynamicSlice.getSliceSizes(); + SmallVector tempStartIndices; + for (const auto &indexAndSliceStart : + llvm::enumerate(dynamicSlice.getStartIndices())) { + APInt val; + Value start = indexAndSliceStart.value(); + int64_t index = indexAndSliceStart.index(); + if (!matchPattern(start, m_ConstantInt(&val))) + return rewriter.notifyMatchFailure(dynamicSlice, + "dynamic slice input not constant"); + + // Clamp the indices within bounds to faithfully mirror dynamic slice + // semantics. + int64_t clampedStart = + std::clamp(val.getSExtValue(), static_cast(0), + inputType.getDimSize(index) - sliceSizes[index]); + tempStartIndices.push_back(clampedStart); + } + + // At this point we've determined that the start indices are all constants; + // pack them into a single tensor. + auto sliceStartIndices = rewriter.getDenseI64ArrayAttr(tempStartIndices); + SmallVector tempSliceLimits; + for (const auto &[start, size] : llvm::zip(tempStartIndices, sliceSizes)) { + tempSliceLimits.push_back(start + size); + } + auto sliceLimits = rewriter.getDenseI64ArrayAttr(tempSliceLimits); + + auto sliceStrides = rewriter.getDenseI64ArrayAttr( + SmallVector(inputType.getRank(), 1)); + + rewriter.replaceOpWithNewOp(dynamicSlice, input, sliceStartIndices, + sliceLimits, sliceStrides); + return success(); + } +}; + +////////////////////////////////// +// RealDynamicSliceOp +///////////////////////////////// + +// Pattern: real_dynamic_slice(X, start, limit, strides) +// -> dynamic_slice(X, start, limit, strides) +// [if strides, start are constants, limit = start + constant] +struct RealDynamicSliceOpToDynamicSlice + : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(RealDynamicSliceOp op, + PatternRewriter &rewriter) const override { + // This rewrite only works for unit strides because DynamicSliceOp + // doesn't support strides (i.e. it implicitly has unit strides). + DenseIntElementsAttr stridesAttr; + if (!matchPattern(op.getStrides(), m_Constant(&stridesAttr))) + return rewriter.notifyMatchFailure(op, "requires constant strides"); + if (!llvm::all_of(stridesAttr.getValues(), + [&](APInt stride) { return stride == 1; })) + return rewriter.notifyMatchFailure(op, "requires unit strides"); + + // Check that slice sizes are fully static (DynamicSliceOp style). + // To detect that, we check whether `limit_indices` is defined as + // `start_indices + constant` or `constant + start_indices`. + DenseIntElementsAttr sliceSizesAttr; + auto m_startIndices = matchers::m_Val(op.getStartIndices()); + // Only handle the AddOp case, if all constant we fold to SliceOp. + if (!matchPattern( + op.getLimitIndices(), + m_Op(m_startIndices, m_Constant(&sliceSizesAttr))) && + !matchPattern(op.getLimitIndices(), + m_Op(m_Constant(&sliceSizesAttr), m_startIndices))) + return rewriter.notifyMatchFailure( + op, "requires limit indices equal to start indices plus constant"); + + // RealDynamicSliceOp can take tensors of integer or index element types. + // DynamicSliceOp::slice_sizes only supports i64 element type. + // Adapt accordingly in order to be compatible with DynamicSliceOp. + SmallVector sliceSizes; + for (auto element : sliceSizesAttr.getValues()) { + sliceSizes.push_back(element.getSExtValue()); + } + + // RealDynamicSliceOp::start_indices is a 1-dimensional tensor. + // DynamicSliceOp::start_indices is a vararg of 0-dimensional tensors. + // Adapt accordingly in order to be compatible with DynamicSliceOp. + SmallVector startIndices; + for (auto i = 0; i < static_cast(sliceSizes.size()); ++i) { + auto startIndex1D = rewriter.create( + op.getLoc(), op.getStartIndices(), rewriter.getDenseI64ArrayAttr(i), + rewriter.getDenseI64ArrayAttr(i + 1), + rewriter.getDenseI64ArrayAttr(1)); + auto startIndex0DType = RankedTensorType::get( + {}, + cast(op.getStartIndices().getType()).getElementType()); + auto startIndex0D = rewriter.create( + op.getLoc(), startIndex0DType, startIndex1D); + startIndices.push_back(startIndex0D); + } + + rewriter.replaceOpWithNewOp( + op, op.getOperand(), startIndices, + rewriter.getDenseI64ArrayAttr(sliceSizes)); + return success(); + } +}; + +////////////////////////////////// +// ReduceOp +///////////////////////////////// + +// Pattern: reduce[A](_, _, fn:return A) -> A... +struct ReduceOpNoopVariableReturn final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReduceOp op, + PatternRewriter &rewriter) const override { + // If all returned values in the ReduceOp region exists outside the + // region, replace the ReduceOp with those values. + if (auto retOp = dyn_cast(op.getBody().front().getTerminator())) { + Region *retRegion = retOp->getParentRegion(); + if (llvm::any_of(retOp.getResults(), [retRegion](Value result) { + return result.getParentRegion() == retRegion; + })) + return failure(); + + rewriter.replaceOp(op, retOp.getResults()); + return success(); + } + + return failure(); + } +}; + +// Pattern: reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...] +struct ReduceOpEmptyCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReduceOp op, + PatternRewriter &rewriter) const override { + // We require all reduce shapes to be the same, up to the element types, so + // we can just use the first operand and the first result as + // representatives. + auto elemTy = cast(op.getInputs().getType().front()); + + if (!llvm::is_contained(elemTy.getShape(), 0)) return failure(); + + Location loc = op.getLoc(); + DenseI64ArrayAttr empty = rewriter.getDenseI64ArrayAttr({}); + if (elemTy.hasStaticShape()) { + SmallVector broadcasts(op.getNumResults()); + for (auto [bcast, init, outTy] : llvm::zip_equal( + broadcasts, op.getInitValues(), op.getResultTypes())) { + bcast = rewriter.create(loc, outTy, init, empty); + } + rewriter.replaceOp(op, broadcasts); + return success(); + } + + SmallVector shapes; + if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), shapes))) + return failure(); + + SmallVector broadcasts(op.getNumResults()); + for (auto [bcast, init, shape, outTy] : llvm::zip_equal( + broadcasts, op.getInitValues(), shapes, op.getResultTypes())) { + bcast = rewriter.create(loc, outTy, init, shape, + empty); + } + rewriter.replaceOp(op, broadcasts); + return success(); + } +}; + +// Pattern: reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)] +struct ReduceOpUnusedResultCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(ReduceOp op, + PatternRewriter &rewriter) const override { + SmallVector usedResults; + llvm::copy_if(op.getResults(), std::back_inserter(usedResults), + [](OpResult result) { return !result.use_empty(); }); + + if (usedResults.size() == op.getNumResults()) + return rewriter.notifyMatchFailure(op, "all operation results have uses"); + + const auto pairSize = 2; + const auto numOperands = op.getNumOperands(); + const auto numOperandPairs = numOperands / pairSize; + + Block &reducerBlock = op.getBody().front(); + auto retOp = cast(reducerBlock.getTerminator()); + + assert(numOperandPairs == op.getNumResults() && + numOperandPairs == retOp.getNumOperands()); + + SmallVector workList; + auto addToWorkList = [&workList, + reducerBody = retOp->getParentRegion()](Value v) { + if (v.getParentRegion() == reducerBody) workList.push_back(v); + }; + + SmallPtrSet usedOps; + SmallBitVector usedArgs(numOperands); + SmallBitVector usedReturnOperands(numOperandPairs); + for (const auto &usedResult : usedResults) { + auto resultNo = usedResult.getResultNumber(); + usedReturnOperands.set(resultNo); + + // Follow the def-use chain starting from return operand to identify + // which argument pairs are used to compute it. + addToWorkList(retOp.getOperand(resultNo)); + while (!workList.empty()) { + auto definition = workList.pop_back_val(); + if (auto blockArg = dyn_cast(definition)) { + // using one argument implies using the whole argument pair + const auto pairNo = blockArg.getArgNumber() % numOperandPairs; + usedArgs.set(pairNo); + usedArgs.set(pairNo + numOperandPairs); + } else if (auto *defOp = definition.getDefiningOp()) { + usedOps.insert(defOp); + for (const auto &operand : defOp->getOperands()) + addToWorkList(operand); + } + } + } + + const auto newNumOperandPairs = usedResults.size(); + const auto newNumOperands = newNumOperandPairs * pairSize; + if (newNumOperands != usedArgs.count()) + return rewriter.notifyMatchFailure(op, [&](Diagnostic &diag) { + diag << "non-conservative case: " << newNumOperandPairs + << " return results should be matched with " << newNumOperands + << " operands, but got " << usedArgs.count(); + }); + + SmallVector newInputs; + SmallVector newInitVals; + SmallVector newElementTypes; + for (auto i : llvm::seq(0u, numOperandPairs)) { + if (usedReturnOperands[i]) + newElementTypes.push_back( + getElementTypeOrSelf(retOp.getOperand(i).getType())); + + if (!usedArgs[i]) continue; + + newInputs.push_back(op.getOperand(i)); + newInitVals.push_back(op.getOperand(i + numOperandPairs)); + } + + auto newOp = + rewriter.create(op.getLoc(), newInputs, newInitVals, + op.getDimensionsAttr(), newElementTypes); + Block *newReducerBlock = rewriter.createBlock(&newOp.getBody()); + + IRMapping mapper; + for (auto arg : reducerBlock.getArguments()) + if (usedArgs[arg.getArgNumber()]) + mapper.map(arg, + newReducerBlock->addArgument(arg.getType(), arg.getLoc())); + + rewriter.setInsertionPointToStart(newReducerBlock); + for (Operation &op : reducerBlock.getOperations()) + if (usedOps.contains(&op)) rewriter.clone(op, mapper); + + SmallVector newReturnOperands; + for (const auto &en : llvm::enumerate(retOp.getOperands())) + if (usedReturnOperands[en.index()]) + newReturnOperands.push_back(mapper.lookup(en.value())); + + rewriter.create(retOp.getLoc(), newReturnOperands); + + // Build new results list (unused entries will be null). + SmallVector newResults(op.getNumResults()); + for (const auto &[i, result] : llvm::enumerate(usedResults)) { + newResults[result.getResultNumber()] = newOp.getResult(i); + } + + rewriter.replaceOp(op, newResults); + return success(); + } +}; + +///////////////////////////////// +// GetDimensionSizeOp +///////////////////////////////// + +// TODO: This is duplicated with a pattern in shape refinement, consider +// consolidating. +// Pattern: get_dimension_size(X, i) -> X.shape[i] +struct GetDimensionSizeOpCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GetDimensionSizeOp op, + PatternRewriter &rewriter) const override { + // Fold get_dimension_size when the queried dim is statically known. + RankedTensorType operandTy = op.getOperand().getType(); + + int64_t dimSize = operandTy.getDimSize(op.getDimension()); + if (dimSize < 0) return failure(); + + auto elemTy = cast(op.getType().getElementType()); + IntegerAttr elemVal = rewriter.getIntegerAttr(elemTy, dimSize); + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(op.getType(), elemVal)); + return success(); + } +}; + +////////////////////////////////// +// GatherOp +///////////////////////////////// + +/// Converts gather ops to slice ops in case we have a single set of constant +/// indices. +// Pattern: gather(X, cst_start_indices) -> slice(X, slice_start, slice_end) +struct GatherOpCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(GatherOp gather, + PatternRewriter &rewriter) const override { + DenseIntElementsAttr index; + if (!matchPattern(gather.getStartIndices(), m_Constant(&index))) + return failure(); + + GatherDimensionNumbersAttr dnums = gather.getDimensionNumbers(); + if (dnums.getIndexVectorDim() != 0 || index.getType().getRank() > 1) + return failure(); + + // TODO: Remove when the verifier catches this case that is + // invalid if all previous condition holds. + if (index.getNumElements() != + static_cast(dnums.getStartIndexMap().size())) { + return failure(); + } + + auto operandType = cast(gather->getOperand(0).getType()); + if (!operandType.hasStaticShape()) return failure(); + + auto sliceEnd = llvm::to_vector(gather.getSliceSizes()); + SmallVector sliceStart(sliceEnd.size(), 0); + for (auto [mapIndex, value] : + llvm::zip_equal(dnums.getStartIndexMap(), index.getValues())) { + // Clamp the indices within bounds to faithfully mirror gather semantics. + int64_t offset = + std::clamp(value.getSExtValue(), static_cast(0), + operandType.getDimSize(mapIndex) - sliceEnd[mapIndex]); + sliceStart[mapIndex] += offset; + sliceEnd[mapIndex] += offset; + } + + SmallVector sliceStride(sliceEnd.size(), 1); + SmallVector sliceShape(sliceEnd.size()); + for (auto [shapeElem, startElem, endElem] : + llvm::zip_equal(sliceShape, sliceStart, sliceEnd)) { + shapeElem = endElem - startElem; + } + + Type elementType = gather.getType().getElementType(); + auto sliceType = RankedTensorType::get(sliceShape, elementType); + Value result = rewriter.create( + gather.getLoc(), sliceType, gather.getOperand(), + rewriter.getDenseI64ArrayAttr(sliceStart), + rewriter.getDenseI64ArrayAttr(sliceEnd), + rewriter.getDenseI64ArrayAttr(sliceStride)); + + ArrayRef collapsedSliceDims = dnums.getCollapsedSliceDims(); + if (!collapsedSliceDims.empty()) { + llvm::SmallVector reshapeShape; + for (auto [idx, dim] : llvm::enumerate(sliceShape)) { + if (!llvm::is_contained(collapsedSliceDims, idx)) + reshapeShape.push_back(dim); + } + auto reshapeType = RankedTensorType::get(reshapeShape, elementType); + result = rewriter.create(gather.getLoc(), reshapeType, result); + } + + result.setType(gather.getType()); + rewriter.replaceOp(gather, result); + return success(); + } +}; + +////////////////////////////////// +// IotaOp +///////////////////////////////// + +// Iota operations across multiple dimensions can be reduced to an iota and a +// ranked broadcast. +// Pattern: iota(dim) : multi_rank +// -> broadcast_in_dim(iota(dim) : array, multi_rank) +struct IotaOpBroadcast : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(IotaOp iota, + PatternRewriter &rewriter) const override { + auto resultTy = cast(iota.getType()); + if (resultTy.getRank() < 2) + return rewriter.notifyMatchFailure(iota, "itoa not broadcastable"); + + auto iotaDim = iota.getIotaDimension(); + auto iotaDimSize = resultTy.getDimSize(iotaDim); + auto iota1D = rewriter.create( + iota.getLoc(), + RankedTensorType::get({iotaDimSize}, resultTy.getElementType()), + rewriter.getI64IntegerAttr(0)); + + auto broadcastAttr = + rewriter.getDenseI64ArrayAttr({static_cast(iotaDim)}); + rewriter.replaceOpWithNewOp(iota, resultTy, iota1D, + broadcastAttr); + return success(); + } +}; + +////////////////////////////////// +// PadOp +///////////////////////////////// + +// If the input tensor has a dimension of length-0, the input tensor is +// irrelevant. Instead we can broadcast the pad value to the output size rather +// than pad the input tensor. + +// If the input tensor has a dimension of length-0, the input tensor is +// irrelevant. Instead we can broadcast the pad value to the output size rather +// than pad the input tensor. + +// Pattern: pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _) +struct PadOpBroadcastEmptyTensor : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(PadOp op, + PatternRewriter &rewriter) const override { + auto operand = op.getOperand(); + auto padVal = op.getPaddingValue(); + + auto resultTy = cast(op.getType()); + + if (cast(operand.getType()).getNumElements() != 0) + return rewriter.notifyMatchFailure(op, "operand is not empty tensor"); + + if (resultTy.hasStaticShape()) { + rewriter.replaceOpWithNewOp( + op, resultTy, padVal, rewriter.getDenseI64ArrayAttr({})); + return success(); + } + + llvm::SmallVector reifiedShapes; + if (failed(op.reifyReturnTypeShapes(rewriter, op.getOperands(), + reifiedShapes))) + return rewriter.notifyMatchFailure(op, "failed to reify return type"); + + rewriter.replaceOpWithNewOp( + op, op.getType(), padVal, reifiedShapes.front(), + rewriter.getDenseI64ArrayAttr({})); + return success(); + } +}; + +////////////////////////////////// +// SelectOp +///////////////////////////////// + +struct SelectOpCanon final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter &rewriter) const override { + RankedTensorType type = op.getType(); + + Value trueVal = op.getOnTrue(); + Value falseVal = op.getOnFalse(); + + // Eliminate select with two identical outcomes. + if (trueVal == falseVal) { + rewriter.replaceOp(op, trueVal); + return success(); + } + + // Simplify when the condition is a constant. + Value pred = op.getPred(); + ElementsAttr cond; + if (!matchPattern(pred, m_Constant(&cond))) return failure(); + + // Handle splat predicate and select either `trueVal` or `falseVal`. + if (cond.isSplat()) { + rewriter.replaceOp(op, cond.getSplatValue() ? trueVal : falseVal); + return success(); + } + + // Handle elementwise selection when both outcomes are also constants. This + // will create a new, likely non-splat constant. + if (cond.getNumElements() > kFoldOpEltLimit) return failure(); + + ElementsAttr trueAttr; + if (!matchPattern(trueVal, m_Constant(&trueAttr))) return failure(); + + ElementsAttr falseAttr; + if (!matchPattern(falseVal, m_Constant(&falseAttr))) return failure(); + + SmallVector newValues; + newValues.reserve(cond.getNumElements()); + for (auto [condElem, trueElem, falseElem] : llvm::zip_equal( + cond.getValues(), trueAttr.getValues(), + falseAttr.getValues())) { + newValues.push_back(condElem ? trueElem : falseElem); + } + + rewriter.replaceOpWithNewOp( + op, DenseElementsAttr::get(type, newValues)); + return success(); + } +}; + +struct CompareSelectIntoMinMax final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SelectOp op, + PatternRewriter &rewriter) const override { + Value pred = op.getPred(); + Value trueVal = op.getOnTrue(); + Value falseVal = op.getOnFalse(); + + auto cmpOp = pred.getDefiningOp(); + if (!cmpOp) return failure(); + + ComparisonDirection direction = cmpOp.getComparisonDirection(); + Value cmpLhs = cmpOp.getLhs(); + Value cmpRhs = cmpOp.getRhs(); + + // Turn into canonical form: + // b <= a ? a : b ---> a >= b ? a : b + // b < a ? a : b ---> a > b ? a : b + // b >= a ? a : b ---> a <= b ? a : b + // b > a ? a : b ---> a < b ? a : b + if (cmpLhs == falseVal && cmpRhs == trueVal) { + direction = invertDirection(direction); + } else if (!(cmpLhs == trueVal && cmpRhs == falseVal)) { + return failure(); + } + + switch (direction) { + case ComparisonDirection::GE: + case ComparisonDirection::GT: { + rewriter.replaceOpWithNewOp(op, trueVal, falseVal); + return success(); + } + case ComparisonDirection::LE: + case ComparisonDirection::LT: { + rewriter.replaceOpWithNewOp(op, trueVal, falseVal); + return success(); + } + default: { + return failure(); + } + } + } +}; + +////////////////////////////////// +// SliceOp +///////////////////////////////// + +// In cases where a concat is fed into a slice, it is possible the concat +// can be simplified or bypassed. This checks which inputs to the concat are +// used by the slice, either reducing the number of concatenated values or +// entirely removes the concat. +// Pattern: slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z)) +struct SliceOpConcatSimplify : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SliceOp slice, + PatternRewriter &rewriter) const override { + auto resultTy = cast(slice.getType()); + if (!resultTy.hasStaticShape()) + return rewriter.notifyMatchFailure(slice, "result shape not static"); + + auto concat = slice.getOperand().getDefiningOp(); + if (!concat) + return rewriter.notifyMatchFailure(slice, "slice input not concat"); + + auto concatType = cast(concat.getType()); + auto dimension = concat.getDimension(); + + auto start = slice.getStartIndices(); + auto limit = slice.getLimitIndices(); + + int64_t sliceStart = start[dimension]; + int64_t sliceLimit = limit[dimension]; + + // We need to determine what inputs from the concat affect the slice, and + // how the bounds of the slice need to be updated for the minimally required + // inputs. + int64_t runningSize = 0; + int64_t frontOffset = concatType.getShape()[dimension]; + + auto subsetStart = concat.operand_end(); + auto subsetEnd = concat.operand_end(); + for (auto it = concat.operand_begin(); it < concat.operand_end(); ++it) { + auto input = *it; + ShapedType inputTy = cast(input.getType()); + if (inputTy.isDynamicDim(dimension)) + return rewriter.notifyMatchFailure( + slice, "concat input has dynamic dimension"); + + auto dimSize = inputTy.getShape()[dimension]; + + // If this position is in the slice its the start of the subset and we + // need to update the start and limit values. + if (runningSize + dimSize > sliceStart && + subsetStart == concat.operand_end()) { + subsetStart = it; + frontOffset = runningSize; + } + + // Determine the last required offset. + if (runningSize < sliceLimit) { + subsetEnd = it + 1; + } + + runningSize += dimSize; + } + + auto subsetSize = subsetEnd - subsetStart; + // We need all inputs so no optimization. + if (subsetSize == concat.getNumOperands()) + return rewriter.notifyMatchFailure(slice, + "slice needs all concat inputs"); + + // If there's nothing to slice that means the output is an empty tensor and + // there is dead code. We do nothing here and rely on other passes to clean + // this up. + if (subsetSize == 0) + return rewriter.notifyMatchFailure(slice, "slice is empty"); + + if (subsetSize > 1 && !concat.getResult().hasOneUse()) + return rewriter.notifyMatchFailure(slice, + "slice is not the only concat user"); + + auto concatRange = OperandRange(subsetStart, subsetEnd); + auto newConcat = rewriter.create( + concat.getLoc(), concatRange, concat.getDimension()); + + SmallVector newStart(start); + SmallVector newLimit(limit); + newStart[dimension] -= frontOffset; + newLimit[dimension] -= frontOffset; + + rewriter.replaceOpWithNewOp( + slice, newConcat, rewriter.getDenseI64ArrayAttr(newStart), + rewriter.getDenseI64ArrayAttr(newLimit), slice.getStrides()); + return success(); + } +}; + +////////////////////////////////// +// SortOp +///////////////////////////////// + +/// Drops the operands if the results are not used and they are not used in +/// op.comparator(). + +// Pattern: sort(X,Y) -> sort(X) [if Y unused and unused in comparator] +struct SortOpDropUnusedArgs : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SortOp op, + PatternRewriter &rewriter) const override { + DenseSet erasedArgs; + unsigned numOperands = op.getNumOperands(); + for (unsigned i = 0; i < numOperands; ++i) { + if (!op.getResult(i).use_empty()) continue; + Block &block = op.getComparator().front(); + if (!block.getArgument(i * 2).use_empty()) continue; + if (!block.getArgument(i * 2 + 1).use_empty()) continue; + erasedArgs.insert(i); + } + if (erasedArgs.empty()) return failure(); + + SmallVector newOperands; + BitVector erasedBlockArgs(op.getNumOperands() * 2); + for (const auto &en : llvm::enumerate(op.getInputs())) { + if (erasedArgs.contains(en.index())) { + erasedBlockArgs.set(en.index() * 2); + erasedBlockArgs.set(en.index() * 2 + 1); + } else { + newOperands.push_back(en.value()); + } + } + + auto newOp = rewriter.create(op.getLoc(), newOperands, + op.getDimension(), op.getIsStable()); + Region ®ion = newOp.getComparator(); + rewriter.inlineRegionBefore(op.getComparator(), region, region.end()); + region.front().eraseArguments(erasedBlockArgs); + + SmallVector results; + for (unsigned i = 0, j = 0; i < numOperands; ++i) { + if (erasedArgs.contains(i)) { + results.push_back({}); + } else { + results.push_back(newOp.getResult(j++)); + } + } + rewriter.replaceOp(op, results); + + return success(); + } +}; + +/// Set the sorting dimension to the last dimension if it's not set and the rank +/// is known. +// Pattern: sort(X) -> sort(X, dim = N) [when dim can be inferred] +struct SortOpSetDimension : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(SortOp op, + PatternRewriter &rewriter) const override { + if (op.getResults().empty() || + static_cast(op.getDimension()) != -1) + return rewriter.notifyMatchFailure(op, + "dimension already set or no results"); + + auto type = cast(op.getResultTypes()[0]); + IntegerAttr dim = rewriter.getI64IntegerAttr(type.getRank() - 1); + auto newOp = + rewriter.create(op.getLoc(), op.getResultTypes(), + op.getInputs(), dim, op.getIsStableAttr()); + newOp.getComparator().takeBody(op.getComparator()); + rewriter.replaceOp(op, newOp.getResults()); + return success(); + } +}; + +////////////////////////////////// +// TransposeOp +///////////////////////////////// + +// Pattern: transpose(X, [no_mem_layout_change...]) -> reshape(X) +struct TransposeIsReshape final : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TransposeOp op, + PatternRewriter &rewriter) const override { + auto input = op.getOperand(); + auto permutation = op.getPermutation(); + + RankedTensorType inputTy = input.getType(); + if (!inputTy.hasStaticShape() || !op.getType().hasStaticShape()) + return rewriter.notifyMatchFailure( + op, + "requires input and output to be of a statically-shaped ranked " + "tensor type"); + + // Check that the permutation is a valid memory layout change. + // All non-zero/one dimensions must be in increasing order. + SmallVector nonZeroPerms; + nonZeroPerms.reserve(permutation.size()); + for (auto idx : permutation) + if (inputTy.getDimSize(idx) != 1) nonZeroPerms.push_back(idx); + + for (size_t i = 1; i < nonZeroPerms.size(); ++i) + if (nonZeroPerms[i - 1] > nonZeroPerms[i]) + return rewriter.notifyMatchFailure(op, "memory layout change"); + + rewriter.replaceOpWithNewOp(op, op.getType(), input); + return success(); + } +}; + +////////////////////////////////// +// TupleOp +///////////////////////////////// + +// Pattern: tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X +struct TupleIsRepacking : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(TupleOp op, + PatternRewriter &rewriter) const override { + if (op.getVal().empty()) + return rewriter.notifyMatchFailure(op, "empty tuple"); + + // Get parent tuple + Value firstElement = op.getVal().front(); + auto firstElementOp = firstElement.getDefiningOp(); + if (!firstElementOp) + return rewriter.notifyMatchFailure(op, "parent not get_tuple_element"); + + Value tuplePredecessor = firstElementOp.getOperand(); + if (tuplePredecessor.getType() != op.getType()) + return rewriter.notifyMatchFailure( + op, "tuple predecessor type does not match"); + + // Check that this is a repacking of the parent tuple. + for (const auto &elementAndIdx : llvm::enumerate(op.getVal())) { + auto elementOp = elementAndIdx.value().getDefiningOp(); + if (!elementOp || + elementOp.getIndexAttr().getInt() != + static_cast(elementAndIdx.index()) || + elementOp.getOperand() != tuplePredecessor) + return rewriter.notifyMatchFailure( + op, "not a repacking of the parent tuple"); + } + + rewriter.replaceOp(op, tuplePredecessor); + return success(); + } +}; + +///////////////////////////////// +// WhileOp +///////////////////////////////// + +// Turn loop invariant values into implicit capture. +// Check if there is at least one value is forwarded from one iteration to +// the next, or one of the yielded value is an implicit capture already. +// Otherwise there is nothing to do here. + +// Pattern: while -> while (loop invariants as implicit captures) +struct WhileOpImplicitCapture : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(WhileOp whileOp, + PatternRewriter &rewriter) const override { + Block *cond = whileOp.SingleBlock::getBody(0); + Block *body = whileOp.SingleBlock::getBody(1); + auto bodyReturnOp = cast(body->getTerminator()); + if (!llvm::any_of(llvm::zip(whileOp->getOperands(), body->getArguments(), + bodyReturnOp->getOperands()), + [&](auto zip) { + return (std::get<0>(zip) == std::get<2>(zip) || + std::get<1>(zip) == std::get<2>(zip)); + })) + return rewriter.notifyMatchFailure(whileOp, "no loop invariant found"); + + SmallVector newOperands, resultsToReplace; + SmallVector invariantArgIdxs; + BitVector invariantArgIdxBitVector(cond->getNumArguments()); + for (const auto &enumeratedOperands : llvm::enumerate(llvm::zip( + whileOp.getOperands(), cond->getArguments(), body->getArguments(), + bodyReturnOp->getOperands(), whileOp->getResults()))) { + const auto &operands = enumeratedOperands.value(); + Value whileOperand = std::get<0>(operands); + BlockArgument condBlockArg = std::get<1>(operands); + BlockArgument bodyBlockArg = std::get<2>(operands); + Value bodyReturnOperand = std::get<3>(operands); + Value whileResult = std::get<4>(operands); + + bool forwarded = (whileOperand == bodyReturnOperand || + bodyBlockArg == bodyReturnOperand); + if (forwarded) { + invariantArgIdxs.push_back(enumeratedOperands.index()); + invariantArgIdxBitVector.set(enumeratedOperands.index()); + condBlockArg.replaceAllUsesWith(whileOperand); + bodyBlockArg.replaceAllUsesWith(whileOperand); + whileResult.replaceAllUsesWith(whileOperand); + continue; + } + newOperands.push_back(whileOperand); + resultsToReplace.push_back(whileResult); + } + cond->eraseArguments(invariantArgIdxBitVector); + body->eraseArguments(invariantArgIdxBitVector); + for (int idx : llvm::reverse(invariantArgIdxs)) + bodyReturnOp->eraseOperand(idx); + + WhileOp newWhileOp = rewriter.create( + whileOp.getLoc(), bodyReturnOp->getOperandTypes(), newOperands); + newWhileOp.getBodyRegion(0).takeBody(whileOp.getBodyRegion(0)); + newWhileOp.getBodyRegion(1).takeBody(whileOp.getBodyRegion(1)); + for (auto results : llvm::zip(resultsToReplace, newWhileOp->getResults())) + std::get<0>(results).replaceAllUsesWith(std::get<1>(results)); + rewriter.eraseOp(whileOp); + return success(); + } +}; + +////////////////////////////////// +// Generic and Elementwise Ops +///////////////////////////////// + +/// Check if a `t` is a tensor with zero extents. +static std::optional getMaybeZeroExtentType(Type t) { + auto type = dyn_cast(t); + if (type && type.hasStaticShape() && type.getNumElements() == 0) return type; + return std::nullopt; +} + +// Replace instances of zero extent tensors with empty tensors +// Pattern: op(X : zero_extent_tensor) -> constant([]) +struct ZeroExtentToEmptyConstant final : RewritePattern { + ZeroExtentToEmptyConstant(MLIRContext *context, PatternBenefit benefit) + : RewritePattern(MatchAnyOpTypeTag(), benefit, context) {} + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + auto loc = op->getLoc(); + + if (!isa_and_present(op->getDialect())) + return rewriter.notifyMatchFailure(op, "not stablehlo"); + if (isa(op)) + return rewriter.notifyMatchFailure(op, "op is empty constant"); + + // If the result is a zero-extent tensor, replace the whole op with an empty + // constant. + bool didUpdate = false; + for (auto result : op->getResults()) { + auto resultType = getMaybeZeroExtentType(result.getType()); + if (!resultType || result.use_empty()) continue; + rewriter.replaceAllUsesWith( + result, rewriter.create( + loc, result.getType(), + DenseElementsAttr::get(resultType.value(), + ArrayRef()))); + didUpdate = true; + } + + // If one of the operands is a zero-extent tensor, replace the operand with + // an empty tensor. + for (OpOperand &operand : op->getOpOperands()) { + auto operandType = getMaybeZeroExtentType(operand.get().getType()); + if (!operandType || operand.get().getDefiningOp()) continue; + Operation *owner = operand.getOwner(); + int operandNum = operand.getOperandNumber(); + auto emptyConstantOp = rewriter.create( + loc, operandType.value(), + DenseElementsAttr::get(operandType.value(), ArrayRef())); + rewriter.modifyOpInPlace( + owner, [&]() { owner->setOperand(operandNum, emptyConstantOp); }); + didUpdate = true; + } + return success(didUpdate); + } +}; + +struct ReorderElementwiseAndShapeOp final + : OpTraitRewritePattern { + using OpTraitRewritePattern::OpTraitRewritePattern; + + LogicalResult matchAndRewrite(Operation *op, + PatternRewriter &rewriter) const override { + if (op->getOperands().size() != 1) + return rewriter.notifyMatchFailure(op, "expected to be unary"); + + auto definingOp = op->getOperand(0).getDefiningOp(); + if (!definingOp) + return rewriter.notifyMatchFailure( + op, "expected to have an op before elementise op"); + + if (!isa(definingOp)) + return rewriter.notifyMatchFailure( + op, "defining operation of unexpected type"); + + // Reshape and broadcast are not allowed to have dynamic shape. + Value result = op->getResult(0); + if (isa(definingOp) && + !cast(result.getType()).hasStaticShape()) + return rewriter.notifyMatchFailure( + op, "cannot reorder around reshape/broadcast with dynamic shape"); + + // Only reorder if the defining op has no other uses. + if (!llvm::hasSingleElement(definingOp->getResult(0).getUses())) + return rewriter.notifyMatchFailure(op, "operation has more than one use"); + + Value input = definingOp->getOperand(0); + auto intermediateType = cast(input.getType()) + .clone(getElementTypeOrSelf(result.getType())); + + // Reorder the operation and rewire the inputs/outputs. + op->moveBefore(definingOp); + definingOp->getResult(0).setType(result.getType()); + rewriter.replaceAllUsesWith(result, definingOp->getResult(0)); + result.setType(intermediateType); + op->setOperands(input); + definingOp->setOperands(result); + return success(); + } +}; + +struct StablehloAggressiveSimplificationPass final + : impl::StablehloAggressiveSimplificationPassBase< + StablehloAggressiveSimplificationPass> { + StablehloAggressiveSimplificationPass() = default; + StablehloAggressiveSimplificationPass(GreedyRewriteConfig config) + : config(config) {} + LogicalResult initialize(MLIRContext *context) override { + RewritePatternSet patterns_(context); + populateStablehloCanonicalizationPatterns(context, &patterns_); + patterns = std::move(patterns_); + return success(); + } + + void runOnOperation() override { + if (failed(applyPatternsGreedily(getOperation(), patterns, config))) + signalPassFailure(); + } + + private: + GreedyRewriteConfig config; + FrozenRewritePatternSet patterns; +}; + +#include "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc" +} // namespace + +void populateStablehloCanonicalizationPatterns(MLIRContext *context, + RewritePatternSet *patterns, + PatternBenefit benefit) { + populateWithGenerated(*patterns); + patterns->add(context); + patterns->add< + CompareOpCanon, CompareSelectIntoMinMax, ConcatenateOpFlatten, + ConcatenateOpNoop, ConcatenateOpRemoveEmpty, DynamicIotaOpToBroadcast, + DynamicReshapeOpSameOperandAndResultShape, DynamicSliceOpToSlice, + GatherOpCanon, IotaOpBroadcast, PadOpBroadcastEmptyTensor, + RealDynamicSliceOpToDynamicSlice, ReduceOpEmptyCanon, + ReduceOpNoopVariableReturn, ReduceOpUnusedResultCanon, SelectOpCanon, + SliceOpConcatSimplify, SortOpDropUnusedArgs, SortOpSetDimension, + TransposeIsReshape, TupleIsRepacking, WhileOpImplicitCapture>(context, + benefit); + + // Generic patterns + patterns->add( + context, benefit); + + // TODO: Dynamism Refinements, consider merging with canonicalize dynamism + patterns + ->add(context); +} + +std::unique_ptr createStablehloAggressiveSimplificationPass( + GreedyRewriteConfig config) { + return std::make_unique(config); +} + +} // namespace stablehlo +} // namespace mlir diff --git a/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td new file mode 100644 index 0000000000..9cbcc07ca6 --- /dev/null +++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td @@ -0,0 +1,427 @@ +// Copyright 2020 The IREE Authors +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// This is the legalization pattern definition file for CHLO to StableHLO. +// These are included in the populateDecompositionPatterns factory +// and should only include canonical expansions which are not actually +// ambiguous/different for various backends. Avoid patterns that are actually +// lowering to non-canonical forms. + +include "mlir/IR/OpBase.td" +include "stablehlo/dialect/StablehloOps.td" +include "mlir/Dialect/Shape/IR/ShapeOps.td" + +/////////// +//// Op & Type Constraints + +class DimSizeEquals : Constraint< + CPred<"llvm::cast($0.getType()).getDimSize($1.getInt()) == " # dimSize>, + "dim size is " # dimSize>; + +def AllDimsNonExpanding : Constraint< + CPred<"$0 && cast($0).size() == llvm::cast($1.getType()).getRank()">, + "all dims are non-expanding">; + +def AllZero : Constraint< + CPred<"llvm::all_of($0, [](Value operand) {return matchPattern(operand, m_Zero()); })">, + "is all zero">; + +def CommutativeOp : Constraint< + CPred<"$0.getDefiningOp()->hasTrait()">, + "op is commutative">; + +def HasOneUse : Constraint>; + +def NotConstantOp : Constraint< + CPred<"llvm::isa($0) || !llvm::isa($0.getDefiningOp())">, + "is not a constant.">; + +def NumberOfElementsEqual : Constraint< + CPred<"llvm::cast($0.getType()).getNumElements() == llvm::cast($1.getType()).getNumElements()">, + "same number of elements">; + +def OperandsEqual : Constraint, "operands are equal">; + +def RankEqual : Constraint< + CPred<"llvm::cast($0.getType()).getRank() == llvm::cast($1.getType()).getRank()">, + "same rank">; + +def TypesEqual : Constraint, "operands are equal">; + +/////////// +//// Attribute Constraints + +def AnySplat : AttrConstraint, "is any splat">; + +def AnyZero : AttrConstraint< + CPred<"::mlir::matchPattern($_self, m_AnyAttrOf(m_Zero(), m_AnyZeroFloat()))">, + "is int or float zero">; + +def DenseIntElementsAttr : AttrConstraint< + CPred<"llvm::isa($_self)">, + "is dense int elements attr">; + +def EmptyI64Array : AttrConstraint< + CPred<"cast($_self).empty()">, + "is empty i64 array">; + +def IntOne : AttrConstraint< + CPred<"::mlir::matchPattern($_self, m_One())">, + "is integer one">; + +def IntAllOnes : AttrConstraint< + CPred<[{ + ::mlir::matchPattern($_self, + ::mlir::detail::constant_int_predicate_matcher{ + [](const llvm::APInt &val) { + return val.isAllOnes(); + }}) + }]>, + "is integer with all bits set to 1">; + +def IntZero : AttrConstraint< + CPred<"::mlir::matchPattern($_self, m_Zero())">,"is integer zero">; + +def IotaDims : AttrConstraint< + CPred<"isIotaRange(cast($_self).asArrayRef())">, + "is iota dimensions">; + +def SortedDims : AttrConstraint< + CPred<"llvm::is_sorted(cast($_self).asArrayRef())">, + "is sorted dimensions">; + +def ZeroExtent : AttrConstraint< + CPred<"cast($_self).getNumElements() == 0">, + "is zero extent">; + +/////////// +//// Native Code Call Utilities + +def CastIntElementsAttr : NativeCodeCall<"cast($0)">; + +def ConvertToI64Array : NativeCodeCall<"convertToI64Array($_builder, $0)">; + +def GetOperandN : NativeCodeCall<"$0.getDefiningOp()->getOperand($1.getInt())">; + +def GetEmptyI64Array : NativeCodeCall<"$_builder.getDenseI64ArrayAttr({})">; + +def MergeBroadcastDims : NativeCodeCall<"getMergedBroadcastDimensions($_builder, $0, $1)">; + +def StableHLO_ConvertOpWithShape : NativeCodeCall< + "$_builder.create($_loc, $0.getType(), $1)">; + +def StableHLO_ReshapeOpWithShape : NativeCodeCall< + "$_builder.create($_loc, $0.getType(), $1)">; + +class StableHLO_ConstantLike : NativeCodeCall< + "::mlir::stablehlo::getConstantLike($_builder, $_loc, " # value # ", $0)">; + +//////////////////////////// +// Generic BinaryOp Patterns + +// op(cst, X) -> op(X, cst) +class CanonicalizeConstantToRhs + : Pat<(StableHLO_OpType:$op (StableHLO_ConstantOp:$lhs $value), $rhs), + (StableHLO_OpType $rhs, $lhs), + [(NotConstantOp $rhs), (CommutativeOp $op)]>; + +//////// +// AddOp + +// Pattern: add(cst, X) -> add(X, cst) +def : CanonicalizeConstantToRhs; + +// Pattern: add(X, 0) -> X +def : Pat<(StableHLO_AddOp $lhs, (ConstantLikeMatcher AnyZero:$value)), + (replaceWithValue $lhs)>; + +//////// +// AndOp + +// Pattern: and(cst, X) -> and(X, cst) +def : CanonicalizeConstantToRhs; + +// Pattern: and(X, 0) -> 0 +def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), + (replaceWithValue $zero)>; + +// Pattern: and(X, 1) -> X +def : Pat<(StableHLO_AndOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)), + (replaceWithValue $lhs)>; + +//////// +// BroadcastInDimOp + +// Pattern: broadcast_in_dim(X, [iota...]) -> X +def : Pat<(StableHLO_BroadcastInDimOp:$op $operand, IotaDims:$dims), + (replaceWithValue $operand), + [(TypesEqual $op, $operand)]>; + +// Pattern: broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB)) +def : Pat<(StableHLO_BroadcastInDimOp + (StableHLO_BroadcastInDimOp $operand, $dims_parent), $dims), + (StableHLO_BroadcastInDimOp $operand, (MergeBroadcastDims $dims, $dims_parent))>; + +// Pattern: broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel] +def : Pat<(StableHLO_BroadcastInDimOp:$op $operand, SortedDims:$dims), + (StableHLO_ReshapeOpWithShape $op, $operand), + [(NumberOfElementsEqual $op, $operand)]>; + +// Pattern: broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank] +def : Pat<(StableHLO_BroadcastInDimOp:$op $operand, $dims), + (StableHLO_TransposeOp $operand, $dims), + [(NumberOfElementsEqual $op, $operand), (RankEqual $op, $operand)]>; + +//////// +// ConvertOp + +// Pattern: convert(X, [X.type]) -> X +def : Pat<(StableHLO_ConvertOp:$convert $operand), + (replaceWithValue $operand), + [(TypesEqual $convert, $operand)]>; + +//////// +// DynamicBroadcastInDimOp + +// Pattern: dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB)) +// TODO: Think more if the values of known_expanding_dimensions and known_non_expanding_dimensions can be preserved. +def : Pat<(StableHLO_DynamicBroadcastInDimOp + (StableHLO_DynamicBroadcastInDimOp $operand, $shape_p, $dims_p, $expanding_p, $nonexpanding_p), + $shape, $dims, $expanding, $nonexpanding), + (StableHLO_DynamicBroadcastInDimOp $operand, $shape, (MergeBroadcastDims $dims, $dims_p), (GetEmptyI64Array), (GetEmptyI64Array))>; + +// Pattern: dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X) +// No-op, but wrap in ConvertOp to preserve dynamic output shape, can be +// important if this result is returned, where refining type would require +// also updating the funciton signature. +def : Pat<(StableHLO_DynamicBroadcastInDimOp:$op $operand, $shape, IotaDims:$dims, $expanding, $nonexpanding), + (StableHLO_ConvertOpWithShape $op, $operand), + [(AllDimsNonExpanding $nonexpanding, $op)]>; + +// Pattern: dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape) +// If sharing same shape operand, is dynamic reshape. +def : Pat<(StableHLO_DynamicBroadcastInDimOp + (StableHLO_DynamicReshapeOp $operand, $shape), $shape, IotaDims:$dims, $expanding, $nonexpanding), + (StableHLO_DynamicReshapeOp $operand, $shape)>; + +// Pattern: dynamic_broadcast_in_dim(X, shape_of(X)) -> X +def : Pat<(StableHLO_DynamicBroadcastInDimOp + $operand, (Shape_ShapeOfOp $operand), IotaDims:$dims, $expanding, $nonexpanding), + (replaceWithValue $operand)>; + +//////// +// DynamicGatherOp + +// Pattern: dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes) +def : Pat<(StableHLO_DynamicGatherOp $operand, $start_indices, (StableHLO_ConstantOp DenseIntElementsAttr:$slice_sizes), $dimension_numbers, $indices_are_sorted), + (StableHLO_GatherOp $operand, $start_indices, $dimension_numbers, (ConvertToI64Array $slice_sizes), $indices_are_sorted)>; + +//////// +// DynamicPadOp + +// Pattern: dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) +// [if low, high, interior are all constants] +def : Pat<(StableHLO_DynamicPadOp $input, + $padding_value, + (ConstantLikeMatcher AnyIntElementsAttr:$edge_padding_low), + (ConstantLikeMatcher AnyIntElementsAttr:$edge_padding_high), + (ConstantLikeMatcher AnyIntElementsAttr:$interior_padding)), + (StableHLO_PadOp $input, $padding_value, + (ConvertToI64Array $edge_padding_low), + (ConvertToI64Array $edge_padding_high), + (ConvertToI64Array $interior_padding))>; + +//////// +// DynamicReshapeOp + +// Pattern: dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape) +def : Pat<(StableHLO_DynamicReshapeOp (StableHLO_DynamicReshapeOp $operand, $shape_p), $shape), + (StableHLO_DynamicReshapeOp $operand, $shape)>; + +// Pattern: shape_of(dynamic_reshape(X, shape)) -> shape +def : Pat<(Shape_ShapeOfOp:$op (StableHLO_DynamicReshapeOp $x, $shape)), + (replaceWithValue $shape), + [(TypesEqual $shape, $op)]>; + +//////// +// DynamicUpdateSliceOp + +// Pattern: dynamic_update_slice(X, update : zero_extent)) -> X +def : Pat<(StableHLO_DynamicUpdateSliceOp $operand, (ConstantLikeMatcher ZeroExtent:$update), $start_indices), + (replaceWithValue $operand)>; + +// Pattern: dynamic_update_slice(X, update, start_indices : zero)) -> update +def : Pat<(StableHLO_DynamicUpdateSliceOp AnyStaticShapeTensor:$operand, AnyStaticShapeTensor:$update, $start_indices), + (replaceWithValue $update), + [(TypesEqual $operand, $update), (AllZero $start_indices)]>; + + +//////// +// ComplexOp + +// Pattern: complex(real(X), imag(X))) -> X +def : Pat<(StableHLO_ComplexOp (StableHLO_RealOp $operand), (StableHLO_ImagOp $operand)), + (replaceWithValue $operand)>; + + +//////// +// ImagOp + +// Pattern: imag(complex(R,I)) -> I +def : Pat<(StableHLO_ImagOp (StableHLO_ComplexOp $lhs, $rhs)), + (replaceWithValue $rhs)>; + +//////// +// IotaOp + +// Pattern: iota(dim) : type -> constant(0) : type [if type[dim] == 1] +def : Pat<(StableHLO_IotaOp:$iota $dim), + (StableHLO_ConstantLike<"0"> $iota), + [(DimSizeEquals<1> $iota, $dim)]>; + + +//////// +// MaxOp + +// Pattern: max(cst, X) -> max(X, cst) +def : CanonicalizeConstantToRhs; + +//////// +// MinOp + +// Pattern: minimum(cst, X) -> minimum(X, cst) +def : CanonicalizeConstantToRhs; + +//////// +// MulOp + +// Pattern: multiply(cst, X) -> multiply(X, cst) +def : CanonicalizeConstantToRhs; + +// Pattern: multiply(X, 0i) -> 0i +// Multiplication by 0. This fold is not trivial for floats in presence of NaNs +def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), + (replaceWithValue $zero)>; + +// Pattern: multiply(X, 1i) -> X +def : Pat<(StableHLO_MulOp $lhs, (StableHLO_ConstantOp IntOne:$value)), + (replaceWithValue $lhs)>; + +//////// +// OrOp + +// Pattern: or(cst, X) -> or(X, cst) +def : CanonicalizeConstantToRhs; + +// Pattern: or(X, 1) -> 1 +def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$one IntAllOnes:$value)), + (replaceWithValue $one)>; + +// Pattern: or(X, 0) -> X +def : Pat<(StableHLO_OrOp $lhs, (StableHLO_ConstantOp:$zero IntZero:$value)), + (replaceWithValue $lhs)>; + +//////// +// RealDynamicSliceOp + +// Pattern: real_dynamic_slice(X, start, limit, strides) +// -> slice(X, start, limit, strides) +// [if start, limit, strides are all constants] +def : Pat<(StableHLO_RealDynamicSliceOp $operand, + (ConstantLikeMatcher DenseIntElementsAttr:$start_indices), + (ConstantLikeMatcher DenseIntElementsAttr:$limit_indices), + (ConstantLikeMatcher DenseIntElementsAttr:$strides)), + (StableHLO_SliceOp $operand, + (ConvertToI64Array $start_indices), + (ConvertToI64Array $limit_indices), + (ConvertToI64Array $strides))>; + +//////// +// RealOp + +// Pattern: real(complex(R,I)) -> X +def : Pat<(StableHLO_RealOp (StableHLO_ComplexOp $lhs, $rhs)), + (replaceWithValue $lhs)>; + +//////// +// ReduceOp +// Note: If modifying region is required, must write pattern in C++ + +// Pattern: reduce(X..., dims=[], add) -> X... +def : Pat<(StableHLO_ReduceOp $operands, $init, EmptyI64Array:$dims), + (replaceWithValue $operands)>; + +//////// +// ReshapeOp + +// Pattern: reshape(reshape(X, _), [shape]) -> reshape(X, [shape]) +def : Pat<(StableHLO_ReshapeOp:$reshape (StableHLO_ReshapeOp $operand)), + (StableHLO_ReshapeOpWithShape $reshape, $operand)>; + +// Pattern: reshape(X, [X.shape]) -> X +def : Pat<(StableHLO_ReshapeOp:$reshape $operand), + (replaceWithValue $operand), + [(TypesEqual $reshape, $operand)]>; + + +//////// +// SelectOp + +// Pattern: select(not(p), t, f) => select(p, f, t) +def : Pat< + (StableHLO_SelectOp (StableHLO_NotOp $pred), $on_true, $on_false), + (StableHLO_SelectOp $pred, $on_false, $on_true)>; + +// Pattern: select(broadcast(not(p)), t, f) => select(broadcast(p), f, t) +def : Pat<(StableHLO_SelectOp (StableHLO_BroadcastInDimOp:$b (StableHLO_NotOp $pred), $broadcast_dimensions), $on_true, $on_false), + (StableHLO_SelectOp (StableHLO_BroadcastInDimOp $pred, $broadcast_dimensions, (returnType $b)), $on_false, $on_true), + [(HasOneUse $b)]>; + +//////// +// SubtractOp + +// Pattern: subtract(X, X) -> 0 +// Must be static shape, otherwise would require broadcasting via CHLO_ConstantLike +def : Pat<(StableHLO_SubtractOp AnyStaticShapeTensor:$operand, $operand), + (StableHLO_ConstantLike<"0"> $operand)>; + +// Pattern: subtract(X, 0) -> X +def : Pat<(StableHLO_SubtractOp $lhs, (StableHLO_ConstantOp AnyZero:$value)), + (replaceWithValue $lhs)>; + +//////// +// SliceOp + +// Pattern: slice(X, [A:A], [B:B], ...) -> X +def : Pat<(StableHLO_SliceOp:$op AnyStaticShapeTensor:$operand, $start_indices, $limit_indices, $strides), + (replaceWithValue $operand), + [(TypesEqual $operand, $op)]>; + +//////// +// TransposeOp + +// Pattern: transpose(X, [iota...]) -> X +def : Pat<(StableHLO_TransposeOp $lhs, IotaDims:$dims), + (replaceWithValue $lhs)>; + +//////// +// GetTupleElementOp + +// Pattern: get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i +def : Pat<(StableHLO_GetTupleElementOp (StableHLO_TupleOp:$tuple $operands), $idx), + (GetOperandN $tuple, $idx)>; + +//////// +// XorOp + +// Pattern: xor(cst, X) -> xor(X, cst) +def : CanonicalizeConstantToRhs; + +// To consider: xor(X, X) -> 0 +// Unclear if this is beneficial on hardware vs adding another constant +// +// def : Pat<(StableHLO_XorOp AnyStaticShapeTensor:$operand, $operand), +// (StableHLO_ConstantLike<"0"> $operand)>; diff --git a/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp b/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp new file mode 100644 index 0000000000..bae782bd27 --- /dev/null +++ b/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp @@ -0,0 +1,66 @@ +// Copyright 2023 The IREE Authors +// +// Licensed under the Apache License, Version 2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// Implements optional canonicalization patterns for StableHLO ops. + +#include +#include +#include + +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/transforms/optimization/Passes.h" + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DEF_STABLEHLOTARGETINDEPENDENTOPTIMIZATIONPASS +#include "stablehlo/transforms/optimization/Passes.h.inc" + +// This is an upper limit on how many elements can be folded by an op folder. +// This limit doesn't apply to some special cases like adding a zero, +// multiplying by one, doing many operations with splats. +constexpr int64_t kFoldOpEltLimit = 65536; + +struct StablehloTargetIndependentOptimizationPass + : public impl::StablehloTargetIndependentOptimizationPassBase< + StablehloTargetIndependentOptimizationPass> { + using StablehloTargetIndependentOptimizationPassBase:: + StablehloTargetIndependentOptimizationPassBase; + + LogicalResult initialize(MLIRContext* context) override { + RewritePatternSet patterns_(context); + bool foldFloat = false; + populateStablehloCanonicalizationPatterns(context, &patterns_); + populateStablehloAggressiveFolderPatterns(&patterns_, context, foldFloat, + /*benefit=*/2); + patterns = std::move(patterns_); + + return success(); + } + + void runOnOperation() override { + GreedyRewriteConfig config; + config.fold = true; + config.cseConstants = true; + config.maxIterations = kFoldOpEltLimit; + config.useTopDownTraversal = false; + if (failed(applyPatternsGreedily(getOperation(), patterns, config))) + signalPassFailure(); + } + + private: + FrozenRewritePatternSet patterns; +}; + +} // namespace stablehlo +} // namespace mlir