Skip to content

Commit

Permalink
Propagating slice/pad/cat to inputs (#3549)
Browse files Browse the repository at this point in the history
Added a scheduler util function schedule a fusion with resize-based ops
such as slice, pad and concat. This propagates resize ops to producers
so that all tensors have the exact-mapped loop domains.

Part of #3425. Extracted so that it can be individually tested.

(There's a follow-up PR: #3555)
  • Loading branch information
naoyam authored Dec 10, 2024
1 parent 9575fd6 commit 875d765
Show file tree
Hide file tree
Showing 6 changed files with 566 additions and 1 deletion.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,7 @@ list(APPEND NVFUSER_SRCS
${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/loop_domain_scheduler.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/maxinfo_propagator.cpp
${NVFUSER_SRCS_DIR}/scheduler/tools/resize_utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/transpose.cpp
${NVFUSER_SRCS_DIR}/scheduler/utils.cpp
${NVFUSER_SRCS_DIR}/scheduler/vectorize_helper.cpp
Expand Down
4 changes: 4 additions & 0 deletions csrc/scheduler/tools/loop_domain_scheduler.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,10 @@ void scheduleLoopDomainsLike(
LoopDomainScheduler scheduler(ref_loop_dom);

for (auto tv : tvs) {
// Loop domain of fusion inputs should have no meaning
if (tv->isFusionInput()) {
continue;
}
scheduler.schedule(tv);
}
}
Expand Down
70 changes: 70 additions & 0 deletions csrc/scheduler/tools/resize_utils.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on

#include <id_model/id_model.h>
#include <ir/cloner.h>
#include <ir/utils.h>
#include <iter_visitor.h>
#include <logical_domain_map.h>
#include <scheduler/tools/loop_domain_scheduler.h>
#include <scheduler/tools/resize_utils.h>
#include <val_graph_visitor.h>

namespace nvfuser {
namespace scheduler_tools {

void propagateResizeToInputs(Expr* resize_tensor_op) {
NVF_ERROR(
resize_tensor_op->isA<SliceOp>() || resize_tensor_op->isA<PadOp>(),
"Unexpected resize tensor op: ",
resize_tensor_op->toString());

Fusion* fusion = resize_tensor_op->fusion();

auto producer_tv = resize_tensor_op->input(0)->as<TensorView>();
auto consumer_tv = resize_tensor_op->output(0)->as<TensorView>();

auto all_dep_vals = DependencyCheck::getAllValsBetween(
{fusion->inputs().begin(), fusion->inputs().end()}, {producer_tv});

std::vector<TensorView*> tvs_to_schedule;
tvs_to_schedule.reserve(all_dep_vals.size());
for (auto val : all_dep_vals) {
if (val->isA<TensorView>() && !val->isFusionInput()) {
tvs_to_schedule.push_back(val->as<TensorView>());
}
}

// Ideally, this should be just calling
// scheduler_tools::scheduleLoopDomainsLike once with the consumer
// tensor as a reference. However, due to the indexing issue with
// resize, propagating the Resize iter-domain op may fail. To avoid
// the problem, the propagation of the resize op is explicitly done
// by using scheduler_tools::scheduleLoopDomainsBy.
//
// Before doing so, all the dependent tensors need to have the exact-mapped
// loop domain.
scheduler_tools::scheduleLoopDomainsLike(
tvs_to_schedule, producer_tv->getLoopDomain());

// Now that all the dependent tensors have the uniform, exact-mapped
// loop domains, we just need to propagte the specific Resize ops of
// this tensor.
for (const auto i : c10::irange(consumer_tv->getLogicalDomain().size())) {
auto out_logical_id = consumer_tv->getLogicalDomain().at(i);
auto resize = dynamic_cast<Resize*>(out_logical_id->definition());
if (resize == nullptr) {
continue;
}

scheduler_tools::scheduleLoopDomainsBy(tvs_to_schedule, resize);
}
}

} // namespace scheduler_tools
} // namespace nvfuser
23 changes: 23 additions & 0 deletions csrc/scheduler/tools/resize_utils.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
// clang-format off
/*
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES.
* All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*/
// clang-format on
#pragma once

namespace nvfuser {

class Expr;

namespace scheduler_tools {

// For a given resize-based tensor op such as SliceOp and PadOp, make the loop
// domain of each dependent producer tensor exact-mapped by propagating
// the iter-domain ops of the output tensor of the given op. Note that
// fusion inputs are skipped as their loop domains don't matter.
void propagateResizeToInputs(Expr* resize_op);

} // namespace scheduler_tools
} // namespace nvfuser
7 changes: 6 additions & 1 deletion tests/cpp/test_loop_domain_scheduling.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,12 @@ TEST_F(LoopDomainSchedulingTest, ManyReshape) {
// The new loop domain of each tensor should be exactly mapped
// with the reference loop domain
for (const auto tv : fusion_copy.allTvs()) {
EXPECT_EQ(tv->getLoopDomain().size(), ref_loop.size());
// scheduleLoopDomainsLike skips fusion inputs
if (tv->isFusionInput()) {
continue;
}
EXPECT_EQ(tv->getLoopDomain().size(), ref_loop.size())
<< "Invalid rank of loop domain: " << tv->toString();
for (const auto i : c10::irange(ref_loop.size())) {
EXPECT_TRUE(exact_graph.disjointValSets().strictAreMapped(
tv->getLoopDomain().at(i), ref_loop.at(i)))
Expand Down
Loading

0 comments on commit 875d765

Please sign in to comment.