Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow inlining past loop broadcasts for MmaOp #3416

Merged
merged 50 commits into from
Dec 11, 2024
Merged
Show file tree
Hide file tree
Changes from 44 commits
Commits
Show all changes
50 commits
Select commit Hold shift + click to select a range
9bf2645
Move OptOutMutator tests to new file and add repro
jacobhinkle Nov 15, 2024
96dd201
Add additional_ids arg to big ctor
jacobhinkle Nov 15, 2024
c7c790b
Only check actually used IDs in predicate elimination
jacobhinkle Nov 15, 2024
29fe28b
Allow inlining loop broadcasts
jacobhinkle Nov 15, 2024
11083a1
clang-format
jacobhinkle Nov 15, 2024
11c43c4
clang-tidy of TensorDomain ctor
jacobhinkle Nov 15, 2024
683ae1e
Merge branch 'mutator_preserve_additional_ids' into mma_predicate_eli…
jacobhinkle Nov 15, 2024
c29470f
Merge branch 'main' into mma_predicate_elimination
jacobhinkle Nov 15, 2024
c06917f
Check IterType of loop broadcasts
jacobhinkle Nov 15, 2024
64be2c7
Remove debugging comment
jacobhinkle Nov 18, 2024
a9cc7aa
Merge remote-tracking branch 'origin/main' into mma_predicate_elimina…
jacobhinkle Nov 18, 2024
46e6ca8
Merge remote-tracking branch 'origin/mma_predicate_elimination' into …
jacobhinkle Nov 18, 2024
cc236fd
Track IDs used in indexing
jacobhinkle Nov 18, 2024
05d5ca4
[DO NOT MERGE] added throw to test impact on existing tests
jacobhinkle Nov 18, 2024
5974567
Refactor getting indexing IDs into utility
jacobhinkle Nov 19, 2024
e0ad380
Put back accidentally removed replay
jacobhinkle Nov 19, 2024
3c2631f
Add skipped root->logical mappings in c2p
jacobhinkle Nov 19, 2024
3342e77
Simplify getIndexIDs
jacobhinkle Nov 19, 2024
ee5329f
Remove NVF_THROW and disable matmul test for codediff
jacobhinkle Nov 19, 2024
0cf29e5
Enable test
jacobhinkle Nov 19, 2024
2475578
Merge remote-tracking branch 'origin/mma_predicate_elimination' into …
jacobhinkle Nov 20, 2024
381035f
Avoid processing non-indexing inputs to Merge
jacobhinkle Nov 20, 2024
3fa19a2
Merge remote-tracking branch 'origin/mma_predicate_elimination' into …
jacobhinkle Nov 20, 2024
732b873
Remove declaration that shadowed c2p_tmp
jacobhinkle Nov 20, 2024
60b23a3
Merge remote-tracking branch 'origin/main' into mma_predicate_elimina…
jacobhinkle Nov 20, 2024
9feb8f8
Update in light of #3452
jacobhinkle Nov 20, 2024
2f89ab4
Merge remote-tracking branch 'origin/main' into mma_predicate_elimina…
jacobhinkle Nov 20, 2024
85c4172
Merge remote-tracking branch 'origin/mma_predicate_elimination' into …
jacobhinkle Nov 20, 2024
c706046
Use getIndexIDs
jacobhinkle Nov 20, 2024
6f451f7
Only check index IDs for MmaOp
jacobhinkle Nov 27, 2024
434c8a1
Merge remote-tracking branch 'origin/mma_predicate_elimination' into …
jacobhinkle Dec 3, 2024
bd45e7e
Guard changes so they only affect MmaOp
jacobhinkle Dec 3, 2024
9742b5b
Merge remote-tracking branch 'origin/main' into mma_predicate_elimina…
jacobhinkle Dec 3, 2024
9fb9aad
Simplify utility to lower_utils::getIdsBetween
jacobhinkle Dec 3, 2024
e623561
Rename to getIdsAlongPathBetween and add example to comment
jacobhinkle Dec 3, 2024
0660f8c
Use loop group traversal from alloc to loop
jacobhinkle Dec 4, 2024
6e17d11
Remove getIdsAlongPathBetween
jacobhinkle Dec 5, 2024
ee6a89a
Use TensorIndexer and getValsBetween
jacobhinkle Dec 5, 2024
2959d88
Don't need to promote allocation IDs
jacobhinkle Dec 5, 2024
9a3cb54
Merge remote-tracking branch 'origin/mma_predicate_elimination' into …
jacobhinkle Dec 5, 2024
62f9ede
Use inlining graph path between loop domains instead of getIndexIDs
jacobhinkle Dec 5, 2024
0c25adc
Merge remote-tracking branch 'origin/main' into mma_inlining
jacobhinkle Dec 5, 2024
8913262
Merge remote-tracking branch 'origin/main' into mma_inlining
jacobhinkle Dec 6, 2024
4a3b0d2
Undo stale change reordering predicate elimination lowering pass
jacobhinkle Dec 6, 2024
cfc7ed9
Get path from mapped IDs. Improve comments. Try asserting more
jacobhinkle Dec 10, 2024
7f6f5a1
Merge remote-tracking branch 'origin/main' into mma_inlining
jacobhinkle Dec 10, 2024
6d99f0d
Add test to check improper inlining is not done
jacobhinkle Dec 10, 2024
a902803
Remove debug print
jacobhinkle Dec 10, 2024
ff358f7
Comment why tv1c is not inlined to position 1
jacobhinkle Dec 10, 2024
951757d
Add comment about the traversal
jacobhinkle Dec 10, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 45 additions & 2 deletions csrc/scheduler/tools/inlining.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,14 @@
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#include <device_lower/utils.h>
#include <id_model/utils.h>
#include <ir/utils.h>
#include <iter_visitor.h>
#include <logical_domain_map.h>
#include <scheduler/tools/inlining.h>
#include <transform_iter.h>
#include <val_graph_visitor.h>

