From 550d946af59c1af094648bc22bf5b6005e8e9c52 Mon Sep 17 00:00:00 2001 From: mlevesquedion Date: Thu, 1 Feb 2024 15:40:32 -0800 Subject: [PATCH] Perform initialization in pass initialize method (#1966) This is unlikely to give us performance gains since most of our passes run on modules anyway (so the initialization probably already occurs only once), but it is cleaner to separate the initialization of a pass from the actual running of the pass. The code in `initialize` will run when the pass runs regardless of whether there is at least one instance of the target operation. However, modules are pretty much always present so this is unlikely to change anything. --- .../transforms/StablehloLegalizeToLinalg.cpp | 28 +++++++++++------- stablehlo/tests/TestUtils.cpp | 14 +++++++-- .../StablehloCanonicalizeDynamism.cpp | 21 +++++++++----- .../transforms/StablehloLegalizeToVhlo.cpp | 29 ++++++++++++------- .../transforms/StablehloRefineShapes.cpp | 26 ++++++++++------- .../transforms/VhloLegalizeToStablehlo.cpp | 29 ++++++++++++------- stablehlo/transforms/VhloToVersion.cpp | 20 ++++++++----- 7 files changed, 107 insertions(+), 60 deletions(-) diff --git a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp index 8631dcb980..0cfef4dcbd 100644 --- a/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp +++ b/stablehlo/conversions/linalg/transforms/StablehloLegalizeToLinalg.cpp @@ -2573,27 +2573,33 @@ struct StablehloLegalizeToLinalgPass : impl::StablehloLegalizeToLinalgPassBase { using StablehloLegalizeToLinalgPassBase::StablehloLegalizeToLinalgPassBase; - void runOnOperation() override { - auto *context = &getContext(); - auto target = ConversionTarget{*context}; - auto patterns = RewritePatternSet{context}; - auto typeConverter = std::make_unique(); - - target.addLegalDialect< + LogicalResult initialize(MLIRContext *context) override { + target = std::make_shared(*context); + target->addLegalDialect< bufferization::BufferizationDialect, arith::ArithDialect, complex::ComplexDialect, linalg::LinalgDialect, math::MathDialect, tensor::TensorDialect, sparse_tensor::SparseTensorDialect, scf::SCFDialect, shape::ShapeDialect>(); - target.addLegalOp(); + target->addLegalOp(); - populateConversionPatterns(context, *typeConverter, &patterns, + RewritePatternSet patterns_(context); + populateConversionPatterns(context, converter, &patterns_, enablePrimitiveOps); + patterns = std::move(patterns_); + + return success(); + } - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { + void runOnOperation() override { + if (failed(applyPartialConversion(getOperation(), *target, patterns))) { return signalPassFailure(); } } + + private: + std::shared_ptr target; + FrozenRewritePatternSet patterns; + LinalgTypeConverter converter; }; } // namespace } // namespace mlir::stablehlo diff --git a/stablehlo/tests/TestUtils.cpp b/stablehlo/tests/TestUtils.cpp index 0b5f065403..fac6486058 100644 --- a/stablehlo/tests/TestUtils.cpp +++ b/stablehlo/tests/TestUtils.cpp @@ -94,14 +94,22 @@ struct ReifyReturnTypeShapesPattern : public RewritePattern { #include "stablehlo/tests/TestUtils.h.inc" struct HloTestInferPass : public impl::HloTestInferPassBase { + LogicalResult initialize(MLIRContext *context) override { + RewritePatternSet patterns_(context); + patterns_.add(context); + patterns_.add(context); + patterns = std::move(patterns_); + return success(); + } + void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - patterns.add(&getContext()); if (failed( applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) return signalPassFailure(); } + + private: + FrozenRewritePatternSet patterns; }; #define GEN_PASS_REGISTRATION diff --git a/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp b/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp index 6f81d8fe7e..dd6f130943 100644 --- a/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp +++ b/stablehlo/transforms/StablehloCanonicalizeDynamism.cpp @@ -298,24 +298,31 @@ struct StablehloCanonicalizeDynamismPass using StablehloCanonicalizeDynamismPassBase:: StablehloCanonicalizeDynamismPassBase; - void runOnOperation() override { - GreedyRewriteConfig config; + LogicalResult initialize(MLIRContext* context) override { config.useTopDownTraversal = true; config.enableRegionSimplification = true; config.maxIterations = 2; config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; config.strictMode = GreedyRewriteStrictness::AnyOp; - RewritePatternSet patterns(&getContext()); - populateStablehloCanonicalizeDynamismPatterns(&patterns, &getContext()); + RewritePatternSet patterns_(context); + populateStablehloCanonicalizeDynamismPatterns(&patterns_, context); + patterns = std::move(patterns_); + + return success(); + } + + void runOnOperation() override { auto func = getOperation(); - if (failed( - applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { + if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { func.emitError("Failed to converge StablehloCanonicalizeDynamism in ") << config.maxIterations << " iterations"; - return signalPassFailure(); } } + + private: + FrozenRewritePatternSet patterns; + GreedyRewriteConfig config; }; } // namespace diff --git a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp index 8026a6a215..d40accc2c6 100644 --- a/stablehlo/transforms/StablehloLegalizeToVhlo.cpp +++ b/stablehlo/transforms/StablehloLegalizeToVhlo.cpp @@ -862,24 +862,31 @@ void populateStablehloToVhloPatterns(RewritePatternSet* patterns, struct StablehloLegalizeToVhloPass : public impl::StablehloLegalizeToVhloPassBase< StablehloLegalizeToVhloPass> { - void runOnOperation() override { - ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addIllegalDialect(); - target.addLegalDialect(); + LogicalResult initialize(MLIRContext* context) override { + target = std::make_shared(*context); + target->addIllegalDialect(); + target->addIllegalDialect(); + target->addLegalDialect(); + + RewritePatternSet patterns_(context); + stablehlo::populateStablehloToVhloPatterns(&patterns_, &converter, context); + patterns = std::move(patterns_); - StablehloToVhloTypeConverter converter; - RewritePatternSet patterns(&getContext()); - stablehlo::populateStablehloToVhloPatterns(&patterns, &converter, - &getContext()); + return success(); + } + void runOnOperation() override { // StableHLO should always be convertible to VHLO. - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { + if (failed(applyPartialConversion(getOperation(), *target, patterns))) { LLVM_DEBUG(llvm::dbgs() << "Failed partial conversion\n"); return signalPassFailure(); } } + + private: + StablehloToVhloTypeConverter converter; + FrozenRewritePatternSet patterns; + std::shared_ptr target; }; void populateStablehloToVhloPatterns(RewritePatternSet* patterns, diff --git a/stablehlo/transforms/StablehloRefineShapes.cpp b/stablehlo/transforms/StablehloRefineShapes.cpp index cfc29fe6d5..566a9c6f63 100644 --- a/stablehlo/transforms/StablehloRefineShapes.cpp +++ b/stablehlo/transforms/StablehloRefineShapes.cpp @@ -1064,33 +1064,39 @@ struct StablehloRefineShapesPass : public impl::StablehloRefineShapesPassBase { using StablehloRefineShapesPassBase::StablehloRefineShapesPassBase; - void runOnOperation() override { - auto func = getStablehloRefineShapesTarget(getOperation()); - if (!func) return signalPassFailure(); - + LogicalResult initialize(MLIRContext* context) override { // The algorithm behind this pass consists of a single traversal of the // function. This is sufficient because we only support one function per // program at the moment. // TODO(#1048): Find out why .maxIterations = 1 no longer works. // There have been recent refactors to applyPatternsAndFoldGreedily // upstream, and that might be the reason. - GreedyRewriteConfig config; config.useTopDownTraversal = true; config.enableRegionSimplification = true; config.maxIterations = 2; config.maxNumRewrites = GreedyRewriteConfig::kNoLimit; config.strictMode = GreedyRewriteStrictness::AnyOp; - RewritePatternSet patterns(&getContext()); + RewritePatternSet patterns_(context); + populateStablehloRefineShapesPatterns(&patterns_, context); + patterns = std::move(patterns_); + + return success(); + } + + void runOnOperation() override { + auto func = getStablehloRefineShapesTarget(getOperation()); + if (!func) return signalPassFailure(); - populateStablehloRefineShapesPatterns(&patterns, &getContext()); - if (failed( - applyPatternsAndFoldGreedily(func, std::move(patterns), config))) { + if (failed(applyPatternsAndFoldGreedily(func, patterns, config))) { func.emitError("Failed to converge StablehloRefineShapes in ") << config.maxIterations << " iterations"; - return signalPassFailure(); } } + + private: + FrozenRewritePatternSet patterns; + GreedyRewriteConfig config; }; } // namespace diff --git a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp index 1f50eea00a..9a61fdf11a 100644 --- a/stablehlo/transforms/VhloLegalizeToStablehlo.cpp +++ b/stablehlo/transforms/VhloLegalizeToStablehlo.cpp @@ -872,25 +872,32 @@ void populateVhloToStablehloPatterns(RewritePatternSet* patterns, struct VhloLegalizeToStablehloPass : public impl::VhloLegalizeToStablehloPassBase< VhloLegalizeToStablehloPass> { - void runOnOperation() override { - ConversionTarget target(getContext()); - target.addIllegalDialect(); - target.addLegalDialect(); - target.addLegalDialect(); + LogicalResult initialize(MLIRContext* context) override { + target = std::make_shared(*context); + target->addIllegalDialect(); + target->addLegalDialect(); + target->addLegalDialect(); + + RewritePatternSet patterns_(context); + stablehlo::populateVhloToStablehloPatterns(&patterns_, &converter, context); + patterns = std::move(patterns_); - VhloToStablehloTypeConverter converter; - RewritePatternSet patterns(&getContext()); - stablehlo::populateVhloToStablehloPatterns(&patterns, &converter, - &getContext()); + return success(); + } + void runOnOperation() override { // Upgraded VHLO should always be convertible to StableHLO. // Arbitrary VHLO might not be convertible if it uses deprecated features // which are no longer available in StableHLO. - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) { + if (failed(applyPartialConversion(getOperation(), *target, patterns))) { return signalPassFailure(); } } + + private: + VhloToStablehloTypeConverter converter; + FrozenRewritePatternSet patterns; + std::shared_ptr target; }; void populateVhloToStablehloPatterns(RewritePatternSet* patterns, diff --git a/stablehlo/transforms/VhloToVersion.cpp b/stablehlo/transforms/VhloToVersion.cpp index a8cf57d1aa..81bdd32959 100644 --- a/stablehlo/transforms/VhloToVersion.cpp +++ b/stablehlo/transforms/VhloToVersion.cpp @@ -216,6 +216,14 @@ struct VhloToVersionPass : public VhloToVersionPassBase { VhloToVersionPass(const VhloToVersionPassOptions& opts) : VhloToVersionPassBase(opts) {} + LogicalResult initialize(MLIRContext* context) override { + RewritePatternSet patterns_(context); + stablehlo::populateVhloToVersionPatterns(&patterns_, &converter, context); + patterns = std::move(patterns_); + + return success(); + } + void runOnOperation() override { ConversionTarget target(getContext()); @@ -248,16 +256,14 @@ struct VhloToVersionPass : public VhloToVersionPassBase { return isLegalOperation(op, targetVersion); }); - vhlo::VhloToVersionConverter converter; - RewritePatternSet patterns(&getContext()); - stablehlo::populateVhloToVersionPatterns(&patterns, &converter, - &getContext()); - // Conversions within VHLO may fail if new features or ops are used. - if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + if (failed(applyPartialConversion(getOperation(), target, patterns))) return signalPassFailure(); } + + private: + vhlo::VhloToVersionConverter converter; + FrozenRewritePatternSet patterns; }; ////////////////////////////////////////////