Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
naoyam committed Dec 10, 2024
1 parent 6fa5165 commit 928fede
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 10 deletions.
21 changes: 12 additions & 9 deletions csrc/scheduler/tools/loop_domain_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,9 @@ class LoopDomainScheduler {
public:
LoopDomainScheduler(
std::vector<IterDomain*> ref_loop_dom,
bool update_mode = false)
: ref_loop_dom_(std::move(ref_loop_dom)), update_mode_(update_mode) {
bool update_loop_domain_only = false)
: ref_loop_dom_(std::move(ref_loop_dom)),
update_loop_domain_only_(update_loop_domain_only) {
NVF_ERROR(!ref_loop_dom_.empty());

// For now, ref must not be a broadcast domain
Expand Down Expand Up @@ -178,7 +179,7 @@ class LoopDomainScheduler {
std::vector<IterDomain*> ref_loop_dom_;
// If true, uses the current loop domain as the starting domain and
// updates it to make it look like the given reference loop domain
bool update_mode_ = false;
bool update_loop_domain_only_ = false;
std::unique_ptr<IdModel> id_model_;
ValGroups ref_id_groups_;
ValGroups all_ancestors_of_ref_;
Expand All @@ -197,9 +198,10 @@ void LoopDomainScheduler::schedule(TensorView* tv) const {
std::unordered_map<ValGroup, IterDomain*> group_to_id;
ValGroups all_id_groups;
// When update_mode_ is true, only the loop domain IDs are reused as
// we attemp to transform the current loop domain to look like the
// we attempt to transform the current loop domain to look like the
// reference loop domain.
auto all_ids = update_mode_ ? tv->getLoopDomain() : tv->domain()->allIDs();
auto all_ids =
update_loop_domain_only_ ? tv->getLoopDomain() : tv->domain()->allIDs();
for (auto id : all_ids) {
const auto& group = graph().toGroup(id);
group_to_id.emplace(group, id);
Expand Down Expand Up @@ -330,7 +332,8 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const {
// In the case of the update mode, the target should be just the
// current loop domain of the tensor.
ValGroups tv_target_domains = graph().toGroups(TensorDomain::noBroadcasts(
update_mode_ ? tv->getLoopDomain() : tv->getMaybeRootDomain()));
update_loop_domain_only_ ? tv->getLoopDomain()
: tv->getMaybeRootDomain()));

// If all the target domains are an ancestor of the reference
// domains, just a single backward BFS should be enough to find a
Expand All @@ -353,7 +356,7 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const {
// In the case of the update mode, the path from the reference is
// assumed to just a backward traversal path.
NVF_ERROR(
!update_mode_,
!update_loop_domain_only_,
"Trying to update the current loop domain but could not find a valid path from the reference: ",
tv->toString());

Expand Down Expand Up @@ -394,12 +397,12 @@ ValGraphBFS::ExprPath LoopDomainScheduler::getReplayPath(TensorView* tv) const {
void scheduleLoopDomainsLike(
const std::vector<TensorView*>& tvs,
const std::vector<IterDomain*>& ref_loop_dom,
bool update_mode) {
bool update_loop_domain_only) {
if (tvs.empty()) {
return;
}

LoopDomainScheduler scheduler(ref_loop_dom, update_mode);
LoopDomainScheduler scheduler(ref_loop_dom, update_loop_domain_only);

for (auto tv : tvs) {
// Loop domain of fusion inputs should have no meaning
Expand Down
6 changes: 5 additions & 1 deletion csrc/scheduler/tools/loop_domain_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,14 @@ namespace scheduler_tools {
// Create the loop domain of given tensors as specified by the
// reference. The new loop domain is connected to the existing IDs of
// each tensor by replaying exprs found in the Exact ValGraph.
//
// If update_loop_domain_only is true, uses the current loop domain as
// the starting domain and updates it to make it look like the given
// reference loop domain.
void scheduleLoopDomainsLike(
const std::vector<TensorView*>& tvs,
const std::vector<IterDomain*>& ref_loop_dom,
bool update_mode = false);
bool update_loop_domain_only = false);

// Replay a transform expr on the loop domain of each of the given
// tensors. If the input of the transform is exact mapped with the loop
Expand Down

0 comments on commit 928fede

Please sign in to comment.