#include <utility>

Expand Down Expand Up @@ -193,6 +195,21 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
}
return producer->nDims();
} else {
std::unordered_set<ValGroup> loop_path_groups;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since an empty set can mean it's just not set or there's indeed no val group, I think this would be clearer and less error-prone:

Suggested change
std::unordered_set<ValGroup> loop_path_groups;
std::optional<std::unordered_set<ValGroup>> loop_path_groups;

if (consumer->definition()->isA<MmaOp>()) {
// Get ValGroups between producer and consumer loop in the inlining graph
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it would make more sense to start by using a PairwiseLogicalDomainMap to find mapped ID groups between producer and consumer, then doing BFS to find the path from this to both loop domain val groups. That way we could avoid missing loop groups in both the producer and consumer. That case won't occur for MmaOp but it might be clearer than traversing from producer loop to consumer loop since it's not immediately clear why we go that direction.

std::vector<ValGroup> producer_loop_groups, consumer_loop_groups;
for (IterDomain* id : producer->getLoopDomain()) {
producer_loop_groups.push_back(inliningGraph().toGroup(id));
}
for (IterDomain* id : consumer->getLoopDomain()) {
consumer_loop_groups.push_back(inliningGraph().toGroup(id));
}
std::vector<ValGroup> group_path = getValsBetween<ValGraphBFS>(
producer_loop_groups, consumer_loop_groups, inliningGraph());
loop_path_groups.insert(group_path.begin(), group_path.end());
}

auto consumer_it = consumer->getLoopDomain().begin();
for (const auto producer_pos : c10::irange(producer->nDims())) {
auto p_id = producer->getLoopDomain().at(producer_pos);
Expand All @@ -211,8 +228,34 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer(
}

IterDomain* c_id = *consumer_it;
if (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) ||
!isAllowedID(c_id, consumer, best_effort, true, false, true)) {

// We can inline past consumer IDs that are not connected to the producer.
//
// For example, an MmaOp with no broadcasts could contain the following:
// tv0:
// root/logical: [ iS0, iS1 ]
// loop: [ iS0, bS7, iS1 ]
// tv1:
// root/logical: [ iS2, iS3 ]
// loop: [ bS8, iS2, iS3 ]
// tv2:
// root/logical/loop: [ iS4, iS5, rS6 ]
Comment on lines +262 to +269
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are these MmaOp inputs and output actually scheduled? Do their loop domains look like just shown above? I'm asking because if, for example, bS8 gets merged with iS2, it would be included in the BFS path between tv1 and tv2, so we wouldn't skip the merge output domain, but it isn't mapped with any of the loop ID of tv2 in the broadcast graph (because bS8 is not mapped), the inlining would be stopped at that point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is how the schedule looks in the test:

T2_l_float[iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16}] ca_pos( 2 ) produce_pos( 3 )
   = mma(T4_s___half[iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}] ca_pos( 3 ),
         T5_s___half[bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}] ca_pos( 3 ))

