From 2720b9029944eb71af39937116699909ccaadd3b Mon Sep 17 00:00:00 2001 From: Kevin Gleason Date: Tue, 4 Feb 2025 17:08:34 -0600 Subject: [PATCH] Add target independent optimization pass (#2707) --- BUILD.bazel | 87 ++++++++++- build_tools/github_actions/ci_build_docs.sh | 1 + docs/_toc.yaml | 2 + .../stablehlo_optimization_passes.md | 110 ++++++++++++++ docs/generated/stablehlo_passes.md | 15 -- ...lehlo_target_independent_optimization.mlir | 12 ++ stablehlo/tools/StablehloOptMain.cpp | 2 + stablehlo/transforms/CMakeLists.txt | 10 +- stablehlo/transforms/Passes.h | 16 -- stablehlo/transforms/Passes.td | 21 --- .../transforms/StablehloRefineShapes.cpp | 1 + .../transforms/optimization/CMakeLists.txt | 45 ++++++ stablehlo/transforms/optimization/Passes.h | 56 +++++++ stablehlo/transforms/optimization/Passes.td | 141 ++++++++++++++++++ .../StablehloAggressiveFolder.cpp | 61 ++++---- .../StablehloAggressiveSimplification.cpp | 6 +- ...ablehloAggressiveSimplificationPatterns.td | 0 ...StablehloTargetIndependentOptimization.cpp | 72 +++++++++ 18 files changed, 560 insertions(+), 98 deletions(-) create mode 100755 docs/generated/stablehlo_optimization_passes.md create mode 100644 stablehlo/tests/transforms/stablehlo_target_independent_optimization.mlir create mode 100644 stablehlo/transforms/optimization/CMakeLists.txt create mode 100644 stablehlo/transforms/optimization/Passes.h create mode 100644 stablehlo/transforms/optimization/Passes.td rename stablehlo/transforms/{ => optimization}/StablehloAggressiveFolder.cpp (94%) rename stablehlo/transforms/{ => optimization}/StablehloAggressiveSimplification.cpp (99%) rename stablehlo/transforms/{ => optimization}/StablehloAggressiveSimplificationPatterns.td (100%) create mode 100644 stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp diff --git a/BUILD.bazel b/BUILD.bazel index 40cc04aedd..3ecaf3d19e 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -316,11 +316,11 @@ gentbl_cc_library( tbl_outs = [ ( ["--gen-rewriters"], - "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.h.inc", + "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc", ), ], tblgen = "@llvm-project//mlir:mlir-tblgen", - td_file = "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td", + td_file = "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td", deps = [ ":stablehlo_ops_td_files", ], @@ -1125,15 +1125,33 @@ gentbl_cc_library( deps = ["@llvm-project//mlir:PassBaseTdFiles"], ) +cc_library( + name = "stablehlo_pass_utils", + srcs = [ + "stablehlo/transforms/PassUtils.cpp", + ], + hdrs = [ + "stablehlo/transforms/PassUtils.h", + ], + strip_include_prefix = ".", + deps = [ + ":base", + ":chlo_ops", + ":stablehlo_ops", + ":stablehlo_pass_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ComplexDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Support", + ], +) + cc_library( name = "stablehlo_passes", srcs = [ "stablehlo/transforms/ChloLegalizeToStablehlo.cpp", "stablehlo/transforms/PassPipelines.cpp", - "stablehlo/transforms/PassUtils.cpp", "stablehlo/transforms/ShapeLegalizeToStablehlo.cpp", - "stablehlo/transforms/StablehloAggressiveFolder.cpp", - "stablehlo/transforms/StablehloAggressiveSimplification.cpp", "stablehlo/transforms/StablehloCanonicalizeDynamism.cpp", "stablehlo/transforms/StablehloCompatibilityExpander.cpp", "stablehlo/transforms/StablehloComplexMathExpander.cpp", @@ -1163,13 +1181,14 @@ cc_library( ":chlo_ops", ":chlo_rewriters_inc_gen", ":linalg_passes", - ":stablehlo_aggressive_simplification_inc_gen", ":stablehlo_create_compatibility_expander_inc_gen", ":stablehlo_create_complex_math_expander_inc_gen", ":stablehlo_legalize_deprecated_ops_inc_gen", ":stablehlo_ops", ":stablehlo_ops_inc_gen", ":stablehlo_pass_inc_gen", + ":stablehlo_pass_utils", + ":stablehlo_passes_optimization", ":stablehlo_type_inference", ":version", ":vhlo_ops", @@ -1198,6 +1217,61 @@ cc_library( ], ) +gentbl_cc_library( + name = "stablehlo_passes_optimization_inc_gen", + strip_include_prefix = ".", + tbl_outs = [ + ( + [ + "-gen-pass-decls", + "-name=Optimization", + ], + "stablehlo/transforms/optimization/Passes.h.inc", + ), + ( + [ + "-gen-pass-doc", + ], + "stablehlo/transforms/optimization/stablehlo_optimization_passes.md", + ), + ], + tblgen = "@llvm-project//mlir:mlir-tblgen", + td_file = "stablehlo/transforms/optimization/Passes.td", + deps = ["@llvm-project//mlir:PassBaseTdFiles"], +) + +cc_library( + name = "stablehlo_passes_optimization", + srcs = [ + "stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp", + "stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp", + "stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp", + ], + hdrs = [ + "stablehlo/transforms/optimization/Passes.h", + ], + strip_include_prefix = ".", + deps = [ + ":base", + ":stablehlo_aggressive_simplification_inc_gen", + ":stablehlo_ops", + ":stablehlo_pass_inc_gen", + ":stablehlo_pass_utils", + ":stablehlo_passes_optimization_inc_gen", + "@llvm-project//llvm:Support", + "@llvm-project//mlir:ArithDialect", + "@llvm-project//mlir:CommonFolders", + "@llvm-project//mlir:DialectUtils", + "@llvm-project//mlir:FuncDialect", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:Rewrite", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:TransformUtils", + "@llvm-project//mlir:Transforms", + ], +) + cc_library( name = "stablehlo_portable_api", srcs = [ @@ -1364,6 +1438,7 @@ cc_binary( ":linalg_passes", ":register", ":stablehlo_passes", + ":stablehlo_passes_optimization", ":tosa_passes", "//stablehlo/tests:check_ops", "//stablehlo/tests:test_utils", diff --git a/build_tools/github_actions/ci_build_docs.sh b/build_tools/github_actions/ci_build_docs.sh index c75e439aec..d19fde6eec 100755 --- a/build_tools/github_actions/ci_build_docs.sh +++ b/build_tools/github_actions/ci_build_docs.sh @@ -41,6 +41,7 @@ fi declare -A targets targets[":stablehlo_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/transforms/stablehlo_passes.md" +targets[":stablehlo_passes_optimization_inc_gen_filegroup"]="bazel-bin/stablehlo/transforms/optimization/stablehlo_optimization_passes.md" targets[":interpreter_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/reference/interpreter_passes.md" targets[":linalg_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/conversions/linalg/transforms/stablehlo_linalg_passes.md" targets[":tosa_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/conversions/tosa/transforms/stablehlo_tosa_passes.md" diff --git a/docs/_toc.yaml b/docs/_toc.yaml index b0d7ff765a..bcf846c47c 100644 --- a/docs/_toc.yaml +++ b/docs/_toc.yaml @@ -59,6 +59,8 @@ toc: path: /stablehlo/ide - title: StableHLO Passes path: /stablehlo/generated/stablehlo_passes + - title: StableHLO Optimization Passes + path: /stablehlo/generated/stablehlo_optimization_passes - title: StableHLO Interpreter Passes path: /stablehlo/generated/interpreter_passes - title: StableHLO Linalg Passes diff --git a/docs/generated/stablehlo_optimization_passes.md b/docs/generated/stablehlo_optimization_passes.md new file mode 100755 index 0000000000..224ffdd07e --- /dev/null +++ b/docs/generated/stablehlo_optimization_passes.md @@ -0,0 +1,110 @@ + +### `-stablehlo-aggressive-folder` + +_Folds StableHLO operations_ + + + +#### Options +``` +-fold-float : Allow for potentially lossy computations using float type. +``` +### `-stablehlo-aggressive-simplification` + +_Canonicalizes StableHLO operations_ + + + +Note: Prefer StablehloTargetIndependentOptimizationPass to get best results. + +Performs graph simplifications, including: + +``` +- add(cst, X) -> add(X, cst) +- add(X, 0) -> X +- and(cst, X) -> and(X, cst) +- and(X, 0) -> 0 +- and(X, 1) -> X +- broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB)) +- broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank] +- broadcast_in_dim(X, [iota...]) -> X +- broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel] +- compare(cst, X, comparator) -> compare(X, cst, inv(comparator)) +- compare(X, X, [EQ,GE,LE]) -> true +- compare(X, X, [NE,GT,LT]) -> false +- complex(real(X), imag(X))) -> X +- concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z) +- concatenate(X) -> X +- concatenate(X, Y, []) -> concatenate(X, Y) +- convert(X, [X.type]) -> X +- dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB)) +- dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape) +- dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X) +- dynamic_broadcast_in_dim(X, shape_of(X)) -> X +- dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes) +- dynamic_iota(shape, dim) -> +- dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) +- dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape) +- dynamic_reshape(op(dynamic_reshape(X, shape)), shape) +- dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes) +- dynamic_update_slice(X, update, start_indices : zero)) -> update +- dynamic_update_slice(X, update : zero_extent)) -> X +- gather(X, cst_start_indices) -> slice(X, slice_start, slice_end) +- get_dimension_size(X, i) -> X.shape[i] +- get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i +- imag(complex(R,I)) -> I +- iota(dim) : multi_rank +- iota(dim) : type -> constant(0) : type [if type[dim] == 1] +- max(cst, X) -> max(X, cst) +- minimum(cst, X) -> minimum(X, cst) +- multiply(cst, X) -> multiply(X, cst) +- multiply(X, 0i) -> 0i +- multiply(X, 1i) -> X +- op(X : zero_extent_tensor) -> constant([]) +- or(cst, X) -> or(X, cst) +- or(X, 0) -> X +- or(X, 1) -> 1 +- pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _) +- real(complex(R,I)) -> X +- real_dynamic_slice(X, start, limit, strides) +- real_dynamic_slice(X, start, limit, strides) +- reduce[A](_, _, fn:return A) -> A... +- reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...] +- reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)] +- reduce(X..., dims=[], add) -> X... +- reshape(reshape(X, _), [shape]) -> reshape(X, [shape]) +- reshape(X, [X.shape]) -> X +- select(broadcast(not(p)), t, f) => select(broadcast(p), f, t) +- select(not(p), t, f) => select(p, f, t) +- shape_of(dynamic_reshape(X, shape)) -> shape +- slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z)) +- slice(X, [A:A], [B:B], ...) -> X +- sort(X) -> sort(X, dim = N) [when dim can be inferred] +- sort(X,Y) -> sort(X) [if Y unused and unused in comparator] +- subtract(X, 0) -> X +- subtract(X, X) -> 0 +- transpose(X, [iota...]) -> X +- transpose(X, [no_mem_layout_change...]) -> reshape(X) +- tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X +- while -> while (loop invariants as implicit captures) +- xor(cst, X) -> xor(X, cst) +- (+more) +``` + +This list is pulled from code comments so is not fully exhaustive, but has +high coverage of the pass today. +### `-stablehlo-target-independent-optimization` + +_Runs canonicalizers, folders, and other target-independent optimizations._ + +Uses patterns from StablehloAggressiveSimplificationPass and +StablehloAggressiveFolderPass together, allowing canonicalization and +folding to be performed in the same pattern set, often leading to better +results. + +Users should prefer this pass to calling the others directly. diff --git a/docs/generated/stablehlo_passes.md b/docs/generated/stablehlo_passes.md index 2fdda99136..1857b8b7ae 100755 --- a/docs/generated/stablehlo_passes.md +++ b/docs/generated/stablehlo_passes.md @@ -13,21 +13,6 @@ An experimental pass that legalizes shape-related ops to StableHLO ops. Bringing shape and data computations together via an optional pass will make it possible for the StableHLO ecosystem to potentially leverage the compilation pipelines that use StableHLO operations to model dynamism. -### `-stablehlo-aggressive-folder` - -_Folds StableHLO operations_ - - - -#### Options -``` --fold-float : Allow for potentially lossy computations using float type. -``` -### `-stablehlo-aggressive-simplification` - -_Canonicalizes StableHLO operations_ - - ### `-stablehlo-canonicalize-dynamism` _Canonicalizes dynamic StableHLO ops into static ops._ diff --git a/stablehlo/tests/transforms/stablehlo_target_independent_optimization.mlir b/stablehlo/tests/transforms/stablehlo_target_independent_optimization.mlir new file mode 100644 index 0000000000..069d59a764 --- /dev/null +++ b/stablehlo/tests/transforms/stablehlo_target_independent_optimization.mlir @@ -0,0 +1,12 @@ +// RUN: stablehlo-opt --stablehlo-target-independent-optimization --split-input-file %s | FileCheck %s + +// Check that simplificaiton and folding are both applied. + +// CHECK-LABEL: @add_cst_on_rhs +func.func @add_cst_on_rhs(%arg0: tensor) -> tensor { + %cst = stablehlo.constant dense<1.0> : tensor + %0 = stablehlo.add %cst, %cst : tensor + // CHECK: stablehlo.add %arg0, %cst : tensor + %1 = stablehlo.add %0, %arg0 : tensor + return %1 : tensor +} diff --git a/stablehlo/tools/StablehloOptMain.cpp b/stablehlo/tools/StablehloOptMain.cpp index 8562a46462..34f90d60e8 100644 --- a/stablehlo/tools/StablehloOptMain.cpp +++ b/stablehlo/tools/StablehloOptMain.cpp @@ -26,12 +26,14 @@ limitations under the License. #include "stablehlo/tests/CheckOps.h" #include "stablehlo/tests/TestUtils.h" #include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" int main(int argc, char **argv) { mlir::registerAllPasses(); mlir::hlo::registerAllTestPasses(); mlir::stablehlo::registerPassPipelines(); mlir::stablehlo::registerPasses(); + mlir::stablehlo::registerOptimizationPasses(); mlir::stablehlo::registerStablehloLinalgTransformsPasses(); mlir::stablehlo::registerInterpreterTransformsPasses(); mlir::tosa::registerStablehloTOSATransformsPasses(); diff --git a/stablehlo/transforms/CMakeLists.txt b/stablehlo/transforms/CMakeLists.txt index f5193cb61e..c07cc6a971 100644 --- a/stablehlo/transforms/CMakeLists.txt +++ b/stablehlo/transforms/CMakeLists.txt @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +add_subdirectory(optimization) + set(LLVM_TARGET_DEFINITIONS Passes.td) mlir_tablegen(Passes.h.inc -gen-pass-decls) add_public_tablegen_target(PassesIncGen) @@ -20,10 +22,6 @@ set(LLVM_TARGET_DEFINITIONS ChloDecompositionPatterns.td) mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters) add_public_tablegen_target(ChloDecompositionPatternsIncGen) -set(LLVM_TARGET_DEFINITIONS StablehloAggressiveSimplificationPatterns.td) -mlir_tablegen(StablehloAggressiveSimplificationPatterns.h.inc --gen-rewriters) -add_public_tablegen_target(StablehloAggressiveSimplificationPatternsIncGen) - set(LLVM_TARGET_DEFINITIONS StablehloCompatibilityExpanderPatterns.td) mlir_tablegen(StablehloCompatibilityExpanderPatterns.h.inc --gen-rewriters) add_public_tablegen_target(StablehloCompatibilityExpanderPatternsIncGen) @@ -46,8 +44,6 @@ add_mlir_dialect_library(StablehloPasses ChloLegalizeToStablehlo.cpp PassPipelines.cpp ShapeLegalizeToStablehlo.cpp - StablehloAggressiveFolder.cpp - StablehloAggressiveSimplification.cpp StablehloCanonicalizeDynamism.cpp StablehloConvertToSignless.cpp StablehloCompatibilityExpander.cpp @@ -67,7 +63,6 @@ add_mlir_dialect_library(StablehloPasses DEPENDS ChloDecompositionPatternsIncGen PassesIncGen - StablehloAggressiveSimplificationPatternsIncGen StablehloCompatibilityExpanderPatternsIncGen StablehloComplexMathExpanderPatternsIncGen StablehloLegalizeDeprecatedOpsPatternsIncGen @@ -93,6 +88,7 @@ add_mlir_dialect_library(StablehloPasses StablehloBroadcastUtils StablehloLinalgTransforms StablehloOps + StablehloOptimizationPasses StablehloTypeInference VhloOps ) diff --git a/stablehlo/transforms/Passes.h b/stablehlo/transforms/Passes.h index 3fede6e9eb..055768cacd 100644 --- a/stablehlo/transforms/Passes.h +++ b/stablehlo/transforms/Passes.h @@ -72,11 +72,6 @@ void populateChloToStablehloPatterns(MLIRContext *context, void populateChloConstantLikePattern(MLIRContext *context, RewritePatternSet *patterns); -/// Collection of folding patterns for StableHLO. -void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns, - MLIRContext *context, - bool foldFloat); - /// Collection of rewrite patterns for lowering quantized StableHLO operations /// using uniform dequantize/quantize operations. void populateStablehloLegalizeQuantizedOpToQDQPatterns( @@ -88,17 +83,6 @@ void populateStablehloLegalizeQuantizedOpToQDQPatterns( void populateStablehloLegalizeQDQToQuantizedOpPatterns( RewritePatternSet *patterns, MLIRContext *context); -/// A subset of folding patterns for StableHLO that is necessary for shape -/// refinement. -void populateStablehloShapeFolderPatterns(RewritePatternSet *patterns, - MLIRContext *context, - bool foldFloat = false); - -/// Collection of canonicalization patterns for StableHLO. -void populateStablehloCanonicalizationPatterns(MLIRContext *context, - RewritePatternSet *patterns, - PatternBenefit benefit = 1); - /// Collection of patterns to upgrade deprecated ops to long-term supported ops. void populateStablehloLegalizeDeprecatedOpsPatterns( MLIRContext *context, RewritePatternSet *patterns); diff --git a/stablehlo/transforms/Passes.td b/stablehlo/transforms/Passes.td index aa9d696664..e0d9f317e1 100644 --- a/stablehlo/transforms/Passes.td +++ b/stablehlo/transforms/Passes.td @@ -36,27 +36,6 @@ def ShapeLegalizeToStablehloPass : Pass<"shape-legalize-to-stablehlo", "func::Fu let dependentDialects = ["mlir::stablehlo::StablehloDialect"]; } -def StablehloAggressiveFolderPass - : Pass<"stablehlo-aggressive-folder", "func::FuncOp"> { - let summary = "Folds StableHLO operations"; - let dependentDialects = [ - "mlir::stablehlo::StablehloDialect", - "mlir::tensor::TensorDialect", - ]; - let options = [ - Option<"foldFloat", "fold-float", "bool", /*default=*/"true", - "Allow for potentially lossy computations using float type.">, - ]; -} - -def StablehloAggressiveSimplificationPass - : Pass<"stablehlo-aggressive-simplification", "func::FuncOp"> { - let summary = "Canonicalizes StableHLO operations"; - let dependentDialects = [ - "mlir::stablehlo::StablehloDialect", - ]; -} - def StablehloCanonicalizeDynamismPass : Pass<"stablehlo-canonicalize-dynamism", "func::FuncOp"> { let summary = "Canonicalizes dynamic StableHLO ops into static ops."; let description = [{ diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index 4274c09413..c8fb1e515b 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -53,6 +53,7 @@ limitations under the License. #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/dialect/TypeInference.h" #include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" #define DEBUG_TYPE "stablehlo-refine-shapes" diff --git a/stablehlo/transforms/optimization/CMakeLists.txt b/stablehlo/transforms/optimization/CMakeLists.txt new file mode 100644 index 0000000000..d43d77beed --- /dev/null +++ b/stablehlo/transforms/optimization/CMakeLists.txt @@ -0,0 +1,45 @@ +# Copyright 2025 The StableHLO Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +set(LLVM_TARGET_DEFINITIONS Passes.td) +mlir_tablegen(Passes.h.inc -gen-pass-decls -name=Optimization) +add_public_tablegen_target(OptimizationPassesIncGen) + +set(LLVM_TARGET_DEFINITIONS StablehloAggressiveSimplificationPatterns.td) +mlir_tablegen(StablehloAggressiveSimplificationPatterns.h.inc --gen-rewriters) +add_public_tablegen_target(StablehloAggressiveSimplificationPatternsIncGen) + +add_mlir_dialect_library(StablehloOptimizationPasses + PARTIAL_SOURCES_INTENDED + StablehloAggressiveFolder.cpp + StablehloAggressiveSimplification.cpp + StablehloTargetIndependentOptimization.cpp + + DEPENDS + OptimizationPassesIncGen + StablehloAggressiveSimplificationPatternsIncGen + + LINK_LIBS PUBLIC + ChloOps + MLIRArithDialect + MLIRDialectUtils + MLIRFuncDialect + MLIRIR + MLIRRewrite + MLIRSupport + MLIRTransformUtils + StablehloBase + StablehloOps + StablehloTypeInference +) diff --git a/stablehlo/transforms/optimization/Passes.h b/stablehlo/transforms/optimization/Passes.h new file mode 100644 index 0000000000..5aa7ea55af --- /dev/null +++ b/stablehlo/transforms/optimization/Passes.h @@ -0,0 +1,56 @@ +/* Copyright 2025 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H +#define STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H + +#include + +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/IR/BuiltinOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/dialect/StablehloOps.h" + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DECL +#define GEN_PASS_REGISTRATION +#include "stablehlo/transforms/optimization/Passes.h.inc" + +/// Collection of canonicalization patterns for StableHLO. +void populateStablehloCanonicalizationPatterns(MLIRContext *context, + RewritePatternSet *patterns, + PatternBenefit benefit = 1); + +/// Collection of folding patterns for StableHLO. +void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns, + MLIRContext *context, + bool foldFloat, + PatternBenefit benefit = 1); + +/// A subset of folding patterns for StableHLO that is necessary for shape +/// refinement. +void populateStablehloShapeFolderPatterns(RewritePatternSet *patterns, + MLIRContext *context, + bool foldFloat = false, + PatternBenefit benefit = 1); +} // namespace stablehlo +} // namespace mlir + +#endif // STABLEHLO_TRANSFORMS_OPTIMIZATION_PASSES_H diff --git a/stablehlo/transforms/optimization/Passes.td b/stablehlo/transforms/optimization/Passes.td new file mode 100644 index 0000000000..d766b249bc --- /dev/null +++ b/stablehlo/transforms/optimization/Passes.td @@ -0,0 +1,141 @@ +/* Copyright 2025 The StableHLO Authors. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +include "mlir/Pass/PassBase.td" + +def StablehloAggressiveFolderPass + : Pass<"stablehlo-aggressive-folder", "func::FuncOp"> { + let summary = "Folds StableHLO operations"; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + ]; + let options = [ + Option<"foldFloat", "fold-float", "bool", /*default=*/"true", + "Allow for potentially lossy computations using float type.">, + ]; +} + +def StablehloAggressiveSimplificationPass + : Pass<"stablehlo-aggressive-simplification", "func::FuncOp"> { + let summary = "Canonicalizes StableHLO operations"; + let description = [{ + + + + Note: Prefer StablehloTargetIndependentOptimizationPass to get best results. + + Performs graph simplifications, including: + + ``` + - add(cst, X) -> add(X, cst) + - add(X, 0) -> X + - and(cst, X) -> and(X, cst) + - and(X, 0) -> 0 + - and(X, 1) -> X + - broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB)) + - broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank] + - broadcast_in_dim(X, [iota...]) -> X + - broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel] + - compare(cst, X, comparator) -> compare(X, cst, inv(comparator)) + - compare(X, X, [EQ,GE,LE]) -> true + - compare(X, X, [NE,GT,LT]) -> false + - complex(real(X), imag(X))) -> X + - concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z) + - concatenate(X) -> X + - concatenate(X, Y, []) -> concatenate(X, Y) + - convert(X, [X.type]) -> X + - dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB)) + - dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape) + - dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X) + - dynamic_broadcast_in_dim(X, shape_of(X)) -> X + - dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes) + - dynamic_iota(shape, dim) -> + - dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior) + - dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape) + - dynamic_reshape(op(dynamic_reshape(X, shape)), shape) + - dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes) + - dynamic_update_slice(X, update, start_indices : zero)) -> update + - dynamic_update_slice(X, update : zero_extent)) -> X + - gather(X, cst_start_indices) -> slice(X, slice_start, slice_end) + - get_dimension_size(X, i) -> X.shape[i] + - get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i + - imag(complex(R,I)) -> I + - iota(dim) : multi_rank + - iota(dim) : type -> constant(0) : type [if type[dim] == 1] + - max(cst, X) -> max(X, cst) + - minimum(cst, X) -> minimum(X, cst) + - multiply(cst, X) -> multiply(X, cst) + - multiply(X, 0i) -> 0i + - multiply(X, 1i) -> X + - op(X : zero_extent_tensor) -> constant([]) + - or(cst, X) -> or(X, cst) + - or(X, 0) -> X + - or(X, 1) -> 1 + - pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _) + - real(complex(R,I)) -> X + - real_dynamic_slice(X, start, limit, strides) + - real_dynamic_slice(X, start, limit, strides) + - reduce[A](_, _, fn:return A) -> A... + - reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...] + - reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)] + - reduce(X..., dims=[], add) -> X... + - reshape(reshape(X, _), [shape]) -> reshape(X, [shape]) + - reshape(X, [X.shape]) -> X + - select(broadcast(not(p)), t, f) => select(broadcast(p), f, t) + - select(not(p), t, f) => select(p, f, t) + - shape_of(dynamic_reshape(X, shape)) -> shape + - slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z)) + - slice(X, [A:A], [B:B], ...) -> X + - sort(X) -> sort(X, dim = N) [when dim can be inferred] + - sort(X,Y) -> sort(X) [if Y unused and unused in comparator] + - subtract(X, 0) -> X + - subtract(X, X) -> 0 + - transpose(X, [iota...]) -> X + - transpose(X, [no_mem_layout_change...]) -> reshape(X) + - tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X + - while -> while (loop invariants as implicit captures) + - xor(cst, X) -> xor(X, cst) + - (+more) + ``` + + This list is pulled from code comments so is not fully exhaustive, but has + high coverage of the pass today. + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + "mlir::arith::ArithDialect", + ]; +} + +def StablehloTargetIndependentOptimizationPass + : Pass<"stablehlo-target-independent-optimization", "func::FuncOp"> { + let summary = "Runs canonicalizers, folders, and other target-independent optimizations."; + let description = [{ + Uses patterns from StablehloAggressiveSimplificationPass and + StablehloAggressiveFolderPass together, allowing canonicalization and + folding to be performed in the same pattern set, often leading to better + results. + + Users should prefer this pass to calling the others directly. + }]; + let dependentDialects = [ + "mlir::stablehlo::StablehloDialect", + ]; +} diff --git a/stablehlo/transforms/StablehloAggressiveFolder.cpp b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp similarity index 94% rename from stablehlo/transforms/StablehloAggressiveFolder.cpp rename to stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp index 49942c4241..2b5198b496 100644 --- a/stablehlo/transforms/StablehloAggressiveFolder.cpp +++ b/stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp @@ -26,7 +26,6 @@ limitations under the License. #include "llvm/Support/ErrorHandling.h" #include "mlir/Dialect/CommonFolders.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/BuiltinTypeInterfaces.h" @@ -47,13 +46,13 @@ limitations under the License. #include "mlir/Transforms/GreedyPatternRewriteDriver.h" #include "stablehlo/dialect/Base.h" #include "stablehlo/dialect/StablehloOps.h" -#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" namespace mlir { namespace stablehlo { #define GEN_PASS_DEF_STABLEHLOAGGRESSIVEFOLDERPASS -#include "stablehlo/transforms/Passes.h.inc" +#include "stablehlo/transforms/optimization/Passes.h.inc" namespace { @@ -461,8 +460,9 @@ struct EvalConcatenateOpPattern : public OpRewritePattern { struct EvalConvertOpPattern : public OpRewritePattern { using OpRewritePattern::OpRewritePattern; - EvalConvertOpPattern(MLIRContext* context, bool foldFloat_) - : OpRewritePattern(context), foldFloat{foldFloat_} {} + EvalConvertOpPattern(MLIRContext* context, PatternBenefit benefit, + bool foldFloat_) + : OpRewritePattern(context, benefit), foldFloat{foldFloat_} {} LogicalResult matchAndRewrite(ConvertOp op, PatternRewriter& rewriter) const override { @@ -900,10 +900,11 @@ struct StablehloAggressiveFolderPass void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns, MLIRContext* context, - bool foldFloat) { - populateStablehloShapeFolderPatterns(patterns, context, foldFloat); - patterns->add(context); - patterns->add(context); + bool foldFloat, + PatternBenefit benefit) { + populateStablehloShapeFolderPatterns(patterns, context, foldFloat, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); // TODO: Consolidate FoldOp patterns // One is used by Shape Refinement, the other is a generic folder. @@ -914,27 +915,27 @@ void populateStablehloAggressiveFolderPatterns(RewritePatternSet* patterns, } void populateStablehloShapeFolderPatterns(RewritePatternSet* patterns, - MLIRContext* context, - bool foldFloat) { - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context, foldFloat); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); - patterns->add(context); + MLIRContext* context, bool foldFloat, + PatternBenefit benefit) { + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit, foldFloat); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); + patterns->add(context, benefit); } } // namespace stablehlo diff --git a/stablehlo/transforms/StablehloAggressiveSimplification.cpp b/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp similarity index 99% rename from stablehlo/transforms/StablehloAggressiveSimplification.cpp rename to stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp index ebb6c027a3..f32f8d66b6 100644 --- a/stablehlo/transforms/StablehloAggressiveSimplification.cpp +++ b/stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp @@ -48,7 +48,7 @@ #include "stablehlo/dialect/Base.h" #include "stablehlo/dialect/StablehloOps.h" #include "stablehlo/transforms/PassUtils.h" -#include "stablehlo/transforms/Passes.h" +#include "stablehlo/transforms/optimization/Passes.h" using llvm::SmallBitVector; @@ -56,7 +56,7 @@ namespace mlir { namespace stablehlo { #define GEN_PASS_DEF_STABLEHLOAGGRESSIVESIMPLIFICATIONPASS -#include "stablehlo/transforms/Passes.h.inc" +#include "stablehlo/transforms/optimization/Passes.h.inc" namespace { // This is an upper limit on how many elements can be folded by an op folder. @@ -1499,7 +1499,7 @@ struct StablehloAggressiveSimplificationPass final FrozenRewritePatternSet patterns; }; -#include "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.h.inc" +#include "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc" } // namespace void populateStablehloCanonicalizationPatterns(MLIRContext *context, diff --git a/stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td b/stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td similarity index 100% rename from stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td rename to stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td diff --git a/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp b/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp new file mode 100644 index 0000000000..74f2698000 --- /dev/null +++ b/stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp @@ -0,0 +1,72 @@ +/* Copyright 2025 The StableHLO Authors. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include +#include + +#include "mlir/IR/BuiltinTypeInterfaces.h" +#include "mlir/IR/Diagnostics.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Rewrite/FrozenRewritePatternSet.h" +#include "mlir/Support/LLVM.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "stablehlo/transforms/optimization/Passes.h" + +namespace mlir { +namespace stablehlo { + +#define GEN_PASS_DEF_STABLEHLOTARGETINDEPENDENTOPTIMIZATIONPASS +#include "stablehlo/transforms/optimization/Passes.h.inc" + +// This is an upper limit on how many elements can be folded by an op folder. +// This limit doesn't apply to some special cases like adding a zero, +// multiplying by one, doing many operations with splats. +constexpr int64_t kFoldOpEltLimit = 65536; + +struct StablehloTargetIndependentOptimizationPass + : public impl::StablehloTargetIndependentOptimizationPassBase< + StablehloTargetIndependentOptimizationPass> { + using StablehloTargetIndependentOptimizationPassBase:: + StablehloTargetIndependentOptimizationPassBase; + + LogicalResult initialize(MLIRContext* context) override { + RewritePatternSet patterns_(context); + bool foldFloat = false; + populateStablehloCanonicalizationPatterns(context, &patterns_); + populateStablehloAggressiveFolderPatterns(&patterns_, context, foldFloat, + /*benefit=*/2); + patterns = std::move(patterns_); + + return success(); + } + + void runOnOperation() override { + GreedyRewriteConfig config; + config.fold = true; + config.cseConstants = true; + config.maxIterations = kFoldOpEltLimit; + config.useTopDownTraversal = false; + if (failed(applyPatternsGreedily(getOperation(), patterns, config))) + signalPassFailure(); + } + + private: + FrozenRewritePatternSet patterns; +}; + +} // namespace stablehlo +} // namespace mlir