From 9bf2645f38acac19b93a331cb16486e00e85a397 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Nov 2024 19:49:12 -0500 Subject: [PATCH 01/35] Move OptOutMutator tests to new file and add repro --- CMakeLists.txt | 1 + tests/cpp/test_dynamic_transform.cpp | 80 -------------- tests/cpp/test_mutator.cpp | 149 +++++++++++++++++++++++++++ 3 files changed, 150 insertions(+), 80 deletions(-) create mode 100644 tests/cpp/test_mutator.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 87f11f16658..9d7d7b32cdb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -560,6 +560,7 @@ list(APPEND JIT_TEST_SRCS ${NVFUSER_ROOT}/tests/cpp/test_memory.cpp ${NVFUSER_ROOT}/tests/cpp/test_move_split_cat.cpp ${NVFUSER_ROOT}/tests/cpp/test_move_pad.cpp + ${NVFUSER_ROOT}/tests/cpp/test_mutator.cpp ${NVFUSER_ROOT}/tests/cpp/test_no_op.cpp ${NVFUSER_ROOT}/tests/cpp/test_persistent_buffer.cpp ${NVFUSER_ROOT}/tests/cpp/test_pointwise.cpp diff --git a/tests/cpp/test_dynamic_transform.cpp b/tests/cpp/test_dynamic_transform.cpp index 8eb468999b7..e6c6b1292b3 100644 --- a/tests/cpp/test_dynamic_transform.cpp +++ b/tests/cpp/test_dynamic_transform.cpp @@ -1174,86 +1174,6 @@ TEST_F(NVFuserTest, Issue249InputNegative1_CUDA) { executor_cache.fusion(), outputs, {at_x, 2, 4, 15}, __LINE__, __FILE__); } -// Test that OptOutMutator mutates expressions in a predictable way -// See https://github.com/NVIDIA/Fuser/issues/852 -TEST_F(NVFuserTest, OptOutMutatorMutatedOutput) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion* fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - TensorView* tv0 = makeSymbolicTensor(1); - fusion->addInput(tv0); - - auto tv1 = neg(tv0); - - auto tv2 = set(tv1); - fusion->addOutput(tv2); - - auto tv3 = set(tv0); - - OptOutMutator mut; - mut.registerMutation(tv1, tv3); - - for (auto stmt : StmtSort::getStmts(fusion)) { - mut.dispatchMutate(stmt); - } - - EXPECT_NE(tv3->definition(), nullptr); - EXPECT_TRUE(tv3->definition()->isA()); - EXPECT_NE(tv2->definition(), nullptr); - EXPECT_TRUE(tv2->definition()->isA()); - EXPECT_EQ(tv2->definition()->input(0), tv3); - - auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); - at::Tensor t0 = at::randn({3}, options); - - inlineMost(); - - KernelExecutor ke; - ke.compile(fusion); - - auto outputs = ke.run({t0}); - - testValidate(fusion, outputs, {t0}, __LINE__, __FILE__); -} - -// Another test related to https://github.com/NVIDIA/Fuser/issues/852 -TEST_F(NVFuserTest, OptOutMutatorRedefinedConstant) { - std::unique_ptr fusion_ptr = std::make_unique(); - Fusion* fusion = fusion_ptr.get(); - FusionGuard fg(fusion); - - auto s0 = IrBuilder::create(DataType::Int); - fusion->addInput(s0); - auto s1 = neg(s0); - - auto tv0 = full({IrBuilder::create(2L)}, s1, DataType::Int); - fusion->addOutput(tv0); - - // After the following mutation, it's reasonable to expect the input scalar s0 - // to be ignored, and the output to just be ones. - OptOutMutator mut; - auto c = fusion->oneVal(DataType::Int); - mut.registerMutation(s1, c); - - for (auto stmt : StmtSort::getStmts(fusion)) { - mut.dispatchMutate(stmt); - } - - EXPECT_EQ( - c->definition(), nullptr); // Replacement value should not be redefined - EXPECT_EQ(tv0->definition()->as()->getFillValue(), c); - - inlineMost(); - - KernelExecutor ke; - ke.compile(fusion); - - auto outputs = ke.run({3L}); - - testValidate(fusion, outputs, {3L}, __LINE__, __FILE__); -} - // Test that we can squeeze Symbolic IterDomains and that we properly detect // improper concretizations where we have squeezed a dimension with extent // other than 1. diff --git a/tests/cpp/test_mutator.cpp b/tests/cpp/test_mutator.cpp new file mode 100644 index 00000000000..763c3b0c4ac --- /dev/null +++ b/tests/cpp/test_mutator.cpp @@ -0,0 +1,149 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include + +#include + +namespace nvfuser { + +// Test that OptOutMutator mutates expressions in a predictable way +// See https://github.com/NVIDIA/Fuser/issues/852 +TEST_F(NVFuserTest, OptOutMutatorMutatedOutput) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + TensorView* tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + + auto tv1 = neg(tv0); + + auto tv2 = set(tv1); + fusion->addOutput(tv2); + + auto tv3 = set(tv0); + + OptOutMutator mut; + mut.registerMutation(tv1, tv3); + + for (auto stmt : StmtSort::getStmts(fusion)) { + mut.dispatchMutate(stmt); + } + + EXPECT_NE(tv3->definition(), nullptr); + EXPECT_TRUE(tv3->definition()->isA()); + EXPECT_NE(tv2->definition(), nullptr); + EXPECT_TRUE(tv2->definition()->isA()); + EXPECT_EQ(tv2->definition()->input(0), tv3); + + auto options = at::TensorOptions().dtype(at::kFloat).device(at::kCUDA, 0); + at::Tensor t0 = at::randn({3}, options); + + inlineMost(); + + KernelExecutor ke; + ke.compile(fusion); + + auto outputs = ke.run({t0}); + + testValidate(fusion, outputs, {t0}, __LINE__, __FILE__); +} + +// Another test related to https://github.com/NVIDIA/Fuser/issues/852 +TEST_F(NVFuserTest, OptOutMutatorRedefinedConstant) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto s0 = IrBuilder::create(DataType::Int); + fusion->addInput(s0); + auto s1 = neg(s0); + + auto tv0 = full({IrBuilder::create(2L)}, s1, DataType::Int); + fusion->addOutput(tv0); + + // After the following mutation, it's reasonable to expect the input scalar s0 + // to be ignored, and the output to just be ones. + OptOutMutator mut; + auto c = fusion->oneVal(DataType::Int); + mut.registerMutation(s1, c); + + for (auto stmt : StmtSort::getStmts(fusion)) { + mut.dispatchMutate(stmt); + } + + EXPECT_EQ( + c->definition(), nullptr); // Replacement value should not be redefined + EXPECT_EQ(tv0->definition()->as()->getFillValue(), c); + + inlineMost(); + + KernelExecutor ke; + ke.compile(fusion); + + auto outputs = ke.run({3L}); + + testValidate(fusion, outputs, {3L}, __LINE__, __FILE__); +} + +// Test that additional IDs are preserved when mutating a TensorView +TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) { + std::unique_ptr fusion_ptr = std::make_unique(); + Fusion* fusion = fusion_ptr.get(); + FusionGuard fg(fusion); + + auto tv0 = makeSymbolicTensor(1); + fusion->addInput(tv0); + + auto tv1 = exp(tv0); + + fusion->addOutput(tv1); + + // We add a broadcast domain bS2{1}. This adds the new Broadcast ID to tv1->domain()->additionalIDs() + // logical: [ iS1{i0} ] + // loop: [ iS1{i0}, bS2{1} ] + // additional IDs: [ bS2{1} ] + tv1->broadcast(1); + EXPECT_FALSE(tv1->domain()->additionalIDs().empty()); + + // After this split we have + // logical: [ iS1{i0} ] + // loop: [ iS1{i0}, bS3{1}, bS4{2} ] + // additional IDs: [ bS2{1} ] + tv1->split(1, 2); + EXPECT_FALSE(tv1->domain()->additionalIDs().empty()); + + // Now register a mutation that will alter some IDs in the domain + OptOutMutator mut; + mut.registerMutation(tv1->axis(0)->extent(), IrBuilder::create(DataType::Index)); + TensorDomain* old_tensor_domain = tv1->domain(); + auto all_stmts = StmtSort::getStmts( + fusion, + /*traverse_members*/ true, + /*traverse_attributes*/ true, + /*traverse_siblings*/ true); + for (auto stmt : all_stmts) { + mut.dispatchMutate(stmt); + } + EXPECT_TRUE(tv1->domain() != old_tensor_domain) << "Mutation did not change the TensorDomain"; + + EXPECT_FALSE(tv1->domain()->additionalIDs().empty())<< "Mutation did not preserve additional IDs"; +} + + +} // namespace nvfuser From 96dd201ee27d2d44b995e89687ba5617bc2070ed Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 14 Nov 2024 19:49:28 -0500 Subject: [PATCH 02/35] Add additional_ids arg to big ctor --- csrc/ir/internal_base_nodes.h | 3 ++- csrc/ir/nodes.cpp | 4 +++- csrc/mutator.cpp | 4 +++- 3 files changed, 8 insertions(+), 3 deletions(-) diff --git a/csrc/ir/internal_base_nodes.h b/csrc/ir/internal_base_nodes.h index f9f422cd994..58086ec1b0d 100644 --- a/csrc/ir/internal_base_nodes.h +++ b/csrc/ir/internal_base_nodes.h @@ -441,7 +441,8 @@ class TensorDomain : public Val { std::vector logical_domain, std::vector allocation, std::vector loop_domain, - std::vector> contiguity = {}); + std::vector> contiguity = {}, + std::vector additional_ids = {}); TensorDomain(IrBuilderPasskey, const TensorDomain* src); diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index 3d4213f68fe..f88618e03e5 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3078,13 +3078,15 @@ TensorDomain::TensorDomain( std::vector logical_domain, std::vector allocation_domain, std::vector loop_domain, - std::vector> contiguity) + std::vector> contiguity, + std::vector additional_ids) : Val(passkey, ValType::TensorDomain, DataType::Null), root_domain_(std::move(root_domain)), logical_domain_(std::move(logical_domain)), allocation_domain_(std::move(allocation_domain)), loop_domain_(std::move(loop_domain)), initial_loop_domain_(loop_domain_), + additional_ids_(additional_ids), contiguity_( contiguity.empty() ? getContiguityFilledWith(maybeAllocation(), false) : std::move(contiguity)) { diff --git a/csrc/mutator.cpp b/csrc/mutator.cpp index 5f183bf4839..26d1ca82924 100644 --- a/csrc/mutator.cpp +++ b/csrc/mutator.cpp @@ -156,6 +156,7 @@ void OptOutMutator::mutate(TensorDomain* td) { ? updateIdVec(td->allocation()) : std::vector(); std::vector domain = updateIdVec(td->loop()); + std::vector additional_ids = updateIdVec(td->additionalIDs()); if (!mutated) { return; @@ -167,7 +168,8 @@ void OptOutMutator::mutate(TensorDomain* td) { logical_dom, allocation_dom, domain, - td->contiguity()); + td->contiguity(), + additional_ids); registerMutation(td, mutated_val); } From c7c790b24c1a187f0cc9072068addf8b2c164898 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 15 Nov 2024 09:46:47 -0500 Subject: [PATCH 03/35] Only check actually used IDs in predicate elimination --- .../analysis/predicate_elimination.cpp | 44 +++- tests/cpp/test_matmul.cpp | 209 +++++++++++++++--- 2 files changed, 222 insertions(+), 31 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 347cb63222c..3159de8a5db 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -188,7 +188,37 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) .getReplay(); - ProducerConsumerPairAnalyzer analyzer(c2p); + // Find all IterDomains involved in index expressions + // TODO: do we need to find logical IDs in producer that are involved in + // its allocation domain too? + std::vector mapped_root_vals, loop_vals; + for (IterDomain* id : consumer->getRootDomain()) { + if (c2p.find(id) != c2p.end()) { + mapped_root_vals.push_back(id); + } + } + for (IterDomain* id : consumer->getLoopDomain()) { + loop_vals.push_back(id); + } + + // Collect all IterDomains along path instead of Exprs + std::unordered_set index_ids; + for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween( + mapped_root_vals, + loop_vals, + /*require_all_to_visited=*/false)) { + for (Val* v : expr->inputs()) { + if (auto* id = dynamic_cast(v)) { + index_ids.insert(id); + } + } + for (Val* v : expr->outputs()) { + if (auto* id = dynamic_cast(v)) { + index_ids.insert(id); + } + } + } + ProducerConsumerPairAnalyzer analyzer(c2p, index_ids); for (auto id : consumer->getLoopDomain()) { if (analyzer.needsPredicate(id)) { @@ -201,12 +231,19 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { private: ProducerConsumerPairAnalyzer( - const std::unordered_map& c2p) - : c2p_(c2p) {} + const std::unordered_map& c2p, + const std::unordered_set index_ids) + : c2p_(c2p), index_ids_(index_ids) {} // Returns true if no out-of-bound accesses could occur with a // producer bool needsPredicate(IterDomain* consumer_id) { + // TODO: check that this consumer_id is actually involved in indexing the + // producer. If it is not connected to the producer allocation domain in + // the broadcast graph, then we can skip processing it. + if (index_ids_.find(consumer_id) == index_ids_.end()) { + return false; + } needs_predicate_ = false; handle(consumer_id); return needs_predicate_; @@ -297,6 +334,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { //! BestEffort map from consumer IDs to producer IDs const std::unordered_map& c2p_; bool needs_predicate_ = false; + std::unordered_set index_ids_; }; class PredicateChcker : public IterVisitor { diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index fa185665253..ec087e8fd6c 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -449,6 +449,35 @@ TEST_P(MatmulTestWithLayout, AmpereMatmulPipelineGmem) { } } +// Check that mma op is not predicated. +class PredicateChecker : public kir::IrVisitor { + public: + using kir::IrVisitor::handle; + bool found_mma = false; + + private: + void handle(kir::Asm* asm_) final { +#if IS_CPP20 + if (!asm_->code().starts_with("mma") && + !asm_->code().starts_with("wgmma")) { +#else + if (asm_->code().substr(0, 3) != "mma" && + asm_->code().substr(0, 5) != "wgmma") { +#endif + return; + } + found_mma = true; + for (auto expr : scope_exprs_) { + NVF_CHECK( + !expr->isA() || + expr->as()->predicate()->isTrivial(), + "MmaOp should't be predicated!", + " Get predicate ", + expr->as()->predicate()->toInlineString()); + } + } +}; + // Matmul test for Ampere MMA: checking CTA Swizzles TEST_P(MatmulTestWithLayout, AmpereSwizzle) { NVFUSER_TEST_CUDA_ARCH_RANGE_GUARD(7, 5, 9, 0); @@ -549,35 +578,9 @@ TEST_P(MatmulTestWithLayout, AmpereSwizzle) { runtime = 0; } - // Check that mma op is not predicated. This is a regression test for - // https://github.com/NVIDIA/Fuser/issues/95 - class PredicateChecker : public kir::IrVisitor { - public: - using kir::IrVisitor::handle; - bool found_mma = false; - - private: - void handle(kir::Asm* asm_) final { -#if IS_CPP20 - if (!asm_->code().starts_with("mma")) { -#else - if (asm_->code().substr(0, 3) != "mma") { -#endif - return; - } - found_mma = true; - for (auto expr : scope_exprs_) { - NVF_CHECK( - !expr->isA() || - expr->as()->predicate()->isTrivial(), - "MmaOp should't be predicated!", - " Get predicate ", - expr->as()->predicate()->toInlineString()); - } - } - } pred_checker; - + // This is a regression test for https://github.com/NVIDIA/Fuser/issues/95 GpuLower gpulw(&fusion); + PredicateChecker pred_checker; pred_checker.handle(gpulw.run()->topLevelExprs()); ASSERT_TRUE(pred_checker.found_mma); }; @@ -3798,4 +3801,154 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { EXPECT_TRUE(at::allclose(cg_outputs[0], tref, 1e-5, 1e-5)); } +// Test scheduling a Hopper matmul where the operands are 2D +TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { + Fusion fusion; + FusionGuard fg(&fusion); + + // constexpr int64_t M = 2048, N = 2048, K = 8192; + constexpr auto macro = MmaMacro::Hopper_64_256_16; + // constexpr auto layout = MmaLayout::NT; // [K, M] x [K, N] -> [M, N] + constexpr auto swizzle = MmaInputSmemSwizzle::B128; + const auto dtype = DataType::Half; + + constexpr int64_t stages = 1; + constexpr int64_t prefetch = 3; + const int64_t cta_m = 2 * getM(macro); + const int64_t cta_n = 1 * getN(macro); + + auto tv0 = makeContigConcreteTensor({-1, -1}, dtype); // [K, M] + auto tv1 = makeContigConcreteTensor({-1, -1}, dtype); // [K, N] + fusion.addInput(tv0); + fusion.addInput(tv1); + + // The output is [M, N, K] (no reordering needed) + MmaOp::AxisMapping axis_mapping{.a_axes = {1, -1, 0}, .b_axes = {-1, 1, 0}}; + auto tv2 = + fusedMultiplySum(tv0, tv1, /*axes=*/{-1}, /*init=*/nullptr, axis_mapping); + + auto tv3 = castOp(DataType::Half, tv2); + + fusion.addOutput(tv3); + + auto mma_ops = ir_utils::getOpsOfType(&fusion); + NVF_CHECK( + 1 == mma_ops.size(), + "Invalid number of MmaOp instances in fusion definition, expected 1, got ", + mma_ops.size()); + mma_ops.front()->setMacro(macro); + + // gmem [K, M] x gmem [K, N] -mma-> register [M, N, rK] + // register [M, N, rK] -cast-> gmem [M, N] + + auto tv0c = tv0->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv0c->setMemoryType(MemoryType::Shared); + auto tv1c = tv1->cacheAfter(LoadStoreOpType::CpAsyncBulkTensorTile); + tv1c->setMemoryType(MemoryType::Shared); + auto tv3c = tv3->cacheBefore(); + + tv0c->broadcast(-1); // [K, M] -> [K, M, 1] + tv1c->broadcast(-2); // [K, N] -> [K, 1, N] + + // gmem [K, M, 1] -TMA-> smem [K, M, 1] + // gmem [K, 1, N] -TMA-> smem [K, 1, N] + // smem [K, M, 1] x smem [K, 1, N] -mma-> register [M, N, rK] + // register [M, N, rK] -cast-> register [M, N] -set-> gmem [M, N] + + // Create tiles + tv2->split(-3, cta_m); + tv2->split(-2, cta_n); + tv2->split(-1, getK(macro)); + // [Mo, Mi, No, Ni, Ko, Ki] -> [Mo, No, Ko, Mi, Ni, Ki] + tv2->reorder({{-5, -3}, {-3, -2}}); + tv2->axis(0)->parallelize(ParallelType::BIDy); + tv2->axis(1)->parallelize(ParallelType::BIDx); + + // NOTE: since in this case we do not have "proper" broadcast in the inputs, + // we cannot simply propagate transforms to the operands. Instead, we + // propagate forward to the outputs and manually schedule the smem operands. + + // ComputeAtMap in this case finds bS15 and bS16. They are in the loop domain + // at this point + + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + tv2, + -1, + {tv3}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + + // Schedule operands + for (TensorView* tv : {tv0c, tv1c}) { + tv->reorder({{-3, -1}}); // [K, M, N] -> [M, N, K] + // NOTE: above axes are given in MNK order, but inputs are in KMN + tv->split(-3, cta_m); + tv->split(-2, cta_n); + tv->split(-1, getK(macro)); + // [Mo, Mi, No, Ni, Ko, Ki] -> [Mo, No, Ko, Mi, Ni, Ki] + // [Ko, Ki, Mo, Mi, No, Ni] -> [Mo, No, Ko, Mi, Ni, Ki] + tv->reorder({{-5, -3}, {-3, -2}}); + tv->axis(0)->parallelize(ParallelType::BIDy); + tv->axis(1)->parallelize(ParallelType::BIDx); + } + + // [..., Mi, Ni, Ki] -> [..., Ni, Ki, Mi] + tv0c->reorder({{-3, -1}}); + tv0c->applyMmaSwizzleForTMALoad(swizzle); + // [..., Mi, Ni, Ki] -> [..., Mi, Ki, Ni] + tv1c->reorder({{-1, -2}}); + tv1c->applyMmaSwizzleForTMALoad(swizzle); + + { + tv2->split(-3, getM(macro)); + tv2->split(-2, getN(macro)); + // [Mo, No, Ko, Mio, Mii, Nio, Nii, Ki] + // -> [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki] + tv2->reorder({{-4, -3}}); + tv2->merge(-5); + tv2->axis(-4)->parallelize(ParallelType::TIDy); + scheduler_utils::BoundedDirectionalTransformPropagator::forward( + tv2, + -1, + {tv3}, + scheduler_utils::BoundedDirectionalTransformPropagator::Options() + .propagateParallelType() + .propagateToBoundary()); + } + + { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv2->getLoopDomain()); + tv2->setAllocationDomain(s.as(), true); + tv2->axis(-1)->parallelize(ParallelType::Mma); + tv2->axis(-2)->parallelize(ParallelType::Mma); + tv2->axis(-3)->parallelize(ParallelType::Mma); + } + + for (auto tv : {tv3c, tv3}) { + auto s = mma_utils::MmaSwizzler::scheduleMmaOutputAllocation( + tv->getLoopDomain()); + tv->setLoopDomain(s.as()); + } + tv3->axis(-1)->parallelize(ParallelType::Vectorize); + + inlineMost(); + + if (stages > 1) { + tv0c->circularBuffer(stages, prefetch); + tv1c->circularBuffer(stages, prefetch); + } + + // Test that predicate elimination works when the MmaOp's operands have no + // logical broadcasts + GpuLower gpulw(&fusion); + kir::Kernel* kernel = gpulw.run(); + PredicateChecker pred_checker; + pred_checker.handle(kernel->topLevelExprs()); + ASSERT_TRUE(pred_checker.found_mma); + + // TODO: compile and run kernel once inlining is fixed +} + } // namespace nvfuser From 29fe28bdfeb2de83de3e5973ea7b834af2bf5495 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 15 Nov 2024 11:59:22 -0500 Subject: [PATCH 04/35] Allow inlining loop broadcasts --- csrc/scheduler/tools/inlining.cpp | 44 +++++++++++++++++++++++++++++-- tests/cpp/test_matmul.cpp | 3 +++ 2 files changed, 45 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index 16d65571625..35c231ee591 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -193,6 +193,34 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( } return producer->nDims(); } else { + // Gather sets of loop broadcasts (broadcast domains that are not connected + // to the logical domain) + std::unordered_set all_additional_ids{ + producer->domain()->additionalIDs().begin(), + producer->domain()->additionalIDs().end()}; + all_additional_ids.insert( + consumer->domain()->additionalIDs().begin(), + consumer->domain()->additionalIDs().end()); + std::unordered_set loop_broadcasts; + for (TensorView* tv : {producer, consumer}) { + for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween( + {tv->domain()->additionalIDs().begin(), + tv->domain()->additionalIDs().end()}, + {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}, + /*require_all_to_visited=*/false)) { + for (Val* v : expr->inputs()) { + if (auto* id = dynamic_cast(v)) { + loop_broadcasts.insert(id); + } + } + for (Val* v : expr->outputs()) { + if (auto* id = dynamic_cast(v)) { + loop_broadcasts.insert(id); + } + } + } + } + auto consumer_it = consumer->getLoopDomain().begin(); for (const auto producer_pos : c10::irange(producer->nDims())) { auto p_id = producer->getLoopDomain().at(producer_pos); @@ -211,8 +239,20 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( } IterDomain* c_id = *consumer_it; - if (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || - !isAllowedID(c_id, consumer, best_effort, true, false, true)) { + + // If c_id or p_id are a "loop broadcast", then allow inlining past them. + // TODO: should we verify that any non-broadcast IDs in this case are not + // actually used in indexing? + if (loop_broadcasts.count(c_id) == 0 && + loop_broadcasts.count(p_id) == 0 && + (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || + !isAllowedID( + c_id, + consumer, + best_effort, + /*allow_reduction=*/true, + /*allow_vectorize=*/false, + /*allow_unmappable=*/true))) { return producer_pos; } diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index ec087e8fd6c..94392c1d14b 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3935,6 +3935,9 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { inlineMost(); + EXPECT_EQ(tv0c->getComputeAtPosition(), 3); + EXPECT_EQ(tv1c->getComputeAtPosition(), 3); + if (stages > 1) { tv0c->circularBuffer(stages, prefetch); tv1c->circularBuffer(stages, prefetch); From 11083a17f7172bf959b193a0b31bbb118b4f8f5a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 15 Nov 2024 13:11:46 -0500 Subject: [PATCH 05/35] clang-format --- tests/cpp/test_mutator.cpp | 17 +++++++++-------- 1 file changed, 9 insertions(+), 8 deletions(-) diff --git a/tests/cpp/test_mutator.cpp b/tests/cpp/test_mutator.cpp index 763c3b0c4ac..ef1f3433a51 100644 --- a/tests/cpp/test_mutator.cpp +++ b/tests/cpp/test_mutator.cpp @@ -114,10 +114,9 @@ TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) { fusion->addOutput(tv1); - // We add a broadcast domain bS2{1}. This adds the new Broadcast ID to tv1->domain()->additionalIDs() - // logical: [ iS1{i0} ] - // loop: [ iS1{i0}, bS2{1} ] - // additional IDs: [ bS2{1} ] + // We add a broadcast domain bS2{1}. This adds the new Broadcast ID to + // tv1->domain()->additionalIDs() logical: [ iS1{i0} ] loop: [ iS1{i0}, bS2{1} + // ] additional IDs: [ bS2{1} ] tv1->broadcast(1); EXPECT_FALSE(tv1->domain()->additionalIDs().empty()); @@ -130,7 +129,8 @@ TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) { // Now register a mutation that will alter some IDs in the domain OptOutMutator mut; - mut.registerMutation(tv1->axis(0)->extent(), IrBuilder::create(DataType::Index)); + mut.registerMutation( + tv1->axis(0)->extent(), IrBuilder::create(DataType::Index)); TensorDomain* old_tensor_domain = tv1->domain(); auto all_stmts = StmtSort::getStmts( fusion, @@ -140,10 +140,11 @@ TEST_F(NVFuserTest, OptOutMutatorAdditionalBroadcastID) { for (auto stmt : all_stmts) { mut.dispatchMutate(stmt); } - EXPECT_TRUE(tv1->domain() != old_tensor_domain) << "Mutation did not change the TensorDomain"; + EXPECT_TRUE(tv1->domain() != old_tensor_domain) + << "Mutation did not change the TensorDomain"; - EXPECT_FALSE(tv1->domain()->additionalIDs().empty())<< "Mutation did not preserve additional IDs"; + EXPECT_FALSE(tv1->domain()->additionalIDs().empty()) + << "Mutation did not preserve additional IDs"; } - } // namespace nvfuser From 11c43c41a398bfc749f4f6c0b341ffeb8fd98304 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 15 Nov 2024 13:11:54 -0500 Subject: [PATCH 06/35] clang-tidy of TensorDomain ctor --- csrc/ir/nodes.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/ir/nodes.cpp b/csrc/ir/nodes.cpp index f88618e03e5..8758a8022cb 100644 --- a/csrc/ir/nodes.cpp +++ b/csrc/ir/nodes.cpp @@ -3086,7 +3086,7 @@ TensorDomain::TensorDomain( allocation_domain_(std::move(allocation_domain)), loop_domain_(std::move(loop_domain)), initial_loop_domain_(loop_domain_), - additional_ids_(additional_ids), + additional_ids_(std::move(additional_ids)), contiguity_( contiguity.empty() ? getContiguityFilledWith(maybeAllocation(), false) : std::move(contiguity)) { From c06917fc300d2110626e38d293844451f32b0887 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Fri, 15 Nov 2024 14:22:55 -0500 Subject: [PATCH 07/35] Check IterType of loop broadcasts --- csrc/scheduler/tools/inlining.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index 35c231ee591..77c94c52ce3 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -209,12 +209,14 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}, /*require_all_to_visited=*/false)) { for (Val* v : expr->inputs()) { - if (auto* id = dynamic_cast(v)) { + if (auto* id = dynamic_cast(v); + id && id->isBroadcast()) { loop_broadcasts.insert(id); } } for (Val* v : expr->outputs()) { - if (auto* id = dynamic_cast(v)) { + if (auto* id = dynamic_cast(v); + id && id->isBroadcast()) { loop_broadcasts.insert(id); } } From 64be2c7a18cdb7d73e9134c39d0a354d674d6cdf Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 18 Nov 2024 08:48:02 -0500 Subject: [PATCH 08/35] Remove debugging comment --- tests/cpp/test_matmul.cpp | 4 ---- 1 file changed, 4 deletions(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index ec087e8fd6c..c4fed527a42 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3867,10 +3867,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { // NOTE: since in this case we do not have "proper" broadcast in the inputs, // we cannot simply propagate transforms to the operands. Instead, we // propagate forward to the outputs and manually schedule the smem operands. - - // ComputeAtMap in this case finds bS15 and bS16. They are in the loop domain - // at this point - scheduler_utils::BoundedDirectionalTransformPropagator::forward( tv2, -1, From cc236fd83668b6f86c6b69bfcb92e23e00c8b646 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 18 Nov 2024 11:47:22 -0500 Subject: [PATCH 09/35] Track IDs used in indexing --- csrc/scheduler/tools/inlining.cpp | 86 ++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 25 deletions(-) diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index 77c94c52ce3..c1cb0526758 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -193,35 +193,55 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( } return producer->nDims(); } else { - // Gather sets of loop broadcasts (broadcast domains that are not connected - // to the logical domain) - std::unordered_set all_additional_ids{ - producer->domain()->additionalIDs().begin(), - producer->domain()->additionalIDs().end()}; - all_additional_ids.insert( - consumer->domain()->additionalIDs().begin(), - consumer->domain()->additionalIDs().end()); - std::unordered_set loop_broadcasts; - for (TensorView* tv : {producer, consumer}) { + // First we find the consumer root IDs that map to the producer + auto c2p = + PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + // Track the IDs involved in indexing in both producer and consumer + std::vector consumer_root_indexing_ids; + std::vector producer_logical_indexing_ids; + for (IterDomain* id : consumer->getMaybeRootDomain()) { + auto it = c2p.find(id); + if (it == c2p.end()) { + continue; + } + // These are the immediately mapped consumer root and producer logical + // IDs. This is a starting point for our later traversals, which will fill + // these sets out. + consumer_root_indexing_ids.push_back(it->first); + producer_logical_indexing_ids.push_back(it->second); + } + + // Now traverse from the starting set (which, as noted above is a subset of + // either the producer logical or consumer root) to the target which is + // either the producer allocation domain or the consumer loop domain. These + // are the IDs that will actually affect indexing. Any other IDs can be + // skipped. + auto traverse = [](const std::vector& start_domain, + const std::vector& target_domain) + -> std::unordered_set { + std::unordered_set indexing_ids{ + start_domain.begin(), start_domain.end()}; for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween( - {tv->domain()->additionalIDs().begin(), - tv->domain()->additionalIDs().end()}, - {tv->getLoopDomain().begin(), tv->getLoopDomain().end()}, + start_domain, + {target_domain.begin(), target_domain.end()}, /*require_all_to_visited=*/false)) { for (Val* v : expr->inputs()) { - if (auto* id = dynamic_cast(v); - id && id->isBroadcast()) { - loop_broadcasts.insert(id); + if (auto* id = dynamic_cast(v)) { + indexing_ids.insert(id); } } for (Val* v : expr->outputs()) { - if (auto* id = dynamic_cast(v); - id && id->isBroadcast()) { - loop_broadcasts.insert(id); + if (auto* id = dynamic_cast(v)) { + indexing_ids.insert(id); } } } - } + return indexing_ids; + }; + std::unordered_set producer_indexing_ids = traverse( + producer_logical_indexing_ids, producer->getMaybeAllocationDomain()); + std::unordered_set consumer_indexing_ids = + traverse(consumer_root_indexing_ids, consumer->getLoopDomain()); auto consumer_it = consumer->getLoopDomain().begin(); for (const auto producer_pos : c10::irange(producer->nDims())) { @@ -242,11 +262,27 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( IterDomain* c_id = *consumer_it; - // If c_id or p_id are a "loop broadcast", then allow inlining past them. - // TODO: should we verify that any non-broadcast IDs in this case are not - // actually used in indexing? - if (loop_broadcasts.count(c_id) == 0 && - loop_broadcasts.count(p_id) == 0 && + // If either ID is involved in indexing then we need to make sure they're + // both mapped in the inlining graph or that this is a special case + // covered by isAllowedID. + // + // For example, an MmaOp with no broadcasts could contain the following: + // tv0: + // root/logical: [ iS0, iS1 ] + // loop: [ iS0, bS7, iS1 ] + // tv1: + // root/logical: [ iS2, iS3 ] + // loop: [ bS8, iS2, iS3 ] + // tv2: + // root/logical/loop: [ iS4, iS5, rS6 ] + // + // iS4 maps to iS0 so when producer==tv0 we inline past iS0. When + // producer==tv1, iS4 doesn't map to anything in tv1 and is not used for + // indexing, and bS8 is also not used in indexing (it's a loop broadcast) + // so we inline past the first ID in that case also. Similarly, we inline + // past iS5, iS2, and bS7. + if (!(consumer_indexing_ids.count(c_id) == 0 && + producer_indexing_ids.count(p_id) == 0) && (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || !isAllowedID( c_id, From 05d5ca42b95d1bceb0794f35496c466cdeea02c3 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 18 Nov 2024 15:09:16 -0500 Subject: [PATCH 10/35] [DO NOT MERGE] added throw to test impact on existing tests --- csrc/device_lower/analysis/predicate_elimination.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 3159de8a5db..78c88bfea9a 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -242,6 +242,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // producer. If it is not connected to the producer allocation domain in // the broadcast graph, then we can skip processing it. if (index_ids_.find(consumer_id) == index_ids_.end()) { + NVF_THROW("FOUND UNEXPECTED PATH IN TEST"); return false; } needs_predicate_ = false; From 59745676a37d983d748c88626f8e55c101f7280d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 18 Nov 2024 21:56:32 -0500 Subject: [PATCH 11/35] Refactor getting indexing IDs into utility This updates the NVF_THROW check to rule out the BroadcastOp case. --- .../analysis/predicate_elimination.cpp | 45 ++-------- csrc/device_lower/utils.cpp | 84 +++++++++++++++++++ csrc/device_lower/utils.h | 9 ++ 3 files changed, 101 insertions(+), 37 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 78c88bfea9a..4fd7a4a42b1 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -183,42 +183,11 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { return true; } - auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer); auto c2p = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) - .getReplay(); - - // Find all IterDomains involved in index expressions - // TODO: do we need to find logical IDs in producer that are involved in - // its allocation domain too? - std::vector mapped_root_vals, loop_vals; - for (IterDomain* id : consumer->getRootDomain()) { - if (c2p.find(id) != c2p.end()) { - mapped_root_vals.push_back(id); - } - } - for (IterDomain* id : consumer->getLoopDomain()) { - loop_vals.push_back(id); - } - - // Collect all IterDomains along path instead of Exprs - std::unordered_set index_ids; - for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween( - mapped_root_vals, - loop_vals, - /*require_all_to_visited=*/false)) { - for (Val* v : expr->inputs()) { - if (auto* id = dynamic_cast(v)) { - index_ids.insert(id); - } - } - for (Val* v : expr->outputs()) { - if (auto* id = dynamic_cast(v)) { - index_ids.insert(id); - } - } - } - ProducerConsumerPairAnalyzer analyzer(c2p, index_ids); + PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + [[maybe_unused]] const auto [producer_index_ids, consumer_index_ids] = + lower_utils::getIndexIDs(producer, consumer, &c2p); + ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); for (auto id : consumer->getLoopDomain()) { if (analyzer.needsPredicate(id)) { @@ -238,10 +207,12 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // Returns true if no out-of-bound accesses could occur with a // producer bool needsPredicate(IterDomain* consumer_id) { - // TODO: check that this consumer_id is actually involved in indexing the + // Check that this consumer_id is actually involved in indexing the // producer. If it is not connected to the producer allocation domain in // the broadcast graph, then we can skip processing it. - if (index_ids_.find(consumer_id) == index_ids_.end()) { + if (!consumer_id->isBroadcast() && + index_ids_.find(consumer_id) == index_ids_.end()) { + // TODO: Remove this line and the isBroadcast check in the condition above NVF_THROW("FOUND UNEXPECTED PATH IN TEST"); return false; } diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index f77a7520e9a..d0a5b6488c2 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2004,6 +2004,90 @@ std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { return sync_exprs; } +std::pair, std::unordered_set> +getIndexIDs( + TensorView* producer, + TensorView* consumer, + const std::unordered_map* c2p) { + // First we find the consumer root IDs that map to the producer + std::unordered_map c2p_tmp; + if (c2p == nullptr) { + auto c2p_tmp = + PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + c2p = &c2p_tmp; + } + // Track the IDs involved in indexing in both producer and consumer + std::unordered_set consumer_indexing_ids; + std::unordered_set producer_indexing_ids; + for (IterDomain* id : consumer->getMaybeRootDomain()) { + auto it = c2p->find(id); + if (it == c2p->end()) { + continue; + } + // These are the immediately mapped consumer root and producer logical + // IDs. This is a starting point for our later traversals, which will fill + // these sets out. + consumer_indexing_ids.insert(it->first); + producer_indexing_ids.insert(it->second); + } + + // Now traverse from the starting set (which, as noted above is a subset of + // either the producer logical or consumer root) to the target which is + // either the producer allocation domain or the consumer loop domain. These + // are the IDs that will actually affect indexing. Any other IDs can be + // skipped. + auto traverse = [](std::unordered_set& indexing_ids, + const std::vector& start_domain, + const std::vector& target_domain) { + for (auto [expr, dir] : IRBFS::getExprsBetween( + {start_domain.begin(), start_domain.end()}, + {target_domain.begin(), target_domain.end()}, + /*require_all_to_visited=*/false)) { + // If there are any indexing IDs in the inputs, count all outputs as + // indexing IDs + if (dir == Direction::Forward) { + if (std::any_of( + expr->inputs().begin(), expr->inputs().end(), [&](Val* input) { + auto* id = dynamic_cast(input); + return id && indexing_ids.count(id) != 0; + })) { + for (Val* v : expr->outputs()) { + if (auto* id = dynamic_cast(v)) { + indexing_ids.insert(id); + } + } + } + } else if (dir == Direction::Backward) { + if (std::any_of( + expr->outputs().begin(), + expr->outputs().end(), + [&](Val* output) { + auto* id = dynamic_cast(output); + return id && indexing_ids.count(id) != 0; + })) { + for (Val* v : expr->inputs()) { + if (auto* id = dynamic_cast(v)) { + indexing_ids.insert(id); + } + } + } + } else { + NVF_THROW("Found unexpected direction"); + } + } + }; + traverse( + producer_indexing_ids, + /*start_domain=*/producer->getLogicalDomain(), + /*target_domain=*/producer->getMaybeAllocationDomain()); + traverse( + consumer_indexing_ids, + /*start_domain=*/consumer->getMaybeRootDomain(), + /*target_domain=*/consumer->getLoopDomain()); + + return {producer_indexing_ids, consumer_indexing_ids}; +} + } // namespace lower_utils } // namespace nvfuser diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index 2f53e7ed0ae..f78208243e8 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -19,6 +19,7 @@ #include #include +#include "logical_domain_map.h" // Provides utilities for dealing with nested ForLoop and IfThenElse scopes @@ -379,6 +380,14 @@ std::vector getSyncExprs( AsyncOpType async_type, int64_t keep_stages = 0); +//! Get the set of IterDomains on the shortest path from the producer allocation +//! domain to the consumer loop domain. +std::pair, std::unordered_set> +getIndexIDs( + TensorView* producer, + TensorView* consumer, + const std::unordered_map* c2p = nullptr); + } // namespace lower_utils } // namespace nvfuser From e0ad380f7bbefa68dc0bdcfe07adda7511a12826 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 07:18:26 -0500 Subject: [PATCH 12/35] Put back accidentally removed replay --- csrc/device_lower/analysis/predicate_elimination.cpp | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 4fd7a4a42b1..34296fdf006 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -183,8 +183,10 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { return true; } + auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer); auto c2p = - PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); + BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) + .getReplay(); [[maybe_unused]] const auto [producer_index_ids, consumer_index_ids] = lower_utils::getIndexIDs(producer, consumer, &c2p); ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); @@ -213,7 +215,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { if (!consumer_id->isBroadcast() && index_ids_.find(consumer_id) == index_ids_.end()) { // TODO: Remove this line and the isBroadcast check in the condition above - NVF_THROW("FOUND UNEXPECTED PATH IN TEST"); + NVF_THROW("FOUND UNEXPECTED PATH IN TEST ", consumer_id->toString()); return false; } needs_predicate_ = false; From 3c2631fbc202e805cc6ba3c09ecbeb8042d42d20 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 09:25:52 -0500 Subject: [PATCH 13/35] Add skipped root->logical mappings in c2p --- csrc/device_lower/analysis/predicate_elimination.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 34296fdf006..fb1911a8425 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -185,8 +185,13 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { auto pairwise_map = PairwiseLogicalDomainMap(producer, consumer); auto c2p = - BestEffortReplay::replayPasC(producer, consumer, -1, pairwise_map) + BestEffortReplay::replayPasC( + producer, consumer, /*consumer_compute_at_axis=*/-1, pairwise_map) .getReplay(); + for (auto [c, p] : pairwise_map.mapConsumerToProducer()) { + // replayPasC skips mapping after the compute at position + c2p[c] = p; + } [[maybe_unused]] const auto [producer_index_ids, consumer_index_ids] = lower_utils::getIndexIDs(producer, consumer, &c2p); ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); From 3342e77bfa5e45c42acf41f2dfeb9f41af9c25f4 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 09:26:41 -0500 Subject: [PATCH 14/35] Simplify getIndexIDs --- csrc/device_lower/utils.cpp | 32 ++++++++++++-------------------- 1 file changed, 12 insertions(+), 20 deletions(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index d0a5b6488c2..3fccec517ac 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2045,32 +2045,24 @@ getIndexIDs( /*require_all_to_visited=*/false)) { // If there are any indexing IDs in the inputs, count all outputs as // indexing IDs - if (dir == Direction::Forward) { - if (std::any_of( - expr->inputs().begin(), expr->inputs().end(), [&](Val* input) { - auto* id = dynamic_cast(input); - return id && indexing_ids.count(id) != 0; - })) { - for (Val* v : expr->outputs()) { + const auto processExpr = [&indexing_ids]( + const std::vector& prev_vals, + const std::vector& next_vals) { + if (std::any_of(prev_vals.begin(), prev_vals.end(), [&](Val* prev) { + auto* id = dynamic_cast(prev); + return id && indexing_ids.count(id) != 0; + })) { + for (Val* v : next_vals) { if (auto* id = dynamic_cast(v)) { indexing_ids.insert(id); } } } + }; + if (dir == Direction::Forward) { + processExpr(expr->inputs(), expr->outputs()); } else if (dir == Direction::Backward) { - if (std::any_of( - expr->outputs().begin(), - expr->outputs().end(), - [&](Val* output) { - auto* id = dynamic_cast(output); - return id && indexing_ids.count(id) != 0; - })) { - for (Val* v : expr->inputs()) { - if (auto* id = dynamic_cast(v)) { - indexing_ids.insert(id); - } - } - } + processExpr(expr->outputs(), expr->inputs()); } else { NVF_THROW("Found unexpected direction"); } From ee5329fc21b4f8c995b5788baaa5a4f84199886d Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 12:32:13 -0500 Subject: [PATCH 15/35] Remove NVF_THROW and disable matmul test for codediff --- csrc/device_lower/analysis/predicate_elimination.cpp | 5 +---- tests/cpp/test_matmul.cpp | 2 +- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index fb1911a8425..dd357a62608 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -217,10 +217,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // Check that this consumer_id is actually involved in indexing the // producer. If it is not connected to the producer allocation domain in // the broadcast graph, then we can skip processing it. - if (!consumer_id->isBroadcast() && - index_ids_.find(consumer_id) == index_ids_.end()) { - // TODO: Remove this line and the isBroadcast check in the condition above - NVF_THROW("FOUND UNEXPECTED PATH IN TEST ", consumer_id->toString()); + if (index_ids_.find(consumer_id) == index_ids_.end()) { return false; } needs_predicate_ = false; diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index c4fed527a42..2f6709d1e18 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3802,7 +3802,7 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { } // Test scheduling a Hopper matmul where the operands are 2D -TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { +TEST_F(HopperMatmulTest, DISABLED_HSH_NT_128BSwizzle_NoBroadcasts) { Fusion fusion; FusionGuard fg(&fusion); From 0cf29e5126f3752e6bcef703cd2869e6ff5e6b7a Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 19 Nov 2024 15:37:48 -0500 Subject: [PATCH 16/35] Enable test codediff passed --- tests/cpp/test_matmul.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 2f6709d1e18..c4fed527a42 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3802,7 +3802,7 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { } // Test scheduling a Hopper matmul where the operands are 2D -TEST_F(HopperMatmulTest, DISABLED_HSH_NT_128BSwizzle_NoBroadcasts) { +TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { Fusion fusion; FusionGuard fg(&fusion); From 381035fcace11805df0c4b9a813f53f29d533d93 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 20 Nov 2024 12:24:23 -0500 Subject: [PATCH 17/35] Avoid processing non-indexing inputs to Merge If we have a non-indexing ID id1 and an indexing ID id2 and we merge them, we should only need to process id2. --- csrc/device_lower/analysis/predicate_elimination.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index dd357a62608..ce0cc1d5c5a 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -226,6 +226,9 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { } void handle(IterDomain* consumer_id) override { + if (index_ids_.find(consumer_id) == index_ids_.end()) { + return; + } // The traversal should have ended if needs_predicate_ was true NVF_ERROR(!needs_predicate_); From 732b8738f9349757ad55c4ad4b047b5ce6ec4e87 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 20 Nov 2024 13:21:41 -0500 Subject: [PATCH 18/35] Remove declaration that shadowed c2p_tmp This doesn't affect this PR but has an impact on the inlining use case --- csrc/device_lower/utils.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 3fccec517ac..c52267195ec 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2012,7 +2012,7 @@ getIndexIDs( // First we find the consumer root IDs that map to the producer std::unordered_map c2p_tmp; if (c2p == nullptr) { - auto c2p_tmp = + c2p_tmp = PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); c2p = &c2p_tmp; } From 9feb8f8ca2f552eb142efb4ff7c40b5d38eb76d8 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 20 Nov 2024 13:25:47 -0500 Subject: [PATCH 19/35] Update in light of #3452 --- csrc/device_lower/utils.cpp | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index c52267195ec..75898dc0a21 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -2040,9 +2040,10 @@ getIndexIDs( const std::vector& start_domain, const std::vector& target_domain) { for (auto [expr, dir] : IRBFS::getExprsBetween( - {start_domain.begin(), start_domain.end()}, - {target_domain.begin(), target_domain.end()}, - /*require_all_to_visited=*/false)) { + {start_domain.begin(), start_domain.end()}, + {target_domain.begin(), target_domain.end()}, + /*require_all_to_visited=*/false) + .first) { // If there are any indexing IDs in the inputs, count all outputs as // indexing IDs const auto processExpr = [&indexing_ids]( From c70604616e601d38d66c593f0d552ef622b0bd40 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 20 Nov 2024 15:51:22 -0500 Subject: [PATCH 20/35] Use getIndexIDs --- csrc/scheduler/tools/inlining.cpp | 52 ++----------------------------- 1 file changed, 3 insertions(+), 49 deletions(-) diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index c1cb0526758..a957b36c609 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -5,6 +5,7 @@ * SPDX-License-Identifier: BSD-3-Clause */ // clang-format on +#include #include #include #include @@ -193,55 +194,8 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( } return producer->nDims(); } else { - // First we find the consumer root IDs that map to the producer - auto c2p = - PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); - // Track the IDs involved in indexing in both producer and consumer - std::vector consumer_root_indexing_ids; - std::vector producer_logical_indexing_ids; - for (IterDomain* id : consumer->getMaybeRootDomain()) { - auto it = c2p.find(id); - if (it == c2p.end()) { - continue; - } - // These are the immediately mapped consumer root and producer logical - // IDs. This is a starting point for our later traversals, which will fill - // these sets out. - consumer_root_indexing_ids.push_back(it->first); - producer_logical_indexing_ids.push_back(it->second); - } - - // Now traverse from the starting set (which, as noted above is a subset of - // either the producer logical or consumer root) to the target which is - // either the producer allocation domain or the consumer loop domain. These - // are the IDs that will actually affect indexing. Any other IDs can be - // skipped. - auto traverse = [](const std::vector& start_domain, - const std::vector& target_domain) - -> std::unordered_set { - std::unordered_set indexing_ids{ - start_domain.begin(), start_domain.end()}; - for ([[maybe_unused]] auto [expr, dir] : IRBFS::getExprsBetween( - start_domain, - {target_domain.begin(), target_domain.end()}, - /*require_all_to_visited=*/false)) { - for (Val* v : expr->inputs()) { - if (auto* id = dynamic_cast(v)) { - indexing_ids.insert(id); - } - } - for (Val* v : expr->outputs()) { - if (auto* id = dynamic_cast(v)) { - indexing_ids.insert(id); - } - } - } - return indexing_ids; - }; - std::unordered_set producer_indexing_ids = traverse( - producer_logical_indexing_ids, producer->getMaybeAllocationDomain()); - std::unordered_set consumer_indexing_ids = - traverse(consumer_root_indexing_ids, consumer->getLoopDomain()); + const auto [producer_indexing_ids, consumer_indexing_ids] = + lower_utils::getIndexIDs(producer, consumer); auto consumer_it = consumer->getLoopDomain().begin(); for (const auto producer_pos : c10::irange(producer->nDims())) { From 6f451f7ab294a48d551f421c6d07f534fec0d295 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 27 Nov 2024 09:36:57 -0500 Subject: [PATCH 21/35] Only check index IDs for MmaOp --- .../analysis/predicate_elimination.cpp | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index ce0cc1d5c5a..3f7e5cafed9 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -192,8 +192,14 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // replayPasC skips mapping after the compute at position c2p[c] = p; } - [[maybe_unused]] const auto [producer_index_ids, consumer_index_ids] = - lower_utils::getIndexIDs(producer, consumer, &c2p); + std::unordered_set consumer_index_ids; + if (consumer->definition()->isA()) { + // NOTE: if consumer_index_ids is empty, it will be ignored. We only fill + // it for MmaOp for now in order to limit our changes to the only op that + // currently requires this analysis. + consumer_index_ids = + lower_utils::getIndexIDs(producer, consumer, &c2p).second; + } ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); for (auto id : consumer->getLoopDomain()) { @@ -217,7 +223,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // Check that this consumer_id is actually involved in indexing the // producer. If it is not connected to the producer allocation domain in // the broadcast graph, then we can skip processing it. - if (index_ids_.find(consumer_id) == index_ids_.end()) { + if (!index_ids_.empty() && + index_ids_.find(consumer_id) == index_ids_.end()) { return false; } needs_predicate_ = false; @@ -226,7 +233,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { } void handle(IterDomain* consumer_id) override { - if (index_ids_.find(consumer_id) == index_ids_.end()) { + if (!index_ids_.empty() && + index_ids_.find(consumer_id) == index_ids_.end()) { return; } // The traversal should have ended if needs_predicate_ was true From bd45e7ebdfded9e9c93a46bc26804c0bb7fe1ac5 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 3 Dec 2024 10:20:01 -0500 Subject: [PATCH 22/35] Guard changes so they only affect MmaOp --- csrc/scheduler/tools/inlining.cpp | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index a957b36c609..9d47a7720aa 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -194,8 +194,12 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( } return producer->nDims(); } else { - const auto [producer_indexing_ids, consumer_indexing_ids] = - lower_utils::getIndexIDs(producer, consumer); + std::unordered_set producer_indexing_ids, + consumer_indexing_ids; + if (consumer->definition()->isA()) { + std::tie(producer_indexing_ids, consumer_indexing_ids) = + lower_utils::getIndexIDs(producer, consumer); + } auto consumer_it = consumer->getLoopDomain().begin(); for (const auto producer_pos : c10::irange(producer->nDims())) { @@ -235,8 +239,9 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( // indexing, and bS8 is also not used in indexing (it's a loop broadcast) // so we inline past the first ID in that case also. Similarly, we inline // past iS5, iS2, and bS7. - if (!(consumer_indexing_ids.count(c_id) == 0 && - producer_indexing_ids.count(p_id) == 0) && + if (!(!consumer_indexing_ids.empty() && !producer_indexing_ids.empty() && + (consumer_indexing_ids.count(c_id) == 0 && + producer_indexing_ids.count(p_id) == 0)) && (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || !isAllowedID( c_id, From 9fb9aad9d9681142f79f0dc588f9eaf13a99324e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 3 Dec 2024 15:19:37 -0500 Subject: [PATCH 23/35] Simplify utility to lower_utils::getIdsBetween --- .../analysis/predicate_elimination.cpp | 19 ++-- csrc/device_lower/utils.cpp | 96 +++++-------------- csrc/device_lower/utils.h | 15 ++- 3 files changed, 44 insertions(+), 86 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 3f7e5cafed9..e89a850451e 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -188,17 +188,24 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { BestEffortReplay::replayPasC( producer, consumer, /*consumer_compute_at_axis=*/-1, pairwise_map) .getReplay(); - for (auto [c, p] : pairwise_map.mapConsumerToProducer()) { - // replayPasC skips mapping after the compute at position - c2p[c] = p; - } std::unordered_set consumer_index_ids; if (consumer->definition()->isA()) { // NOTE: if consumer_index_ids is empty, it will be ignored. We only fill // it for MmaOp for now in order to limit our changes to the only op that // currently requires this analysis. - consumer_index_ids = - lower_utils::getIndexIDs(producer, consumer, &c2p).second; + + // We flow from mapped IDs to the consumer's loop domain + std::vector mapped_ids; + auto root2logical = pairwise_map.mapConsumerToProducer(); + for (IterDomain* id : consumer->getMaybeRootDomain()) { + if (root2logical.find(id) != root2logical.end()) { + mapped_ids.push_back(id); + } + } + // This set will omit loop IDs that are not mapped to the producer, such + // as N dimensions when the producer is an A operand without broadcasts. + consumer_index_ids = lower_utils::getIdsBetween( + /*from=*/mapped_ids, /*to=*/consumer->getLoopDomain()); } ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index f4ea1c46596..52620dbd527 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -1987,81 +1987,33 @@ std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { return sync_exprs; } -std::pair, std::unordered_set> -getIndexIDs( - TensorView* producer, - TensorView* consumer, - const std::unordered_map* c2p) { - // First we find the consumer root IDs that map to the producer - std::unordered_map c2p_tmp; - if (c2p == nullptr) { - c2p_tmp = - PairwiseLogicalDomainMap(producer, consumer).mapConsumerToProducer(); - c2p = &c2p_tmp; - } - // Track the IDs involved in indexing in both producer and consumer - std::unordered_set consumer_indexing_ids; - std::unordered_set producer_indexing_ids; - for (IterDomain* id : consumer->getMaybeRootDomain()) { - auto it = c2p->find(id); - if (it == c2p->end()) { - continue; - } - // These are the immediately mapped consumer root and producer logical - // IDs. This is a starting point for our later traversals, which will fill - // these sets out. - consumer_indexing_ids.insert(it->first); - producer_indexing_ids.insert(it->second); - } - - // Now traverse from the starting set (which, as noted above is a subset of - // either the producer logical or consumer root) to the target which is - // either the producer allocation domain or the consumer loop domain. These - // are the IDs that will actually affect indexing. Any other IDs can be - // skipped. - auto traverse = [](std::unordered_set& indexing_ids, - const std::vector& start_domain, - const std::vector& target_domain) { - for (auto [expr, dir] : IRBFS::getExprsBetween( - {start_domain.begin(), start_domain.end()}, - {target_domain.begin(), target_domain.end()}, - /*require_all_to_visited=*/false) - .first) { - // If there are any indexing IDs in the inputs, count all outputs as - // indexing IDs - const auto processExpr = [&indexing_ids]( - const std::vector& prev_vals, - const std::vector& next_vals) { - if (std::any_of(prev_vals.begin(), prev_vals.end(), [&](Val* prev) { - auto* id = dynamic_cast(prev); - return id && indexing_ids.count(id) != 0; - })) { - for (Val* v : next_vals) { - if (auto* id = dynamic_cast(v)) { - indexing_ids.insert(id); - } - } +std::unordered_set getIdsBetween( + const std::vector& from, + const std::vector& to) { + std::unordered_set ids{from.begin(), from.end()}; + for (auto [expr, dir] : getExprsBetween( + {from.begin(), from.end()}, + {to.begin(), to.end()}, + /*require_all_to_visited=*/false) + .first) { + const std::vector& prev_vals = + dir == Direction::Forward ? expr->inputs() : expr->outputs(); + const std::vector& next_vals = + dir == Direction::Forward ? expr->outputs() : expr->inputs(); + // If there are _any_ IDs that were found in prev_vals then we count all the + // next vals as found + if (std::any_of(prev_vals.begin(), prev_vals.end(), [&](Val* prev) { + auto* id = dynamic_cast(prev); + return id && ids.count(id) != 0; + })) { + for (Val* v : next_vals) { + if (auto* id = dynamic_cast(v)) { + ids.insert(id); } - }; - if (dir == Direction::Forward) { - processExpr(expr->inputs(), expr->outputs()); - } else if (dir == Direction::Backward) { - processExpr(expr->outputs(), expr->inputs()); - } else { - NVF_THROW("Found unexpected direction"); } } - }; - traverse( - producer_indexing_ids, - /*start_domain=*/producer->getLogicalDomain(), - /*target_domain=*/producer->getMaybeAllocationDomain()); - traverse( - consumer_indexing_ids, - /*start_domain=*/consumer->getMaybeRootDomain(), - /*target_domain=*/consumer->getLoopDomain()); - - return {producer_indexing_ids, consumer_indexing_ids}; + } + return ids; } } // namespace lower_utils diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index f187cccb1f9..2ddf0463646 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -19,7 +19,6 @@ #include #include -#include "logical_domain_map.h" // Provides utilities for dealing with nested ForLoop and IfThenElse scopes @@ -374,13 +373,13 @@ std::vector getSyncExprs( AsyncOpType async_type, int64_t keep_stages = 0); -//! Get the set of IterDomains on the shortest path from the producer allocation -//! domain to the consumer loop domain. -std::pair, std::unordered_set> -getIndexIDs( - TensorView* producer, - TensorView* consumer, - const std::unordered_map* c2p = nullptr); +//! Get a set of IterDomains in TV between two given domains +//! (inclusive). If `from` is provided, IDs without any producers in `from` will +//! be omitted. +//! TODO: example: +std::unordered_set getIdsBetween( + const std::vector& from, + const std::vector& to); } // namespace lower_utils From e623561ee5b0a1d61d835ed44a1351c2d1a9c61c Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 3 Dec 2024 16:01:26 -0500 Subject: [PATCH 24/35] Rename to getIdsAlongPathBetween and add example to comment --- .../analysis/predicate_elimination.cpp | 2 +- csrc/device_lower/utils.cpp | 2 +- csrc/device_lower/utils.h | 23 +++++++++++++++---- 3 files changed, 20 insertions(+), 7 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index e89a850451e..8e7a18ed095 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -204,7 +204,7 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { } // This set will omit loop IDs that are not mapped to the producer, such // as N dimensions when the producer is an A operand without broadcasts. - consumer_index_ids = lower_utils::getIdsBetween( + consumer_index_ids = lower_utils::getIdsAlongPathBetween( /*from=*/mapped_ids, /*to=*/consumer->getLoopDomain()); } ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 52620dbd527..0d16621daa3 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -1987,7 +1987,7 @@ std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { return sync_exprs; } -std::unordered_set getIdsBetween( +std::unordered_set getIdsAlongPathBetween( const std::vector& from, const std::vector& to) { std::unordered_set ids{from.begin(), from.end()}; diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index 2ddf0463646..b2f71730609 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -373,11 +373,24 @@ std::vector getSyncExprs( AsyncOpType async_type, int64_t keep_stages = 0); -//! Get a set of IterDomains in TV between two given domains -//! (inclusive). If `from` is provided, IDs without any producers in `from` will -//! be omitted. -//! TODO: example: -std::unordered_set getIdsBetween( +//! Get a set of IterDomains on a path between two given domains (inclusive). +//! +//! For example: +//! +//! i3 = merge(i0, i1) +//! i4, i5 = split(i3) +//! +//! If we are given +//! from = [ i0, i2 ] +//! to = [ i4 ] +//! This will return [ i0, i2, i3, i4, i5 ] +//! +//! If we are given +//! from = [ i4, i5 ] +//! to = [ i1 ] +//! This will return [ i4, i5, i3, i0, i1 ] +//! +std::unordered_set getIdsAlongPathBetween( const std::vector& from, const std::vector& to); From 0660f8c64d973efcc9c2ab10d593efe8bd186a0e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Wed, 4 Dec 2024 09:27:32 -0500 Subject: [PATCH 25/35] Use loop group traversal from alloc to loop --- .../analysis/predicate_elimination.cpp | 93 ++++++++++++++----- 1 file changed, 71 insertions(+), 22 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 8e7a18ed095..9dc40160c39 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -18,6 +18,7 @@ #include #include #include +#include "val_graph_visitor.h" namespace nvfuser { @@ -188,26 +189,72 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { BestEffortReplay::replayPasC( producer, consumer, /*consumer_compute_at_axis=*/-1, pairwise_map) .getReplay(); - std::unordered_set consumer_index_ids; + + // The variables graph and alloc_to_loop_groups are used to check whether we + // need to check a particular consumer ID. The alloc_to_loop_groups set + // constaints ValGroups along a shortest path in the loop graph from + // non-trivial dimensions in the allocation domain of the producer to the + // consumer's loop domain. Other domains might exist in the loop domain of + // the consumer: for example, for MmaOp we sometimes do not map the N + // dimension of the output logical domain to any ID in the A operand. We + // use this set to avoid performing unnecessary checks on these types of + // irrelevant consumer IDs. + // + // NOTE: if graph is nullptr, it will be + // ignored. We only fill it for MmaOp for now in order to limit our changes + // to the only op that currently requires this analysis. + const ValGraph* graph = nullptr; + std::unordered_set alloc_to_loop_groups; if (consumer->definition()->isA()) { - // NOTE: if consumer_index_ids is empty, it will be ignored. We only fill - // it for MmaOp for now in order to limit our changes to the only op that - // currently requires this analysis. + // Fill ValGraph and grab all ValGroups on path from producer alloc to + // consumer loop. + IdModel& id_model = GpuLower::current()->idModel(); + id_model.maybeBuildGraph(IdMappingMode::LOOP); + const ValGraph& loop_graph = id_model.idGraph(IdMappingMode::LOOP); // We flow from mapped IDs to the consumer's loop domain - std::vector mapped_ids; - auto root2logical = pairwise_map.mapConsumerToProducer(); - for (IterDomain* id : consumer->getMaybeRootDomain()) { - if (root2logical.find(id) != root2logical.end()) { - mapped_ids.push_back(id); + ValGroups alloc_groups; + for (IterDomain* id : producer->getMaybeAllocationDomain()) { + if (!id->isBroadcast() && !id->isReduction()) { + alloc_groups.pushBack(loop_graph.toGroup(id)); + } + } + ValGroups loop_groups; + for (IterDomain* id : consumer->getLoopDomain()) { + loop_groups.pushBack(loop_graph.toGroup(id)); + } + + const auto [path, all_reached] = ValGraphBFS::getExprGroupsBetween( + loop_graph, + /*from=*/alloc_groups, + /*to=*/loop_groups, + /*require_all_to_visited=*/false); + + if (!all_reached) { + // If we reached all loop groups, there's no need to perform this check + graph = &loop_graph; + alloc_to_loop_groups.insert(alloc_groups.begin(), alloc_groups.end()); + for (const auto& [expr_group, direction] : path) { + const std::vector prev_groups = + direction == Direction::Forward + ? loop_graph.inputGroups(expr_group) + : loop_graph.outputGroups(expr_group); + const std::vector next_groups = + direction == Direction::Forward + ? loop_graph.outputGroups(expr_group) + : loop_graph.inputGroups(expr_group); + if (std::any_of( + prev_groups.begin(), + prev_groups.end(), + [&alloc_to_loop_groups](const ValGroup& group) { + return alloc_to_loop_groups.count(group) > 0; + })) { + alloc_to_loop_groups.insert(next_groups.begin(), next_groups.end()); + } } } - // This set will omit loop IDs that are not mapped to the producer, such - // as N dimensions when the producer is an A operand without broadcasts. - consumer_index_ids = lower_utils::getIdsAlongPathBetween( - /*from=*/mapped_ids, /*to=*/consumer->getLoopDomain()); } - ProducerConsumerPairAnalyzer analyzer(c2p, consumer_index_ids); + ProducerConsumerPairAnalyzer analyzer(c2p, graph, alloc_to_loop_groups); for (auto id : consumer->getLoopDomain()) { if (analyzer.needsPredicate(id)) { @@ -221,17 +268,18 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { private: ProducerConsumerPairAnalyzer( const std::unordered_map& c2p, - const std::unordered_set index_ids) - : c2p_(c2p), index_ids_(index_ids) {} + const ValGraph* graph, + const std::unordered_set alloc_to_loop_groups) + : c2p_(c2p), graph_(graph), alloc_to_loop_groups_(alloc_to_loop_groups) {} // Returns true if no out-of-bound accesses could occur with a // producer bool needsPredicate(IterDomain* consumer_id) { // Check that this consumer_id is actually involved in indexing the // producer. If it is not connected to the producer allocation domain in - // the broadcast graph, then we can skip processing it. - if (!index_ids_.empty() && - index_ids_.find(consumer_id) == index_ids_.end()) { + // the indexing graph, then we can skip processing it. + if (graph_ != nullptr && + alloc_to_loop_groups_.count(graph_->toGroup(consumer_id)) == 0) { return false; } needs_predicate_ = false; @@ -240,8 +288,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { } void handle(IterDomain* consumer_id) override { - if (!index_ids_.empty() && - index_ids_.find(consumer_id) == index_ids_.end()) { + if (graph_ != nullptr && + alloc_to_loop_groups_.count(graph_->toGroup(consumer_id)) == 0) { return; } // The traversal should have ended if needs_predicate_ was true @@ -328,7 +376,8 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { //! BestEffort map from consumer IDs to producer IDs const std::unordered_map& c2p_; bool needs_predicate_ = false; - std::unordered_set index_ids_; + const ValGraph* graph_ = nullptr; + const std::unordered_set alloc_to_loop_groups_; }; class PredicateChcker : public IterVisitor { From 6e17d11232529cda7c3779283fad2a902d473a68 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 5 Dec 2024 07:52:27 -0500 Subject: [PATCH 26/35] Remove getIdsAlongPathBetween --- csrc/device_lower/utils.cpp | 29 ----------------------------- csrc/device_lower/utils.h | 21 --------------------- 2 files changed, 50 deletions(-) diff --git a/csrc/device_lower/utils.cpp b/csrc/device_lower/utils.cpp index 0d16621daa3..0ff59a27b26 100644 --- a/csrc/device_lower/utils.cpp +++ b/csrc/device_lower/utils.cpp @@ -1987,35 +1987,6 @@ std::vector getSyncExprs(AsyncOpType async_type, int64_t keep_stages) { return sync_exprs; } -std::unordered_set getIdsAlongPathBetween( - const std::vector& from, - const std::vector& to) { - std::unordered_set ids{from.begin(), from.end()}; - for (auto [expr, dir] : getExprsBetween( - {from.begin(), from.end()}, - {to.begin(), to.end()}, - /*require_all_to_visited=*/false) - .first) { - const std::vector& prev_vals = - dir == Direction::Forward ? expr->inputs() : expr->outputs(); - const std::vector& next_vals = - dir == Direction::Forward ? expr->outputs() : expr->inputs(); - // If there are _any_ IDs that were found in prev_vals then we count all the - // next vals as found - if (std::any_of(prev_vals.begin(), prev_vals.end(), [&](Val* prev) { - auto* id = dynamic_cast(prev); - return id && ids.count(id) != 0; - })) { - for (Val* v : next_vals) { - if (auto* id = dynamic_cast(v)) { - ids.insert(id); - } - } - } - } - return ids; -} - } // namespace lower_utils } // namespace nvfuser diff --git a/csrc/device_lower/utils.h b/csrc/device_lower/utils.h index b2f71730609..fa62d3b3f76 100644 --- a/csrc/device_lower/utils.h +++ b/csrc/device_lower/utils.h @@ -373,27 +373,6 @@ std::vector getSyncExprs( AsyncOpType async_type, int64_t keep_stages = 0); -//! Get a set of IterDomains on a path between two given domains (inclusive). -//! -//! For example: -//! -//! i3 = merge(i0, i1) -//! i4, i5 = split(i3) -//! -//! If we are given -//! from = [ i0, i2 ] -//! to = [ i4 ] -//! This will return [ i0, i2, i3, i4, i5 ] -//! -//! If we are given -//! from = [ i4, i5 ] -//! to = [ i1 ] -//! This will return [ i4, i5, i3, i0, i1 ] -//! -std::unordered_set getIdsAlongPathBetween( - const std::vector& from, - const std::vector& to); - } // namespace lower_utils } // namespace nvfuser From ee6a89abdf797cfacbfd9999b84c6a60295c74fa Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 5 Dec 2024 08:27:51 -0500 Subject: [PATCH 27/35] Use TensorIndexer and getValsBetween --- .../analysis/predicate_elimination.cpp | 51 ++++++------------- csrc/device_lower/lower2device.cpp | 10 ++-- 2 files changed, 20 insertions(+), 41 deletions(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index 9dc40160c39..c11efffde04 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -18,6 +18,7 @@ #include #include #include +#include "id_model/utils.h" #include "val_graph_visitor.h" namespace nvfuser { @@ -208,51 +209,29 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { if (consumer->definition()->isA()) { // Fill ValGraph and grab all ValGroups on path from producer alloc to // consumer loop. - IdModel& id_model = GpuLower::current()->idModel(); - id_model.maybeBuildGraph(IdMappingMode::LOOP); - const ValGraph& loop_graph = id_model.idGraph(IdMappingMode::LOOP); + + const IdModel& id_model = GpuLower::current()->idModel(); + graph = &GpuLower::current()->tensorIndexer().traversalGraph(); // We flow from mapped IDs to the consumer's loop domain - ValGroups alloc_groups; + std::vector alloc_groups; for (IterDomain* id : producer->getMaybeAllocationDomain()) { + id = getLoopPromotion(id, id_model); if (!id->isBroadcast() && !id->isReduction()) { - alloc_groups.pushBack(loop_graph.toGroup(id)); + alloc_groups.push_back(graph->toGroup(id)); } } - ValGroups loop_groups; + std::vector loop_groups; for (IterDomain* id : consumer->getLoopDomain()) { - loop_groups.pushBack(loop_graph.toGroup(id)); + id = getLoopPromotion(id, id_model); + loop_groups.push_back(graph->toGroup(id)); } - const auto [path, all_reached] = ValGraphBFS::getExprGroupsBetween( - loop_graph, - /*from=*/alloc_groups, - /*to=*/loop_groups, - /*require_all_to_visited=*/false); - - if (!all_reached) { - // If we reached all loop groups, there's no need to perform this check - graph = &loop_graph; - alloc_to_loop_groups.insert(alloc_groups.begin(), alloc_groups.end()); - for (const auto& [expr_group, direction] : path) { - const std::vector prev_groups = - direction == Direction::Forward - ? loop_graph.inputGroups(expr_group) - : loop_graph.outputGroups(expr_group); - const std::vector next_groups = - direction == Direction::Forward - ? loop_graph.outputGroups(expr_group) - : loop_graph.inputGroups(expr_group); - if (std::any_of( - prev_groups.begin(), - prev_groups.end(), - [&alloc_to_loop_groups](const ValGroup& group) { - return alloc_to_loop_groups.count(group) > 0; - })) { - alloc_to_loop_groups.insert(next_groups.begin(), next_groups.end()); - } - } - } + std::vector indexing_groups = + getValsBetween(alloc_groups, loop_groups, *graph); + + alloc_to_loop_groups.insert( + indexing_groups.begin(), indexing_groups.end()); } ProducerConsumerPairAnalyzer analyzer(c2p, graph, alloc_to_loop_groups); diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 79652bc67c5..4ce0cb10546 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -570,11 +570,6 @@ void GpuLower::analysis(Fusion* fusion) { nonDivisibleSplitInfo().build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build nonDivisibleSplitInfo"); - // Detects all exprssions that don't need predicates. Depends on - // nonDivisibleSplitInfo. - pred_elimination_ = std::make_unique(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); - circularBufferInfo().build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build circularBufferInfo"); @@ -589,6 +584,11 @@ void GpuLower::analysis(Fusion* fusion) { tensor_indexer_ = std::make_unique(*id_model_); } + // Detects all exprssions that don't need predicates. Depends on + // nonDivisibleSplitInfo. + pred_elimination_ = std::make_unique(fusion_); + dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); + consumerToTMAInfo() = getConsumerToTMAInfoMap(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "getConsumerToTMAInfoMap"); } From 2959d8899c72beba6a441f41a34712234a0ab74e Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 5 Dec 2024 08:41:29 -0500 Subject: [PATCH 28/35] Don't need to promote allocation IDs --- csrc/device_lower/analysis/predicate_elimination.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/csrc/device_lower/analysis/predicate_elimination.cpp b/csrc/device_lower/analysis/predicate_elimination.cpp index c11efffde04..e79cc177c17 100644 --- a/csrc/device_lower/analysis/predicate_elimination.cpp +++ b/csrc/device_lower/analysis/predicate_elimination.cpp @@ -216,7 +216,6 @@ class ProducerConsumerPairAnalyzer : public OptOutDispatch { // We flow from mapped IDs to the consumer's loop domain std::vector alloc_groups; for (IterDomain* id : producer->getMaybeAllocationDomain()) { - id = getLoopPromotion(id, id_model); if (!id->isBroadcast() && !id->isReduction()) { alloc_groups.push_back(graph->toGroup(id)); } From 62f9ede226e08a06cc605652f089595c7eda5290 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 5 Dec 2024 09:51:15 -0500 Subject: [PATCH 29/35] Use inlining graph path between loop domains instead of getIndexIDs --- csrc/scheduler/tools/inlining.cpp | 32 ++++++++++++++++++------------- 1 file changed, 19 insertions(+), 13 deletions(-) diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index 1f5125b85b5..e9dbae10926 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -12,6 +12,7 @@ #include #include #include +#include #include @@ -194,11 +195,19 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( } return producer->nDims(); } else { - std::unordered_set producer_indexing_ids, - consumer_indexing_ids; + std::unordered_set loop_path_groups; if (consumer->definition()->isA()) { - std::tie(producer_indexing_ids, consumer_indexing_ids) = - lower_utils::getIndexIDs(producer, consumer); + // Get ValGroups between producer and consumer loop in the inlining graph + std::vector producer_loop_groups, consumer_loop_groups; + for (IterDomain* id : producer->getLoopDomain()) { + producer_loop_groups.push_back(inliningGraph().toGroup(id)); + } + for (IterDomain* id : consumer->getLoopDomain()) { + consumer_loop_groups.push_back(inliningGraph().toGroup(id)); + } + std::vector group_path = getValsBetween( + producer_loop_groups, consumer_loop_groups, inliningGraph()); + loop_path_groups.insert(group_path.begin(), group_path.end()); } auto consumer_it = consumer->getLoopDomain().begin(); @@ -220,9 +229,7 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( IterDomain* c_id = *consumer_it; - // If either ID is involved in indexing then we need to make sure they're - // both mapped in the inlining graph or that this is a special case - // covered by isAllowedID. + // We can inline past consumer IDs that are not connected to the producer. // // For example, an MmaOp with no broadcasts could contain the following: // tv0: @@ -236,12 +243,11 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( // // iS4 maps to iS0 so when producer==tv0 we inline past iS0. When // producer==tv1, iS4 doesn't map to anything in tv1 and is not used for - // indexing, and bS8 is also not used in indexing (it's a loop broadcast) - // so we inline past the first ID in that case also. Similarly, we inline - // past iS5, iS2, and bS7. - if (!(!consumer_indexing_ids.empty() && !producer_indexing_ids.empty() && - (consumer_indexing_ids.count(c_id) == 0 && - producer_indexing_ids.count(p_id) == 0)) && + // indexing, and bS8 is a loop broadcast so we inline past the first ID + // in that case also. Similarly, we inline past iS5, iS2, and bS7. + if ((loop_path_groups.empty() || + loop_path_groups.count(inliningGraph().toGroup(p_id)) || + loop_path_groups.count(inliningGraph().toGroup(c_id))) && (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || !isAllowedID( c_id, From 4a3b0d27d8341c54037a1ac7a948ea522a940140 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Thu, 5 Dec 2024 19:36:09 -0500 Subject: [PATCH 30/35] Undo stale change reordering predicate elimination lowering pass --- csrc/device_lower/lower2device.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/csrc/device_lower/lower2device.cpp b/csrc/device_lower/lower2device.cpp index 4ce0cb10546..79652bc67c5 100644 --- a/csrc/device_lower/lower2device.cpp +++ b/csrc/device_lower/lower2device.cpp @@ -570,6 +570,11 @@ void GpuLower::analysis(Fusion* fusion) { nonDivisibleSplitInfo().build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build nonDivisibleSplitInfo"); + // Detects all exprssions that don't need predicates. Depends on + // nonDivisibleSplitInfo. + pred_elimination_ = std::make_unique(fusion_); + dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); + circularBufferInfo().build(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "build circularBufferInfo"); @@ -584,11 +589,6 @@ void GpuLower::analysis(Fusion* fusion) { tensor_indexer_ = std::make_unique(*id_model_); } - // Detects all exprssions that don't need predicates. Depends on - // nonDivisibleSplitInfo. - pred_elimination_ = std::make_unique(fusion_); - dumpExprsIfEnabled(fusion_->exprs(), "build predicateElimination"); - consumerToTMAInfo() = getConsumerToTMAInfoMap(fusion_); dumpExprsIfEnabled(fusion_->exprs(), "getConsumerToTMAInfoMap"); } From cfc7ed92ae07c5123d2f2be466601b3cf09e0532 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Mon, 9 Dec 2024 19:55:46 -0500 Subject: [PATCH 31/35] Get path from mapped IDs. Improve comments. Try asserting more --- csrc/scheduler/tools/inlining.cpp | 85 ++++++++++++++++++++++--------- 1 file changed, 61 insertions(+), 24 deletions(-) diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index e9dbae10926..a81483edb13 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -195,19 +195,36 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( } return producer->nDims(); } else { - std::unordered_set loop_path_groups; + std::optional> loop_path_groups = std::nullopt; if (consumer->definition()->isA()) { - // Get ValGroups between producer and consumer loop in the inlining graph - std::vector producer_loop_groups, consumer_loop_groups; + // We handle MmaOp specially here since it is currently the only operation + // for which we generate code (i.e. not SdpaFwdOp or SdpaBwdOp) that has + // some output dimensions that do not map to input dimensions. For this + // case, we need to identify potential inlined pairs each ID of which is + // not mapped at all to the other TensorView (see example below). + + // Get ValGroups in loop domains of producer and consumer that are + // connected to _mapped_ IterDomains in the pairwise map. + std::vector pairwise_mapped_groups; + for (auto [c_id, p_id] : PairwiseLogicalDomainMap(producer, consumer) + .mapConsumerToProducer()) { + pairwise_mapped_groups.push_back(inliningGraph().toGroup(c_id)); + } + // We propagate toward the loop groups from both consumer and producer + std::vector all_loop_groups; for (IterDomain* id : producer->getLoopDomain()) { - producer_loop_groups.push_back(inliningGraph().toGroup(id)); + all_loop_groups.push_back(inliningGraph().toGroup(id)); } for (IterDomain* id : consumer->getLoopDomain()) { - consumer_loop_groups.push_back(inliningGraph().toGroup(id)); + all_loop_groups.push_back(inliningGraph().toGroup(id)); } + // getValsBetween does not require all target groups to be visited. The + // means the result contains the subset of both loop groups that we are + // looking for std::vector group_path = getValsBetween( - producer_loop_groups, consumer_loop_groups, inliningGraph()); - loop_path_groups.insert(group_path.begin(), group_path.end()); + pairwise_mapped_groups, all_loop_groups, inliningGraph()); + loop_path_groups = + std::unordered_set(group_path.begin(), group_path.end()); } auto consumer_it = consumer->getLoopDomain().begin(); @@ -229,9 +246,11 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( IterDomain* c_id = *consumer_it; - // We can inline past consumer IDs that are not connected to the producer. + // We can inline past positions in which both producer and consumer are + // not connected to any mapped logical IterDomain pairs. + // + // For example, an MmaOp can be constructed as follows: // - // For example, an MmaOp with no broadcasts could contain the following: // tv0: // root/logical: [ iS0, iS1 ] // loop: [ iS0, bS7, iS1 ] @@ -241,21 +260,39 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( // tv2: // root/logical/loop: [ iS4, iS5, rS6 ] // - // iS4 maps to iS0 so when producer==tv0 we inline past iS0. When - // producer==tv1, iS4 doesn't map to anything in tv1 and is not used for - // indexing, and bS8 is a loop broadcast so we inline past the first ID - // in that case also. Similarly, we inline past iS5, iS2, and bS7. - if ((loop_path_groups.empty() || - loop_path_groups.count(inliningGraph().toGroup(p_id)) || - loop_path_groups.count(inliningGraph().toGroup(c_id))) && - (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || - !isAllowedID( - c_id, - consumer, - best_effort, - /*allow_reduction=*/true, - /*allow_vectorize=*/false, - /*allow_unmappable=*/true))) { + // iS4 maps to iS0 so when producer==tv0 we can inline past iS0. When + // producer==tv1, iS4 doesn't map to anything in tv1 and bS8 is a loop + // broadcast in that position so we inline past the first ID in that + // case also. Similarly, we inline past iS5, iS2, and bS7. + if (loop_path_groups.has_value()) { + bool p_id_connected = + loop_path_groups->count(inliningGraph().toGroup(p_id)); + bool c_id_connected = + loop_path_groups->count(inliningGraph().toGroup(c_id)); + NVF_ERROR( + p_id_connected || + (consumer->definition()->isA() && p_id->isBroadcast()), + "Expected unmapped producer id to be broadcast domain in MmaOp input but found ", + p_id->toString()); + + if (!p_id_connected && !c_id_connected) { + NVF_ERROR( + p_id->isBroadcast(), + "Unmapped producer ID must be a broadcast created in scheduling but found ", + p_id->toString()); + ++consumer_it; + continue; + } + } + + if (!inliningGraph().disjointValSets().strictAreMapped(p_id, c_id) || + !isAllowedID( + c_id, + consumer, + best_effort, + /*allow_reduction=*/true, + /*allow_vectorize=*/false, + /*allow_unmappable=*/true)) { return producer_pos; } From 6d99f0d96500201c9f26764fcecf0b0638b84682 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 10 Dec 2024 11:36:11 -0500 Subject: [PATCH 32/35] Add test to check improper inlining is not done --- tests/cpp/test_matmul.cpp | 29 +++++++++++++++++++++-------- 1 file changed, 21 insertions(+), 8 deletions(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 37a284c69cd..6b8d68af5f4 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3657,7 +3657,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { const auto dtype = DataType::Half; constexpr bool use_smem_epilogue = false; - constexpr bool use_warp_specialization = true; constexpr int64_t stages = 4; constexpr int64_t prefetch = 3; @@ -3788,13 +3787,8 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle) { inlineMost(); - if (use_warp_specialization) { - tv0c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy)); - tv1c->circularBuffer(stages, prefetch, WarpSpecialized(ParallelType::TIDy)); - } else { - tv0c->circularBuffer(stages, prefetch); - tv1c->circularBuffer(stages, prefetch); - } + tv0c->circularBuffer(stages, prefetch); + tv1c->circularBuffer(stages, prefetch); auto inputs = matmulAtInput3DHopperSS(M, N, K, layout, data_type_to_aten(dtype)); @@ -3935,6 +3929,25 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { } tv3->axis(-1)->parallelize(ParallelType::Vectorize); + { + // Check using a copy that improperly aligned axis are not inlined + Fusion tmp_fusion; + IrCloner ir_cloner = Fusion::copy(&fusion, &tmp_fusion); + FusionGuard tmp_fg(&tmp_fusion); + // [Mo, No, Ko, Mio, Nio, Mii, Nii, Ki] + // Swap the No and Ko axes, but only in tv2, the mma output + // [Mo, Ko, No, Mio, Nio, Mii, Nii, Ki] + // This should mean the smem operands are now inlined at position 1 instead + // of 3 + ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); + inlineMost(); + tmp_fusion.print(); + ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); + EXPECT_EQ(ir_cloner.clone(tv0c)->getComputeAtPosition(), 1); + // TODO: why is tv1c not inlined past the broadcast Mo dimension? + EXPECT_EQ(ir_cloner.clone(tv1c)->getComputeAtPosition(), 0); + } + inlineMost(); EXPECT_EQ(tv0c->getComputeAtPosition(), 3); From a902803ba50fd55cad00db088964be5d69e33435 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 10 Dec 2024 11:41:53 -0500 Subject: [PATCH 33/35] Remove debug print --- tests/cpp/test_matmul.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index 6b8d68af5f4..eb49bcde59c 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3941,7 +3941,6 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { // of 3 ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); inlineMost(); - tmp_fusion.print(); ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); EXPECT_EQ(ir_cloner.clone(tv0c)->getComputeAtPosition(), 1); // TODO: why is tv1c not inlined past the broadcast Mo dimension? From ff358f7f001440f16a881784823f053d3c570ccd Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 10 Dec 2024 11:46:53 -0500 Subject: [PATCH 34/35] Comment why tv1c is not inlined to position 1 --- tests/cpp/test_matmul.cpp | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/tests/cpp/test_matmul.cpp b/tests/cpp/test_matmul.cpp index eb49bcde59c..5be1801e29e 100644 --- a/tests/cpp/test_matmul.cpp +++ b/tests/cpp/test_matmul.cpp @@ -3941,9 +3941,12 @@ TEST_F(HopperMatmulTest, HSH_NT_128BSwizzle_NoBroadcasts) { // of 3 ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); inlineMost(); + tmp_fusion.printMath(); ir_cloner.clone(tv2)->reorder({{2, 1}, {1, 2}}); EXPECT_EQ(ir_cloner.clone(tv0c)->getComputeAtPosition(), 1); - // TODO: why is tv1c not inlined past the broadcast Mo dimension? + // The outermost loop dim of tv1c is a broadcast Mo axis, so + // tv1c->inlineAt(1) does not inline past that axis and we wind up with + // compute-at position 0. EXPECT_EQ(ir_cloner.clone(tv1c)->getComputeAtPosition(), 0); } From 951757dffeea4c41ca158edd24ab76ebc0b01411 Mon Sep 17 00:00:00 2001 From: Jacob Hinkle Date: Tue, 10 Dec 2024 15:27:53 -0500 Subject: [PATCH 35/35] Add comment about the traversal --- csrc/scheduler/tools/inlining.cpp | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/csrc/scheduler/tools/inlining.cpp b/csrc/scheduler/tools/inlining.cpp index a81483edb13..6064f4e6e7c 100644 --- a/csrc/scheduler/tools/inlining.cpp +++ b/csrc/scheduler/tools/inlining.cpp @@ -205,6 +205,14 @@ size_t MaxPosCalculator::getMaxProducerPosFromConsumer( // Get ValGroups in loop domains of producer and consumer that are // connected to _mapped_ IterDomains in the pairwise map. + // + // Note that for MmaOp, it would be sufficient to traverse from the + // producer loop to the consumer loop and identify when _either_ the + // consumer or producer ID is not mapped. Here we are instead traversing + // from mapped domains to both roots so that we can check that _both_ + // consumer and producer ID is not mapped. This is slightly safer and this + // symmetry might be handy in handling new ops that use this feature in + // the future. std::vector pairwise_mapped_groups; for (auto [c_id, p_id] : PairwiseLogicalDomainMap(producer, consumer) .mapConsumerToProducer()) {