From 928fede79d1c9fe9d6fbb4f18b865cbcdafe9b1f Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Tue, 10 Dec 2024 14:30:51 -0800 Subject: [PATCH] cleanup --- .../scheduler/tools/loop_domain_scheduler.cpp | 21 +++++++++++-------- csrc/scheduler/tools/loop_domain_scheduler.h | 6 +++++- 2 files changed, 17 insertions(+), 10 deletions(-) diff --git a/csrc/scheduler/tools/loop_domain_scheduler.cpp b/csrc/scheduler/tools/loop_domain_scheduler.cpp index 818eceb816a..f04a2f2271e 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.cpp +++ b/csrc/scheduler/tools/loop_domain_scheduler.cpp @@ -102,8 +102,9 @@ class LoopDomainScheduler { public: LoopDomainScheduler( std::vector ref_loop_dom, - bool update_mode = false) - : ref_loop_dom_(std::move(ref_loop_dom)), update_mode_(update_mode) { + bool update_loop_domain_only = false) + : ref_loop_dom_(std::move(ref_loop_dom)), + update_loop_domain_only_(update_loop_domain_only) { NVF_ERROR(!ref_loop_dom_.empty()); // For now, ref must not be a broadcast domain @@ -178,7 +179,7 @@ class LoopDomainScheduler { std::vector ref_loop_dom_; // If true, uses the current loop domain as the starting domain and // updates it to make it look like the given reference loop domain - bool update_mode_ = false; + bool update_loop_domain_only_ = false; std::unique_ptr id_model_; ValGroups ref_id_groups_; ValGroups all_ancestors_of_ref_; @@ -197,9 +198,10 @@ void LoopDomainScheduler::schedule(TensorView* tv) const { std::unordered_map group_to_id; ValGroups all_id_groups; // When update_mode_ is true, only the loop domain IDs are reused as - // we attemp to transform the current loop domain to look like the + // we attempt to transform the current loop domain to look like the // reference loop domain. - auto all_ids = update_mode_ ? tv->getLoopDomain() : tv->domain()->allIDs(); + auto all_ids = + update_loop_domain_only_ ? tv->getLoopDomain() : tv->domain()->allIDs(); for (auto id : all_ids) { const auto& group = graph().toGroup(id); group_to_id.emplace(group, id); @@ -330,7 +332,8 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { // In the case of the update mode, the target should be just the // current loop domain of the tensor. ValGroups tv_target_domains = graph().toGroups(TensorDomain::noBroadcasts( - update_mode_ ? tv->getLoopDomain() : tv->getMaybeRootDomain())); + update_loop_domain_only_ ? tv->getLoopDomain() + : tv->getMaybeRootDomain())); // If all the target domains are an ancestor of the reference // domains, just a single backward BFS should be enough to find a @@ -353,7 +356,7 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { // In the case of the update mode, the path from the reference is // assumed to just a backward traversal path. NVF_ERROR( - !update_mode_, + !update_loop_domain_only_, "Trying to update the current loop domain but could not find a valid path from the reference: ", tv->toString()); @@ -394,12 +397,12 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const { void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_mode) { + bool update_loop_domain_only) { if (tvs.empty()) { return; } - LoopDomainScheduler scheduler(ref_loop_dom, update_mode); + LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only); for (auto tv : tvs) { // Loop domain of fusion inputs should have no meaning diff --git a/csrc/scheduler/tools/loop_domain_scheduler.h b/csrc/scheduler/tools/loop_domain_scheduler.h index 39d6ebabb49..5939c9d31e2 100644 --- a/csrc/scheduler/tools/loop_domain_scheduler.h +++ b/csrc/scheduler/tools/loop_domain_scheduler.h @@ -20,10 +20,14 @@ namespace scheduler_tools { // Create the loop domain of given tensors as specified by the // reference. The new loop domain is connected to the existing IDs of // each tensor by replaying exprs found in the Exact ValGraph. +// +// If update_loop_domain_only is true, uses the current loop domain as +// the starting domain and updates it to make it look like the given +// reference loop domain. void scheduleLoopDomainsLike( const std::vector& tvs, const std::vector& ref_loop_dom, - bool update_mode = false); + bool update_loop_domain_only = false); // Replay a transform expr on the loop domain of each of the given // tensors. If the input of the transform is exact mapped with the loop