From 19524d8229c4e1f25db6443c53ab3753126cfd04 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Fri, 13 Dec 2024 12:36:11 -0800 Subject: [PATCH] Grab all IDs and exprs with StmtSort::getAllStmts (and fix replaceSymbolicSizes) (#3578) Stacked on #3585 `StmtSort::getStmtsTo` may not grab all active iter domains if IDs are connected in an unconventional way. For example, we can set the loop domain of a tensor as a producer of its logical domain, but due to the nature of `IterVisitor`, such ID dependency patterns are not supported, meaning `StmtSort::getStmtsTo` would fail to grab all valid IDs and their exprs. I just recently noticed this issue while working on #3556, specifically the issue got exposed as an inconsistent replacement of extent vals. I've been experimenting such patterns of domains, but I hadn't seen this before, likely because I was using just static shape tensors for convenience. To fix the issue, I added a variation of `StmtSort::getStmtsTo`, which traverses a fusion as usual but stops at TensorView. For each TensorView, instead of using `IterVisitor`, it uses `TensorDomain::getAllStatements()`, which combines both `TensorDomain::allIDs()` and `TensorDomain::allExprs()`, and traverse the IDs and exprs in the returned order. It's a bit naive implementation, but I think this is good enough for now and also I don't have any other immediate idea to try. I changed `ValReplacementMutator` to use the new interface. That's the only use for now. --------- Co-authored-by: Jacob Hinkle <1454944+jacobhinkle@users.noreply.github.com> --- csrc/device_lower/pass/expr_sort.cpp | 2 +- csrc/device_lower/pass/replace_size.cpp | 1 - csrc/ir/internal_base_nodes.h | 3 + csrc/ir/nodes.cpp | 28 +++++++ csrc/ir/utils.cpp | 20 ++++- csrc/iter_visitor.cpp | 92 +++++++++++++++++++++++ csrc/iter_visitor.h | 22 ++++++ tests/cpp/test_loop_domain_scheduling.cpp | 32 ++++++++ tests/cpp/test_resize.cpp | 4 - 9 files changed, 197 insertions(+), 7 deletions(-) diff --git a/csrc/device_lower/pass/expr_sort.cpp b/csrc/device_lower/pass/expr_sort.cpp index 00fb02dffe3..8b0ec328613 100644 --- a/csrc/device_lower/pass/expr_sort.cpp +++ b/csrc/device_lower/pass/expr_sort.cpp @@ -1088,7 +1088,7 @@ bool ExprSegmentationSorter::interIterUpdate() { NVF_ERROR( !fallback_mode_enabled_, "Couldn't succcessfully sort out the fusion expressions. ", - "There are remaining connections of the heirarchical segmentation which should have been ", + "There are remaining connections of the hierarchical segmentation which should have been ", "flattened to a single ordered group, or disjoint ordered groups.\n", toString()); // We didn't finish, but we haven't tried the fallback, try again with that. diff --git a/csrc/device_lower/pass/replace_size.cpp b/csrc/device_lower/pass/replace_size.cpp index 4952f01c774..a6d46aef2d5 100644 --- a/csrc/device_lower/pass/replace_size.cpp +++ b/csrc/device_lower/pass/replace_size.cpp @@ -266,7 +266,6 @@ void replaceSymbolicSizes(Fusion* fusion) { } } - // Run mutation on the fusion with the tensor_dim_map ir_utils::replaceValue(fusion, extent_simplification_map); } diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index 58086ec1b0d..4aaafd7482d 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -610,6 +610,9 @@ class TensorDomain : public Val { // Similar to allIDs but returns all ID expressions. std::vector allExprs() const; + // Combine allIDs and allExprs + std::vector allStatements() const; + const std::vector& maybeAllocation() const { return hasAllocation() ? allocation_domain_ : logical(); }; diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 9c033a66c2c..423035367ae 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3803,6 +3803,34 @@ std::vector TensorDomain::allExprs() const { return exprs.vector(); } +std::vector TensorDomain::allStatements() const { + auto all_ids = allIDs(); + std::unordered_set all_id_set{all_ids.begin(), all_ids.end()}; + + VectorOfUniqueEntries stmts; + for (auto id : all_ids) { + // Visit definition if available and all inputs are already visited + auto def = id->definition(); + if (def != nullptr) { + if (std::all_of( + def->inputs().begin(), def->inputs().end(), [&](Val* inp) { + return all_id_set.find(inp) != all_id_set.end(); + })) { + stmts.pushBack(def); + } else { + NVF_ERROR(std::none_of( + def->inputs().begin(), def->inputs().end(), [&](Val* inp) { + return all_id_set.find(inp) != all_id_set.end(); + })); + } + } + + stmts.pushBack(id); + } + + return stmts.vector(); +} + Split::Split( IrBuilderPasskey passkey, IterDomain* outer, diff --git a/csrc/ir/utils.cpp b/csrc/ir/utils.cpp index 79b52d5dbe3..7eadac6abaa 100644 --- a/csrc/ir/utils.cpp +++ b/csrc/ir/utils.cpp @@ -451,7 +451,7 @@ class ValReplacementMutator : private OptOutMutator { // typically not used by anything else. If we don't grab that count, then it // would be a tensorview that doesn't get updated extents. Therefore, first // grab all leaves towards outputs and grab stmts from there. - auto stmts = StmtSort::getStmtsTo(allLeafOuts(fusion), true, true); + auto stmts = StmtSort::getAllStmtsTo(allLeafOuts(fusion), true, true); // Some fusions, such as standalone rand_like, can have disconnected DAG, so // we need some mechanism to make sure our replacement set is as complete as @@ -501,6 +501,24 @@ class ValReplacementMutator : private OptOutMutator { std::unordered_set outputs; std::vector ordered_outputs; for (auto expr : exprs) { + // Iter domains and their exprs are taken care by traversing + // from TensorDomain with TensorDomain::allStatements, so they + // don't need to be included here + if (std::any_of( + expr->outputs().begin(), expr->outputs().end(), [](Val* output) { + return output->isA(); + })) { + NVF_ERROR(std::all_of( + expr->outputs().begin(), expr->outputs().end(), [](Val* output) { + return output->isA(); + })); + NVF_ERROR(std::all_of( + expr->inputs().begin(), expr->inputs().end(), [](Val* input) { + return input->isA(); + })); + continue; + } + inputs.insert(expr->inputs().begin(), expr->inputs().end()); outputs.insert(expr->outputs().begin(), expr->outputs().end()); ordered_outputs.insert( diff --git a/csrc/iter_visitor.cpp b/csrc/iter_visitor.cpp index 121ed1813e8..d559b714214 100644 --- a/csrc/iter_visitor.cpp +++ b/csrc/iter_visitor.cpp @@ -938,6 +938,98 @@ std::vector StmtSort::getStmtsTo( return es.stmts; } +std::vector StmtSort::getAllStmts( + Fusion* fusion, + bool traverse_members, + bool traverse_attributes, + bool traverse_siblings) { + return getAllStmtsTo( + fusion->getTerminatingOutputs(), + traverse_members, + traverse_attributes, + traverse_siblings); +} + +std::vector StmtSort::getAllStmtsTo( + const std::vector& to, + bool traverse_members, + bool traverse_attributes, + bool traverse_siblings) { + // If members are not traversed, this can just be handled by getStmts + if (!traverse_members) { + return getStmtsTo( + to, traverse_members, traverse_attributes, traverse_siblings); + } + + // to is assumed to include only scalar or TensorView + NVF_ERROR(std::all_of(to.begin(), to.end(), [](Val* to_val) { + return to_val->vtype() == ValType::TensorView || + to_val->vtype() == ValType::Others; + })); + + // First, grab all statements without traversing tensor members + auto stmts = getStmtsTo(to, false, traverse_attributes, traverse_siblings); + + VectorOfUniqueEntries all_stmts; + + // For TensorView, further traversing into its members. Note that + // traverse_members is always true here + for (auto stmt : stmts) { + auto tv = dynamic_cast(stmt); + if (tv == nullptr) { + all_stmts.pushBack(stmt); + continue; + } + + // Instead of using MemberStatements, grab the iter domains and + // their exprs with TensorDomain::allStatements(), which + // internally uses IRBFS. + auto all_id_stmts = tv->domain()->allStatements(); + + // For iter domains, traverse further their members and then visit + // themselves. For ID exprs, traverse attributes then the expr + // themselves. + for (auto id_stmt : all_id_stmts) { + if (auto id = dynamic_cast(id_stmt)) { + auto id_members = MemberStatements::next(id); + // Note that traverse_members is always true at this point + for (auto id_member : id_members) { + for (auto stmt_dep : StmtSort::getStmtsTo( + {id_member->as()}, + /*traverse_members=*/true, + traverse_attributes, + traverse_siblings)) { + all_stmts.pushBack(stmt_dep); + } + } + all_stmts.pushBack(id); + } else { + auto expr = dynamic_cast(id_stmt); + NVF_ERROR(expr != nullptr); + if (traverse_attributes) { + for (auto attr : expr->attributes()) { + for (auto stmt_dep : StmtSort::getStmtsTo( + {attr->as()}, + /*traverse_members=*/true, + traverse_attributes, + traverse_siblings)) { + all_stmts.pushBack(stmt_dep); + } + } + } + all_stmts.pushBack(expr); + } + } + + // All depednent vals and exprs for this TensorDomain are in + // all_stmts. Append TensorDomain and then TensorView + all_stmts.pushBack(tv->domain()); + all_stmts.pushBack(tv); + } + + return all_stmts.vector(); +} + std::vector StmtSort::getStmtsBetween( const std::vector& from, const std::vector& to, diff --git a/csrc/iter_visitor.h b/csrc/iter_visitor.h index dec7c79c252..a174b47a2c6 100644 --- a/csrc/iter_visitor.h +++ b/csrc/iter_visitor.h @@ -316,6 +316,28 @@ class StmtSort : public IterVisitor { bool traverse_attributes = false, bool traverse_siblings = false); + // Returns all ordered Statements of a given fusion. Unlike + // getStmts, for TensorDomain, all of its iter domains and exprs are + // grabbed and returned in a topological order. + NVF_API static std::vector getAllStmts( + Fusion* fusion, + bool traverse_members = false, + bool traverse_attributes = false, + bool traverse_siblings = false); + + // Returns ordered Statements required to produce 'to', including + // 'to'. Unlike getStmtsTo, for TensorDomain, all of its iter domains and + // exprs are grabbed and returned in a topological order, regardless of + // `traverse_members`. + // + // The to vals are assumed to be either TensorView or scalar + // Val. This assumption could be removed if desired. + NVF_API static std::vector getAllStmtsTo( + const std::vector& to, + bool traverse_members = false, + bool traverse_attributes = false, + bool traverse_siblings = false); + // Returns ordered Statements required to produce from, including from. // Stops traversal once hiting any Statements in to. Includes Statements in // to. diff --git a/tests/cpp/test_loop_domain_scheduling.cpp b/tests/cpp/test_loop_domain_scheduling.cpp index 5d7c4c6f508..107e0081eee 100644 --- a/tests/cpp/test_loop_domain_scheduling.cpp +++ b/tests/cpp/test_loop_domain_scheduling.cpp @@ -9,6 +9,7 @@ #include #include +#include #include #include #include @@ -20,6 +21,25 @@ namespace nvfuser { +namespace { + +void checkGetAllStmts(Fusion* fusion) { + // Check if StmtSort can grab all IDS, including those that are + // producers of root IDs + auto all_stmts = StmtSort::getAllStmts(fusion, /*traverse_members=*/true); + std::unordered_set all_stmt_set{ + all_stmts.begin(), all_stmts.end()}; + for (auto tv : fusion->allTvs()) { + for (auto id_or_expr : tv->domain()->allStatements()) { + EXPECT_TRUE(all_stmt_set.count(id_or_expr)) + << "Not found: " << id_or_expr->toString() << " of " + << tv->toString(); + } + } +} + +} // namespace + class LoopDomainSchedulingTest : public NVFuserTest { protected: void SetUp() override { @@ -82,6 +102,8 @@ TEST_F(LoopDomainSchedulingTest, ReshapeSplitThenMerge) { } } + checkGetAllStmts(&fusion); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn({10}, options); std::vector inputs({t0}); @@ -143,6 +165,8 @@ TEST_F(LoopDomainSchedulingTest, Slice) { tv->axis(1)->parallelize(ParallelType::TIDx); } + checkGetAllStmts(&fusion); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn(shape, options); std::vector aten_inputs({t0}); @@ -245,6 +269,8 @@ TEST_F(LoopDomainSchedulingTest, ReshapeTraversalDirection) { tv5_loop_to_logical.at(3).first, tv4->getLogicalDomain().at(0)->definition()) && tv5_loop_to_logical.at(3).second == Direction::Forward); + + checkGetAllStmts(&fusion); } // Using the same fusion as ReshapeTraversalDirection, try each one of @@ -309,6 +335,8 @@ TEST_F(LoopDomainSchedulingTest, ManyReshape) { EXPECT_EQ(tv->getComputeAtPosition(), tv->getLoopDomain().size()); } + checkGetAllStmts(&fusion_copy); + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); auto t0 = at::randn({12}, options); std::vector aten_inputs({t0}); @@ -383,6 +411,8 @@ TEST_F(LoopDomainSchedulingTest, ScheduleLoopDomainsBy1) { EXPECT_EQ(tv1->getLoopDomain(), tv1_loop_domain); EXPECT_EQ(tv2->getLoopDomain(), tv2_loop_domain); + + checkGetAllStmts(&fusion); } // Testing scheduleLoopDomainBy on its insertion position of new IDs @@ -414,6 +444,8 @@ TEST_F(LoopDomainSchedulingTest, ScheduleLoopDomainsBy2) { EXPECT_EQ( exact_graph.toGroups(tv1->getLoopDomain()), exact_graph.toGroups(tv2->getLoopDomain())); + + checkGetAllStmts(&fusion); } } // namespace nvfuser diff --git a/tests/cpp/test_resize.cpp b/tests/cpp/test_resize.cpp index 40dff5237b1..0b7e816cc46 100644 --- a/tests/cpp/test_resize.cpp +++ b/tests/cpp/test_resize.cpp @@ -4090,8 +4090,6 @@ TEST_F(ResizeTest, PropagateSliceToInputsWithReshape1) { // Fusion should have a uniform loop domain checkLoopDomainEquivalence(ref_tv); - fusion.print(); - // Schedule the reference ref_tv->flatten(); // For TIDx @@ -4241,8 +4239,6 @@ TEST_F(ResizeTest, PropagateMultipleSlicesToInputs) { // Fusion should have a uniform loop domain checkLoopDomainEquivalence(ref_tv); - fusion.print(); - // Schedule the reference ref_tv->flatten(); // For TIDx