diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index 6e7b51caba4..6064f4e6e7c 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -5,12 +5,14 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include #include #include #include #include +#include #include @@ -193,6 +195,46 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( } return producer->nDims(); } else { + std::optional> loop_path_groups = std::nullopt; + if (consumer->definition()->isA()) { + // We handle MmaOp specially here since it is currently the only operation + // for which we generate code (i.e. not SdpaFwdOp or SdpaBwdOp) that has + // some output dimensions that do not map to input dimensions. For this + // case, we need to identify potential inlined pairs each ID of which is + // not mapped at all to the other TensorView (see example below). + + // Get ValGroups in loop domains of producer and consumer that are + // connected to _mapped_ IterDomains in the pairwise map. + // + // Note that for MmaOp, it would be sufficient to traverse from the + // producer loop to the consumer loop and identify when _either_ the + // consumer or producer ID is not mapped. Here we are instead traversing + // from mapped domains to both roots so that we can check that _both_ + // consumer and producer ID is not mapped. This is slightly safer and this + // symmetry might be handy in handling new ops that use this feature in + // the future. + std::vector pairwise_mapped_groups; + for (auto [c_id, p_id] : PairwiseLogicalDomainMap(producer, consumer) + .mapConsumerToProducer()) { + pairwise_mapped_groups.push_back(inliningGraph().toGroup(c_id)); + } + // We propagate toward the loop groups from both consumer and producer + std::vector all_loop_groups; + for (IterDomain* id : producer->getLoopDomain()) { + all_loop_groups.push_back(inliningGraph().toGroup(id)); + } + for (IterDomain* id : consumer->getLoopDomain()) { + all_loop_groups.push_back(inliningGraph().toGroup(id)); + } + // getValsBetween does not require all target groups to be visited. The + // means the result contains the subset of both loop groups that we are + // looking for + std::vector group_path = getValsBetween( + pairwise_mapped_groups, all_loop_groups, inliningGraph()); + loop_path_groups = + std::unordered_set(group_path.begin(), group_path.end()); + } + auto consumer_it = consumer->getLoopDomain().begin(); for (const auto producer_pos : c10::irange(producer->nDims())) { auto p_id = producer->getLoopDomain().at(producer_pos); @@ -211,8 +253,54 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( } IterDomain* c_id = *consumer_it; + + // We can inline past positions in which both producer and consumer are + // not connected to any mapped logical IterDomain pairs. + // + // For example, an MmaOp can be constructed as follows: + // + // tv0: + // root/logical: [ iS0, iS1 ] + // loop: [ iS0, bS7, iS1 ] + // tv1: + // root/logical: [ iS2, iS3 ] + // loop: [ bS8, iS2, iS3 ] + // tv2: + // root/logical/loop: [ iS4, iS5, rS6 ] + // + // iS4 maps to iS0 so when producer==tv0 we can inline past iS0. When + // producer==tv1, iS4 doesn't map to anything in tv1 and bS8 is a loop + // broadcast in that position so we inline past the first ID in that + // case also. Similarly, we inline past iS5, iS2, and bS7. + if (loop_path_groups.has_value()) { + bool p_id_connected = + loop_path_groups->count(inliningGraph().toGroup(p_id)); + bool c_id_connected = + loop_path_groups->count(inliningGraph().toGroup(c_id)); + NVF_ERROR( + p_id_connected || + (consumer->definition()->isA() && p_id->isBroadcast()), + "Expected unmapped producer id to be broadcast domain in MmaOp input but found ", + p_id->toString()); + + if (!p_id_connected && !c_id_connected) { + NVF_ERROR( + p_id->isBroadcast(), + "Unmapped producer ID must be a broadcast created in scheduling but found ", + p_id->toString()); + ++consumer_it; + continue; + } + } + if (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || - !isAllowedID(c_id, consumer, best_effort, true, false, true)) { + !isAllowedID( + c_id, + consumer, + best_effort, + /*allow_reduction=*/true, + /*allow_vectorize=*/false, + /*allow_unmappable=*/true)) { return producer_pos; } diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 87657e50997..5be1801e29e 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3657,7 +3657,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { const auto dtype = DataType::Half; constexpr bool use_smem_epilogue = false; - constexpr bool use_warp_specialization = true; constexpr int64_t stages = 4; constexpr int64_t prefetch = 3; @@ -3788,13 +3787,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { inlineMost(); - if (use_warp_specialization) { - tv0c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy)); - tv1c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy)); - } else { - tv0c->circularBuffer(stages, prefetch); - tv1c->circularBuffer(stages, prefetch); - } + tv0c->circularBuffer(stages, prefetch); + tv1c->circularBuffer(stages, prefetch); auto inputs = matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); @@ -3935,8 +3929,32 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { } tv3->axis(-1)->parallelize(ParallelType::Vectorize); + { + // Check using a copy that improperly aligned axis are not inlined + Fusion tmp_fusion; + IrCloner ir_cloner = Fusion::copy(&fusion, &tmp_fusion); + FusionGuard tmp_fg(&tmp_fusion); + // [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki] + // Swap the No and Ko axes, but only in tv2, the mma output + // [Mo, Ko, No, Mio, Nio, Mii, Nii, Ki] + // This should mean the smem operands are now inlined at position 1 instead + // of 3 + ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); + inlineMost(); + tmp_fusion.printMath(); + ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); + EXPECT_EQ(ir_cloner.clone(tv0c)->getComputeAtPosition(), 1); + // The outermost loop dim of tv1c is a broadcast Mo axis, so + // tv1c->inlineAt(1) does not inline past that axis and we wind up with + // compute-at position 0. + EXPECT_EQ(ir_cloner.clone(tv1c)->getComputeAtPosition(), 0); + } + inlineMost(); + EXPECT_EQ(tv0c->getComputeAtPosition(), 3); + EXPECT_EQ(tv1c->getComputeAtPosition(), 3); + if (stages > 1) { tv0c->circularBuffer(stages, prefetch); tv1c->circularBuffer(stages, prefetch);