T4_s___half[iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8}] ca_pos( 3 )
 logical domain : (iS9{i0}, iS10{i1})
 allocation domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
 contiguity: t n t n t t t t t t
  Split: iS10{i1} by factor 128 -> iblockIdx.y31{( ceilDiv(i1, 128) )}, iS32{128}
  Split: iS9{i0} by factor 16 -> iS35{( ceilDiv(i0, 16) )}, iS36{16}
  Split: iS32{128} by factor 64 -> iS43{2}, iS44{64}
  Split: iS36{16} by factor 8 -> iB45{2}, iS46{8}
  Split: iS46{8} by factor 1 -> iS47{8}, iB48{1}
  Split: iS44{64} by factor 8 -> iS49{8}, iB50{8}
  Xor(2D): iS47{8} , iS49{8} -> iB51{8} , iB52{8}
 loop domain : (iblockIdx.y31{( ceilDiv(i1, 128) )}, bblockIdx.x33{1}, iS35{( ceilDiv(i0, 16) )}, bS34{256}, iS43{2}, iB45{2}, iB51{8}, iB48{1}, iB52{8}, iB50{8})
T5_s___half[bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8}] ca_pos( 3 )
 logical domain : (iS11{i3}, iS12{i4})
 allocation domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
 contiguity: n t t n t t t t t t
  Split: iS12{i4} by factor 256 -> iblockIdx.x39{( ceilDiv(i4, 256) )}, iS40{256}
  Split: iS11{i3} by factor 16 -> iS41{( ceilDiv(i3, 16) )}, iS42{16}
  Split: iS40{256} by factor 64 -> iS53{4}, iS54{64}
  Split: iS42{16} by factor 8 -> iB55{2}, iS56{8}
  Split: iS56{8} by factor 1 -> iS57{8}, iB58{1}
  Split: iS54{64} by factor 8 -> iS59{8}, iB60{8}
  Xor(2D): iS57{8} , iS59{8} -> iB61{8} , iB62{8}
 loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})
T2_l_float[iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16}] ca_pos( 2 ) produce_pos( 3 )
 logical domain : (iS4{i1}, iS5{i4}, rS6{i0})
 allocation domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, ithreadIdx.x87{128}, iMMA82{32}, iMMA81{2}, iMMA85{2}, rMMA90{2}, rMMA91{4}, rMMA89{2})
 contiguity: t t n t t t t t n n n
  Split: iS4{i1} by factor 128 -> iblockIdx.y17{( ceilDiv(i1, 128) )}, iS18{128}
  Split: iS5{i4} by factor 256 -> iblockIdx.x19{( ceilDiv(i4, 256) )}, iS20{256}
  Split: rS6{i0} by factor 16 -> rS21{( ceilDiv(i0, 16) )}, rMMA22{16}
  Split: iS18{128} by factor 64 -> iS63{2}, iMMA64{64}
  Split: iS20{256} by factor 256 -> iS65{1}, iMMA66{256}
  Merge: iS63{2} and iS65{1} -> ithreadIdx.y67{2}
 loop domain : (iblockIdx.y17{( ceilDiv(i1, 128) )}, iblockIdx.x19{( ceilDiv(i4, 256) )}, rS21{( ceilDiv(i0, 16) )}, ithreadIdx.y67{2}, iMMA64{64}, iMMA66{256}, rMMA22{16})

Also note that these expressions are not shown in the printout because they come from loop broadcasts bS15{1} and bS16{1}, but we split those broadcasts for each operand:

Split: bS15{1} by factor 256 -> bblockIdx.x33{1}, bS34{256}
Split: bS16{1} by factor 128 -> bblockIdx.y37{1}, bS38{128}

The issue comes when we want to inline past these two outer IDs.

You are right that in some other case we might merge with some actual mapped domains like iS2 in your example. In that case it would be included so we would not skip it and we would not be able to inline, even though we probably should be able to. I'm not sure how to handle such a case. In your example if the merge(iS2, bS8) ID in a producer is aligned with a merge(iS9, iS10) in the consumer, we have no way to represent that bS8 should map to iS10 since bS8 is a loop broadcast and is not going to Broadcast map to anything, but that's really the type of relationship I'm trying to fake here -- in my case I'm faking that the loop broadcast M or N dimension will Broadcast map to the mma output M or N dimension.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks. I think this is good enough for now, but I wonder if we should extend the inlining graph. Suppose we have:

t0: [i0]
t1: [i1, i2]

// [i0, b3]
t0->broadcast(0);

In this case, b3 is the added broadcast ID. Since it's not part of the math definition, it won't be mapped with anything. However, if we have:

// [i4]
t0->merge(0, 1);

// [i5]
t1->merge(0, 1);

Maybe we should map i4 and i5 because we want to allow inlining of i4 to i5.

And maybe we should actually map b3 and i2 for the inlining purpose.

I think my point is also phrased as, instead of tweaking the inlining logic, should we define the inlining graph such that it would allow the patterns like the above? Again, not asking any change with this PR, but to me if we call something the inlining graph, it should precisely reflect what can and cannot be inlined.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I also wonder if we could solve this problem by just moving the broadcast IDs to innermost. For T5,

loop domain : (bblockIdx.y37{1}, iblockIdx.x39{( ceilDiv(i4, 256) )}, iS41{( ceilDiv(i3, 16) )}, bS38{128}, iS53{4}, iB55{2}, iB61{8}, iB58{1}, iB62{8}, iB60{8})

If we move b37 and b38 to the innermost position, would that solve the inlining problem? Somewhat related prior fix: #2799

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suppose we have:

I think your example here sums up the issue. We would like to map b3 and i2 in the inlining graph but we have no way to express this wish currently. One option would be to add an optional argument to TensorView::broadcast letting us pass a vector of IterDomains that we'd like to inline with, then we could detect these just after building the Broadcast graph, copy the graph if any of those mappings are found, and perform those mappings, using the resulting graph as the inlining graph.

but to me if we call something the inlining graph, it should precisely reflect what can and cannot be inlined.

👍

moving the broadcast IDs to innermost

They are only broadcast in one of the operands. In the output tensor and in the other operand they are Iteration domains. I don't think we can move them because some of the ones that are Broadcast are the ones we need to inline; they are the outer split dimensions which we parallelize with BIDx/BIDy, i.e. these are the tile coordinates.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

letting us pass a vector of IterDomains that we'd like to inline with

By this I mean "a vector IterDomains we'd like to pretend that our new broadcast loop ID is broadcast mapped to"

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added this a little before:

https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L2533-L2535

Not recommended, though. If it's the only WAR, then sure, but otherwise, I found it not robust. For example, if IDs of x and y are registered as exactly mapped, that information may not be preserved when an iter domain is replaced by another (e.g., replaceSymbolicSizeds).

//
// iS4 maps to iS0 so when producer==tv0 we inline past iS0. When
// producer==tv1, iS4 doesn't map to anything in tv1 and is not used for
// indexing, and bS8 is a loop broadcast so we inline past the first ID
// in that case also. Similarly, we inline past iS5, iS2, and bS7.
if ((loop_path_groups.empty() ||
loop_path_groups.count(inliningGraph().toGroup(p_id)) ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) &&
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the MmaOp case, we will hit this when c_id is not in loop_path_groups because it is an M or N dimension. Note to self: I think we should probably assert that when this happens, both c_id and p_id are not found in that set, to avoid mistakenly inlining in a case where a mapped producer dimension is in the same position as an unmapped consumer dimension.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always the case that an ignored producer ID is a broadcast?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Note to self: I think we should probably assert that when this happens, both c_id and p_id are not found in that set

If that's the case, do we need to check c_id? Isn't loop_path_groups.count(inliningGraph().toGroup(p_id)) sufficient?

Copy link
Collaborator Author

@jacobhinkle jacobhinkle Dec 6, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it always the case that an ignored producer ID is a broadcast?

This is currently the case: we create a loop broadcast for the operands to MmaOp.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we assert that? If not broadcast, I'm not sure if it's safe to skip.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think separating this logic would be easier to follow, like lines 222-224.

Suggested change
if ((loop_path_groups.empty() ||
loop_path_groups.count(inliningGraph().toGroup(p_id)) ||
loop_path_groups.count(inliningGraph().toGroup(c_id))) &&
// We can inline past consumer IDs that are not connected to the producer.
//
// For example, an MmaOp with no broadcasts could contain the following:
if (loop_path_groups.has_value() &&
(!loop_path_groups.count(inliningGraph().toGroup(p_id)) ||
!loop_path_groups.count(inliningGraph().toGroup(c_id)))) {
continue;
}

(!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) ||
!isAllowedID(
c_id,
consumer,
best_effort,
/*allow_reduction=*/true,
/*allow_vectorize=*/false,
/*allow_unmappable=*/true))) {
return producer_pos;
}

Expand Down
3 changes: 3 additions & 0 deletions tests/cpp/test_matmul.cpp
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel we should have a few more tests. Can we create a test that should not get inlined even with the added condition?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a section to the test that swaps the No and Ko axes in the mma result (which is scheduled Mo, No, Ko, ...). This is done on a copy of the scheduled fusion then I call inlineMost to check that we don't mistakenly inline past the unmapped No axis in tv0c.

Original file line number Diff line number Diff line change
Expand Up @@ -3931,6 +3931,9 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) {

inlineMost();

EXPECT_EQ(tv0c->getComputeAtPosition(), 3);
EXPECT_EQ(tv1c->getComputeAtPosition(), 3);

if (stages > 1) {
tv0c->circularBuffer(stages, prefetch);
tv1c->circularBuffer(stages, prefetch);
Expand Down
Loading