Skip to content

Commit

Permalink
Add target independent optimization pass
Browse files Browse the repository at this point in the history
  • Loading branch information
GleasonK committed Feb 4, 2025
1 parent 21e8078 commit 7b649f2
Show file tree
Hide file tree
Showing 17 changed files with 3,212 additions and 2,952 deletions.
82 changes: 76 additions & 6 deletions BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -316,11 +316,11 @@ gentbl_cc_library(
tbl_outs = [
(
["--gen-rewriters"],
"stablehlo/transforms/StablehloAggressiveSimplificationPatterns.h.inc",
"stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/transforms/StablehloAggressiveSimplificationPatterns.td",
td_file = "stablehlo/transforms/optimization/StablehloAggressiveSimplificationPatterns.td",
deps = [
":stablehlo_ops_td_files",
],
Expand Down Expand Up @@ -1125,15 +1125,33 @@ gentbl_cc_library(
deps = ["@llvm-project//mlir:PassBaseTdFiles"],
)

cc_library(
name = "stablehlo_pass_utils",
srcs = [
"stablehlo/transforms/PassUtils.cpp",
],
hdrs = [
"stablehlo/transforms/PassUtils.h",
],
strip_include_prefix = ".",
deps = [
":base",
":chlo_ops",
":stablehlo_ops",
":stablehlo_pass_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ComplexDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Support",
],
)

cc_library(
name = "stablehlo_passes",
srcs = [
"stablehlo/transforms/ChloLegalizeToStablehlo.cpp",
"stablehlo/transforms/PassPipelines.cpp",
"stablehlo/transforms/PassUtils.cpp",
"stablehlo/transforms/ShapeLegalizeToStablehlo.cpp",
"stablehlo/transforms/StablehloAggressiveFolder.cpp",
"stablehlo/transforms/StablehloAggressiveSimplification.cpp",
"stablehlo/transforms/StablehloCanonicalizeDynamism.cpp",
"stablehlo/transforms/StablehloCompatibilityExpander.cpp",
"stablehlo/transforms/StablehloComplexMathExpander.cpp",
Expand Down Expand Up @@ -1163,13 +1181,14 @@ cc_library(
":chlo_ops",
":chlo_rewriters_inc_gen",
":linalg_passes",
":stablehlo_aggressive_simplification_inc_gen",
":stablehlo_create_compatibility_expander_inc_gen",
":stablehlo_create_complex_math_expander_inc_gen",
":stablehlo_legalize_deprecated_ops_inc_gen",
":stablehlo_ops",
":stablehlo_ops_inc_gen",
":stablehlo_pass_inc_gen",
":stablehlo_pass_utils",
":stablehlo_passes_optimization",
":stablehlo_type_inference",
":version",
":vhlo_ops",
Expand Down Expand Up @@ -1198,6 +1217,56 @@ cc_library(
],
)

gentbl_cc_library(
name = "stablehlo_passes_optimization_inc_gen",
compatible_with = get_compatible_with_portable(),
strip_include_prefix = ".",
tbl_outs = [
(
[
"-gen-pass-decls",
"-name=Optimization",
],
"stablehlo/transforms/optimization/Passes.h.inc",
),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "stablehlo/transforms/optimization/Passes.td",
deps = ["@llvm-project//mlir:PassBaseTdFiles"],
)

cc_library(
name = "stablehlo_passes_optimization",
srcs = [
"stablehlo/transforms/optimization/StablehloAggressiveFolder.cpp",
"stablehlo/transforms/optimization/StablehloAggressiveSimplification.cpp",
"stablehlo/transforms/optimization/StablehloTargetIndependentOptimization.cpp",
],
hdrs = [
"stablehlo/transforms/optimization/Passes.h",
],
strip_include_prefix = ".",
deps = [
":base",
":stablehlo_aggressive_simplification_inc_gen",
":stablehlo_ops",
":stablehlo_pass_inc_gen",
":stablehlo_pass_utils",
":stablehlo_passes_optimization_inc_gen",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:ArithDialect",
"@llvm-project//mlir:CommonFolders",
"@llvm-project//mlir:DialectUtils",
"@llvm-project//mlir:FuncDialect",
"@llvm-project//mlir:IR",
"@llvm-project//mlir:Pass",
"@llvm-project//mlir:Rewrite",
"@llvm-project//mlir:Support",
"@llvm-project//mlir:TransformUtils",
"@llvm-project//mlir:Transforms",
],
)

cc_library(
name = "stablehlo_portable_api",
srcs = [
Expand Down Expand Up @@ -1364,6 +1433,7 @@ cc_binary(
":linalg_passes",
":register",
":stablehlo_passes",
":stablehlo_passes_optimization",
":tosa_passes",
"//stablehlo/tests:check_ops",
"//stablehlo/tests:test_utils",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
// RUN: stablehlo-opt --stablehlo-target-independent-optimization --split-input-file %s | FileCheck %s

// Check that simplificaiton and folding are both applied.

// CHECK-LABEL: @add_cst_on_rhs
func.func @add_cst_on_rhs(%arg0: tensor<f32>) -> tensor<f32> {
%cst = stablehlo.constant dense<1.0> : tensor<f32>
%0 = stablehlo.add %cst, %cst : tensor<f32>
// CHECK: stablehlo.add %arg0, %cst : tensor<f32>
%1 = stablehlo.add %0, %arg0 : tensor<f32>
return %1 : tensor<f32>
}
2 changes: 2 additions & 0 deletions stablehlo/tools/StablehloOptMain.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ limitations under the License.
#include "stablehlo/tests/CheckOps.h"
#include "stablehlo/tests/TestUtils.h"
#include "stablehlo/transforms/Passes.h"
#include "stablehlo/transforms/optimization/Passes.h"

int main(int argc, char **argv) {
mlir::registerAllPasses();
mlir::hlo::registerAllTestPasses();
mlir::stablehlo::registerPassPipelines();
mlir::stablehlo::registerPasses();
mlir::stablehlo::registerOptimizationPasses();
mlir::stablehlo::registerStablehloLinalgTransformsPasses();
mlir::stablehlo::registerInterpreterTransformsPasses();
mlir::tosa::registerStablehloTOSATransformsPasses();
Expand Down
7 changes: 3 additions & 4 deletions stablehlo/transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

add_subdirectory(optimization)

set(LLVM_TARGET_DEFINITIONS Passes.td)
mlir_tablegen(Passes.h.inc -gen-pass-decls)
add_public_tablegen_target(PassesIncGen)
Expand All @@ -20,10 +22,6 @@ set(LLVM_TARGET_DEFINITIONS ChloDecompositionPatterns.td)
mlir_tablegen(ChloDecompositionPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(ChloDecompositionPatternsIncGen)

set(LLVM_TARGET_DEFINITIONS StablehloAggressiveSimplificationPatterns.td)
mlir_tablegen(StablehloAggressiveSimplificationPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(StablehloAggressiveSimplificationPatternsIncGen)

set(LLVM_TARGET_DEFINITIONS StablehloCompatibilityExpanderPatterns.td)
mlir_tablegen(StablehloCompatibilityExpanderPatterns.h.inc --gen-rewriters)
add_public_tablegen_target(StablehloCompatibilityExpanderPatternsIncGen)
Expand Down Expand Up @@ -93,6 +91,7 @@ add_mlir_dialect_library(StablehloPasses
StablehloBroadcastUtils
StablehloLinalgTransforms
StablehloOps
StablehloOptimizationPasses
StablehloTypeInference
VhloOps
)
16 changes: 0 additions & 16 deletions stablehlo/transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,11 +72,6 @@ void populateChloToStablehloPatterns(MLIRContext *context,
void populateChloConstantLikePattern(MLIRContext *context,
RewritePatternSet *patterns);

/// Collection of folding patterns for StableHLO.
void populateStablehloAggressiveFolderPatterns(RewritePatternSet *patterns,
MLIRContext *context,
bool foldFloat);

/// Collection of rewrite patterns for lowering quantized StableHLO operations
/// using uniform dequantize/quantize operations.
void populateStablehloLegalizeQuantizedOpToQDQPatterns(
Expand All @@ -88,17 +83,6 @@ void populateStablehloLegalizeQuantizedOpToQDQPatterns(
void populateStablehloLegalizeQDQToQuantizedOpPatterns(
RewritePatternSet *patterns, MLIRContext *context);

/// A subset of folding patterns for StableHLO that is necessary for shape
/// refinement.
void populateStablehloShapeFolderPatterns(RewritePatternSet *patterns,
MLIRContext *context,
bool foldFloat = false);

/// Collection of canonicalization patterns for StableHLO.
void populateStablehloCanonicalizationPatterns(MLIRContext *context,
RewritePatternSet *patterns,
PatternBenefit benefit = 1);

/// Collection of patterns to upgrade deprecated ops to long-term supported ops.
void populateStablehloLegalizeDeprecatedOpsPatterns(
MLIRContext *context, RewritePatternSet *patterns);
Expand Down
21 changes: 0 additions & 21 deletions stablehlo/transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -36,27 +36,6 @@ def ShapeLegalizeToStablehloPass : Pass<"shape-legalize-to-stablehlo", "func::Fu
let dependentDialects = ["mlir::stablehlo::StablehloDialect"];
}

def StablehloAggressiveFolderPass
: Pass<"stablehlo-aggressive-folder", "func::FuncOp"> {
let summary = "Folds StableHLO operations";
let dependentDialects = [
"mlir::stablehlo::StablehloDialect",
"mlir::tensor::TensorDialect",
];
let options = [
Option<"foldFloat", "fold-float", "bool", /*default=*/"true",
"Allow for potentially lossy computations using float type.">,
];
}

def StablehloAggressiveSimplificationPass
: Pass<"stablehlo-aggressive-simplification", "func::FuncOp"> {
let summary = "Canonicalizes StableHLO operations";
let dependentDialects = [
"mlir::stablehlo::StablehloDialect",
];
}

def StablehloCanonicalizeDynamismPass : Pass<"stablehlo-canonicalize-dynamism", "func::FuncOp"> {
let summary = "Canonicalizes dynamic StableHLO ops into static ops.";
let description = [{
Expand Down
Loading

0 comments on commit 7b649f2

Please sign in to comment.