-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Allow inlining past loop broadcasts for MmaOp #3416
Changes from 44 commits
9bf2645
96dd201
c7c790b
29fe28b
11083a1
11c43c4
683ae1e
c29470f
c06917f
64be2c7
a9cc7aa
46e6ca8
cc236fd
05d5ca4
5974567
e0ad380
3c2631f
3342e77
ee5329f
0cf29e5
2475578
381035f
3fa19a2
732b873
60b23a3
9feb8f8
2f89ab4
85c4172
c706046
6f451f7
434c8a1
bd45e7e
9742b5b
9fb9aad
e623561
0660f8c
6e17d11
ee6a89a
2959d88
9a3cb54
62f9ede
0c25adc
8913262
4a3b0d2
cfc7ed9
7f6f5a1
6d99f0d
a902803
ff358f7
951757d
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -5,12 +5,14 @@ | |||||||||||||||||||||||
* SPDX-License-Identifier: BSD-3-Clause | ||||||||||||||||||||||||
*/ | ||||||||||||||||||||||||
// clang-format on | ||||||||||||||||||||||||
#include <device_lower/utils.h> | ||||||||||||||||||||||||
#include <id_model/utils.h> | ||||||||||||||||||||||||
#include <ir/utils.h> | ||||||||||||||||||||||||
#include <iter_visitor.h> | ||||||||||||||||||||||||
#include <logical_domain_map.h> | ||||||||||||||||||||||||
#include <scheduler/tools/inlining.h> | ||||||||||||||||||||||||
#include <transform_iter.h> | ||||||||||||||||||||||||
#include <val_graph_visitor.h> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
#include <utility> | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
@@ -193,6 +195,21 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( | |||||||||||||||||||||||
} | ||||||||||||||||||||||||
return producer->nDims(); | ||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||
std::unordered_set<ValGroup> loop_path_groups; | ||||||||||||||||||||||||
if (consumer->definition()->isA<MmaOp>()) { | ||||||||||||||||||||||||
// Get ValGroups between producer and consumer loop in the inlining graph | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Maybe it would make more sense to start by using a |
||||||||||||||||||||||||
std::vector<ValGroup> producer_loop_groups, consumer_loop_groups; | ||||||||||||||||||||||||
for (IterDomain* id : producer->getLoopDomain()) { | ||||||||||||||||||||||||
producer_loop_groups.push_back(inliningGraph().toGroup(id)); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
for (IterDomain* id : consumer->getLoopDomain()) { | ||||||||||||||||||||||||
consumer_loop_groups.push_back(inliningGraph().toGroup(id)); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
std::vector<ValGroup> group_path = getValsBetween<ValGraphBFS>( | ||||||||||||||||||||||||
producer_loop_groups, consumer_loop_groups, inliningGraph()); | ||||||||||||||||||||||||
loop_path_groups.insert(group_path.begin(), group_path.end()); | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
auto consumer_it = consumer->getLoopDomain().begin(); | ||||||||||||||||||||||||
for (const auto producer_pos : c10::irange(producer->nDims())) { | ||||||||||||||||||||||||
auto p_id = producer->getLoopDomain().at(producer_pos); | ||||||||||||||||||||||||
|
@@ -211,8 +228,34 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( | |||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
IterDomain* c_id = *consumer_it; | ||||||||||||||||||||||||
if (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || | ||||||||||||||||||||||||
!isAllowedID(c_id, consumer, best_effort, true, false, true)) { | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
// We can inline past consumer IDs that are not connected to the producer. | ||||||||||||||||||||||||
// | ||||||||||||||||||||||||
// For example, an MmaOp with no broadcasts could contain the following: | ||||||||||||||||||||||||
// tv0: | ||||||||||||||||||||||||
// root/logical: [ iS0, iS1 ] | ||||||||||||||||||||||||
// loop: [ iS0, bS7, iS1 ] | ||||||||||||||||||||||||
// tv1: | ||||||||||||||||||||||||
// root/logical: [ iS2, iS3 ] | ||||||||||||||||||||||||
// loop: [ bS8, iS2, iS3 ] | ||||||||||||||||||||||||
// tv2: | ||||||||||||||||||||||||
// root/logical/loop: [ iS4, iS5, rS6 ] | ||||||||||||||||||||||||
Comment on lines
+262
to
+269
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. How are these MmaOp inputs and output actually scheduled? Do their loop domains look like just shown above? I'm asking because if, for example, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is how the schedule looks in the test:
Also note that these expressions are not shown in the printout because they come from loop broadcasts
The issue comes when we want to inline past these two outer IDs. You are right that in some other case we might merge with some actual mapped domains like There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Thanks. I think this is good enough for now, but I wonder if we should extend the inlining graph. Suppose we have:
In this case,
Maybe we should map And maybe we should actually map I think my point is also phrased as, instead of tweaking the inlining logic, should we define the inlining graph such that it would allow the patterns like the above? Again, not asking any change with this PR, but to me if we call something the inlining graph, it should precisely reflect what can and cannot be inlined. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I also wonder if we could solve this problem by just moving the broadcast IDs to innermost. For
If we move There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
I think your example here sums up the issue. We would like to map
👍
They are only broadcast in one of the operands. In the output tensor and in the other operand they are Iteration domains. I don't think we can move them because some of the ones that are Broadcast are the ones we need to inline; they are the outer split dimensions which we parallelize with BIDx/BIDy, i.e. these are the tile coordinates. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
By this I mean "a vector IterDomains we'd like to pretend that our new broadcast loop ID is broadcast mapped to" There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added this a little before: https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L2533-L2535 Not recommended, though. If it's the only WAR, then sure, but otherwise, I found it not robust. For example, if IDs of |
||||||||||||||||||||||||
// | ||||||||||||||||||||||||
// iS4 maps to iS0 so when producer==tv0 we inline past iS0. When | ||||||||||||||||||||||||
// producer==tv1, iS4 doesn't map to anything in tv1 and is not used for | ||||||||||||||||||||||||
// indexing, and bS8 is a loop broadcast so we inline past the first ID | ||||||||||||||||||||||||
// in that case also. Similarly, we inline past iS5, iS2, and bS7. | ||||||||||||||||||||||||
if ((loop_path_groups.empty() || | ||||||||||||||||||||||||
loop_path_groups.count(inliningGraph().toGroup(p_id)) || | ||||||||||||||||||||||||
loop_path_groups.count(inliningGraph().toGroup(c_id))) && | ||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For the MmaOp case, we will hit this when There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is it always the case that an ignored producer ID is a broadcast? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
If that's the case, do we need to check There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
This is currently the case: we create a loop broadcast for the operands to MmaOp. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Shouldn't we assert that? If not broadcast, I'm not sure if it's safe to skip. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think separating this logic would be easier to follow, like lines 222-224.
Suggested change
|
||||||||||||||||||||||||
(!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || | ||||||||||||||||||||||||
!isAllowedID( | ||||||||||||||||||||||||
c_id, | ||||||||||||||||||||||||
consumer, | ||||||||||||||||||||||||
best_effort, | ||||||||||||||||||||||||
/*allow_reduction=*/true, | ||||||||||||||||||||||||
/*allow_vectorize=*/false, | ||||||||||||||||||||||||
/*allow_unmappable=*/true))) { | ||||||||||||||||||||||||
return producer_pos; | ||||||||||||||||||||||||
} | ||||||||||||||||||||||||
|
||||||||||||||||||||||||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel we should have a few more tests. Can we create a test that should not get inlined even with the added condition? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a section to the test that swaps the No and Ko axes in the mma result (which is scheduled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since an empty set can mean it's just not set or there's indeed no val group, I think this would be clearer and less error-prone: