Skip to content

Commit

Permalink
Add target independent optimization pass (#2707)
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK authored Feb 4, 2025
1 parent 21e8078 commit 2720b90
Show file tree
Hide file tree
Showing 18 changed files with 560 additions and 98 deletions.
87 changes: 81 additions & 6 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,11 @@ gentbl_cc_library(
tbl_outs = [
(
["--gen-rewriters"],
"stablehlo/transforms/StablehloAggressiveSimplificationPatterns.h.inc",
"stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td",
td_file = "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td",
deps = [
":stablehlo_ops_td_files",
],
Expand Down Expand Up @@ -1125,15 +1125,33 @@ gentbl_cc_library(
deps = ["@llvm-project//mlir:PassBaseTdFiles"],
)

cc_library(
name = "stablehlo_pass_utils",
srcs = [
"stablehlo/transforms/PassUtils.cpp",
],
hdrs = [
"stablehlo/transforms/PassUtils.h",
],
strip_include_prefix = ".",
deps = [
":base",
":chlo_ops",
":stablehlo_ops",
":stablehlo_pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

cc_library(
name = "stablehlo_passes",
srcs = [
"stablehlo/transforms/ChloLegalizeToStablehlo.cpp",
"stablehlo/transforms/PassPipelines.cpp",
"stablehlo/transforms/PassUtils.cpp",
"stablehlo/transforms/ShapeLegalizeToStablehlo.cpp",
"stablehlo/transforms/StablehloAggressiveFolder.cpp",
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
"stablehlo/transforms/StablehloCompatibilityExpander.cpp",
"stablehlo/transforms/StablehloComplexMathExpander.cpp",
Expand Down Expand Up @@ -1163,13 +1181,14 @@ cc_library(
":chlo_ops",
":chlo_rewriters_inc_gen",
":linalg_passes",
":stablehlo_aggressive_simplification_inc_gen",
":stablehlo_create_compatibility_expander_inc_gen",
":stablehlo_create_complex_math_expander_inc_gen",
":stablehlo_legalize_deprecated_ops_inc_gen",
":stablehlo_ops",
":stablehlo_ops_inc_gen",
":stablehlo_pass_inc_gen",
":stablehlo_pass_utils",
":stablehlo_passes_optimization",
":stablehlo_type_inference",
":version",
":vhlo_ops",
Expand Down Expand Up @@ -1198,6 +1217,61 @@ cc_library(
],
)

gentbl_cc_library(
name = "stablehlo_passes_optimization_inc_gen",
strip_include_prefix = ".",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=Optimization",
],
"stablehlo/transforms/optimization/Passes.h.inc",
),
(
[
"-gen-pass-doc",
],
"stablehlo/transforms/optimization/stablehlo_optimization_passes.md",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/transforms/optimization/Passes.td",
deps = ["@llvm-project//mlir:PassBaseTdFiles"],
)

cc_library(
name = "stablehlo_passes_optimization",
srcs = [
"stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp",
"stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp",
"stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp",
],
hdrs = [
"stablehlo/transforms/optimization/Passes.h",
],
strip_include_prefix = ".",
deps = [
":base",
":stablehlo_aggressive_simplification_inc_gen",
":stablehlo_ops",
":stablehlo_pass_inc_gen",
":stablehlo_pass_utils",
":stablehlo_passes_optimization_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Rewrite",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
)

cc_library(
name = "stablehlo_portable_api",
srcs = [
Expand Down Expand Up @@ -1364,6 +1438,7 @@ cc_binary(
":linalg_passes",
":register",
":stablehlo_passes",
":stablehlo_passes_optimization",
":tosa_passes",
"//stablehlo/tests:check_ops",
"//stablehlo/tests:test_utils",
Expand Down
1 change: 1 addition & 0 deletions build_tools/github_actions/ci_build_docs.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ fi

declare -A targets
targets[":stablehlo_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/transforms/stablehlo_passes.md"
targets[":stablehlo_passes_optimization_inc_gen_filegroup"]="bazel-bin/stablehlo/transforms/optimization/stablehlo_optimization_passes.md"
targets[":interpreter_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/reference/interpreter_passes.md"
targets[":linalg_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/conversions/linalg/transforms/stablehlo_linalg_passes.md"
targets[":tosa_pass_inc_gen_filegroup"]="bazel-bin/stablehlo/conversions/tosa/transforms/stablehlo_tosa_passes.md"
Expand Down
2 changes: 2 additions & 0 deletions docs/_toc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ toc:
path: /stablehlo/ide
- title: StableHLO Passes
path: /stablehlo/generated/stablehlo_passes
- title: StableHLO Optimization Passes
path: /stablehlo/generated/stablehlo_optimization_passes
- title: StableHLO Interpreter Passes
path: /stablehlo/generated/interpreter_passes
- title: StableHLO Linalg Passes
Expand Down
110 changes: 110 additions & 0 deletions docs/generated/stablehlo_optimization_passes.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
<!-- Autogenerated by mlir-tblgen; don't manually edit -->
### `-stablehlo-aggressive-folder`

_Folds StableHLO operations_



#### Options
```
-fold-float : Allow for potentially lossy computations using float type.
```
### `-stablehlo-aggressive-simplification`

_Canonicalizes StableHLO operations_

<!--
The following is generated list of patterns from code comments and can be
reconstructed using the following command:
$ cd stablehlo/transforms/optimization
$ grep "// Pattern:" *Simplification* | sed 's/.*Pattern: /- /' | sort
-->

Note: Prefer StablehloTargetIndependentOptimizationPass to get best results.

Performs graph simplifications, including:

```
- add(cst, X) -> add(X, cst)
- add(X, 0) -> X
- and(cst, X) -> and(X, cst)
- and(X, 0) -> 0
- and(X, 1) -> X
- broadcast_in_dim(broadcast_in_dim(X, [dimsA...]), [dimsB...]) -> broadcast_in_dim(X, merge(dimsA, dimsB))
- broadcast_in_dim(X, [dims...]) -> transpose(X, [dims...]) [if same numel & rank]
- broadcast_in_dim(X, [iota...]) -> X
- broadcast_in_dim(X, [sorted...]) -> reshape(X, [sorted...]) [if same numel]
- compare(cst, X, comparator) -> compare(X, cst, inv(comparator))
- compare(X, X, [EQ,GE,LE]) -> true
- compare(X, X, [NE,GT,LT]) -> false
- complex(real(X), imag(X))) -> X
- concatenate(concatenate(X, Y), Z) -> concatenate(X, Y, Z)
- concatenate(X) -> X
- concatenate(X, Y, []) -> concatenate(X, Y)
- convert(X, [X.type]) -> X
- dynamic_broadcast_in_dim(dynamic_broadcast_in_dim(X, _, [dimsA...]), shape, [dimsB...]) -> dynamic_broadcast_in_dim(X, shape, merge(dimsA, dimsB))
- dynamic_broadcast_in_dim(dynamic_reshape(X, shape), shape) -> dynamic_reshape(X, shape)
- dynamic_broadcast_in_dim(X, _, _, [all_nonexpanding...]) -> convert(X)
- dynamic_broadcast_in_dim(X, shape_of(X)) -> X
- dynamic_gather(x, constant(slice_sizes)) -> gather(x, slice_sizes)
- dynamic_iota(shape, dim) ->
- dynamic_pad(X, low, high, interior) -> pad(X, low, high, interior)
- dynamic_reshape(dynamic_reshape(X, _), shape)) -> dynamic_reshape(X, shape)
- dynamic_reshape(op(dynamic_reshape(X, shape)), shape)
- dynamic_slice(X, begin, slice_sizes) -> slice(X, begin, slice_sizes)
- dynamic_update_slice(X, update, start_indices : zero)) -> update
- dynamic_update_slice(X, update : zero_extent)) -> X
- gather(X, cst_start_indices) -> slice(X, slice_start, slice_end)
- get_dimension_size(X, i) -> X.shape[i]
- get_tuple_element(tuple(X_0, X_1, ...), i) -> X_i
- imag(complex(R,I)) -> I
- iota(dim) : multi_rank
- iota(dim) : type -> constant(0) : type [if type[dim] == 1]
- max(cst, X) -> max(X, cst)
- minimum(cst, X) -> minimum(X, cst)
- multiply(cst, X) -> multiply(X, cst)
- multiply(X, 0i) -> 0i
- multiply(X, 1i) -> X
- op(X : zero_extent_tensor) -> constant([])
- or(cst, X) -> or(X, cst)
- or(X, 0) -> X
- or(X, 1) -> 1
- pad(empty_tensor, _) -> broadcast_in_dim(empty_tensor, _)
- real(complex(R,I)) -> X
- real_dynamic_slice(X, start, limit, strides)
- real_dynamic_slice(X, start, limit, strides)
- reduce[A](_, _, fn:return A) -> A...
- reduce(empty_0, empty_1, ...) -> [broadcast_in_dim(empty_i)...]
- reduce(in_1, in_2, _, _) -> reduce(in_1, _, _) [if unused(in_2)]
- reduce(X..., dims=[], add) -> X...
- reshape(reshape(X, _), [shape]) -> reshape(X, [shape])
- reshape(X, [X.shape]) -> X
- select(broadcast(not(p)), t, f) => select(broadcast(p), f, t)
- select(not(p), t, f) => select(p, f, t)
- shape_of(dynamic_reshape(X, shape)) -> shape
- slice(concat(X,Y,Z,...),...) -> concat(slice(X),slice(Y),slice(Z))
- slice(X, [A:A], [B:B], ...) -> X
- sort(X) -> sort(X, dim = N) [when dim can be inferred]
- sort(X,Y) -> sort(X) [if Y unused and unused in comparator]
- subtract(X, 0) -> X
- subtract(X, X) -> 0
- transpose(X, [iota...]) -> X
- transpose(X, [no_mem_layout_change...]) -> reshape(X)
- tuple(get_tuple_element(X, 0), get_tuple_element(X, 1), ...) -> X
- while -> while (loop invariants as implicit captures)
- xor(cst, X) -> xor(X, cst)
- (+more)
```

This list is pulled from code comments so is not fully exhaustive, but has
high coverage of the pass today.
### `-stablehlo-target-independent-optimization`

_Runs canonicalizers, folders, and other target-independent optimizations._

Uses patterns from StablehloAggressiveSimplificationPass and
StablehloAggressiveFolderPass together, allowing canonicalization and
folding to be performed in the same pattern set, often leading to better
results.

Users should prefer this pass to calling the others directly.
15 changes: 0 additions & 15 deletions docs/generated/stablehlo_passes.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,6 @@ An experimental pass that legalizes shape-related ops to StableHLO ops.
Bringing shape and data computations together via an optional pass will
make it possible for the StableHLO ecosystem to potentially leverage the
compilation pipelines that use StableHLO operations to model dynamism.
### `-stablehlo-aggressive-folder`

_Folds StableHLO operations_



#### Options
```
-fold-float : Allow for potentially lossy computations using float type.
```
### `-stablehlo-aggressive-simplification`

_Canonicalizes StableHLO operations_


### `-stablehlo-canonicalize-dynamism`

_Canonicalizes dynamic StableHLO ops into static ops._
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: stablehlo-opt --stablehlo-target-independent-optimization --split-input-file %s | FileCheck %s

// Check that simplificaiton and folding are both applied.

// CHECK-LABEL: @add_cst_on_rhs
func.func @add_cst_on_rhs(%arg0: tensor<f32>) -> tensor<f32> {
%cst = stablehlo.constant dense<1.0> : tensor<f32>
%0 = stablehlo.add %cst, %cst : tensor<f32>
// CHECK: stablehlo.add %arg0, %cst : tensor<f32>
%1 = stablehlo.add %0, %arg0 : tensor<f32>
return %1 : tensor<f32>
}
2 changes: 2 additions & 0 deletions stablehlo/tools/StablehloOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ limitations under the License.
#include "stablehlo/tests/CheckOps.h"
#include "stablehlo/tests/TestUtils.h"
#include "stablehlo/transforms/Passes.h"
#include "stablehlo/transforms/optimization/Passes.h"

int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::hlo::registerAllTestPasses();
mlir::stablehlo::registerPassPipelines();
mlir::stablehlo::registerPasses();
mlir::stablehlo::registerOptimizationPasses();
mlir::stablehlo::registerStablehloLinalgTransformsPasses();
mlir::stablehlo::registerInterpreterTransformsPasses();
mlir::tosa::registerStablehloTOSATransformsPasses();
Expand Down
10 changes: 3 additions & 7 deletions stablehlo/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

add_subdirectory(optimization)

set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(PassesIncGen)
Expand All @@ -20,10 +22,6 @@ set(LLVM_TARGET_DEFINITIONS ChloDecompositionPatterns.td)
mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(ChloDecompositionPatternsIncGen)

set(LLVM_TARGET_DEFINITIONS StablehloAggressiveSimplificationPatterns.td)
mlir_tablegen(StablehloAggressiveSimplificationPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(StablehloAggressiveSimplificationPatternsIncGen)

set(LLVM_TARGET_DEFINITIONS StablehloCompatibilityExpanderPatterns.td)
mlir_tablegen(StablehloCompatibilityExpanderPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(StablehloCompatibilityExpanderPatternsIncGen)
Expand All @@ -46,8 +44,6 @@ add_mlir_dialect_library(StablehloPasses
ChloLegalizeToStablehlo.cpp
PassPipelines.cpp
ShapeLegalizeToStablehlo.cpp
StablehloAggressiveFolder.cpp
StablehloAggressiveSimplification.cpp
StablehloCanonicalizeDynamism.cpp
StablehloConvertToSignless.cpp
StablehloCompatibilityExpander.cpp
Expand All @@ -67,7 +63,6 @@ add_mlir_dialect_library(StablehloPasses
DEPENDS
ChloDecompositionPatternsIncGen
PassesIncGen
StablehloAggressiveSimplificationPatternsIncGen
StablehloCompatibilityExpanderPatternsIncGen
StablehloComplexMathExpanderPatternsIncGen
StablehloLegalizeDeprecatedOpsPatternsIncGen
Expand All @@ -93,6 +88,7 @@ add_mlir_dialect_library(StablehloPasses
StablehloBroadcastUtils
StablehloLinalgTransforms
StablehloOps
StablehloOptimizationPasses
StablehloTypeInference
VhloOps
)
16 changes: 0 additions & 16 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ void populateChloToStablehloPatterns(MLIRContext *context,
void populateChloConstantLikePattern(MLIRContext *context,
RewritePatternSet *patterns);

/// Collection of folding patterns for StableHLO.
void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns,
MLIRContext *context,
bool foldFloat);

/// Collection of rewrite patterns for lowering quantized StableHLO operations
/// using uniform dequantize/quantize operations.
void populateStablehloLegalizeQuantizedOpToQDQPatterns(
Expand All @@ -88,17 +83,6 @@ void populateStablehloLegalizeQuantizedOpToQDQPatterns(
void populateStablehloLegalizeQDQToQuantizedOpPatterns(
RewritePatternSet *patterns, MLIRContext *context);

/// A subset of folding patterns for StableHLO that is necessary for shape
/// refinement.
void populateStablehloShapeFolderPatterns(RewritePatternSet *patterns,
MLIRContext *context,
bool foldFloat = false);

/// Collection of canonicalization patterns for StableHLO.
void populateStablehloCanonicalizationPatterns(MLIRContext *context,
RewritePatternSet *patterns,
PatternBenefit benefit = 1);

/// Collection of patterns to upgrade deprecated ops to long-term supported ops.
void populateStablehloLegalizeDeprecatedOpsPatterns(
MLIRContext *context, RewritePatternSet *patterns);
Expand Down
Loading

0 comments on commit 2720b90

Please sign in to comment.