-
Notifications
You must be signed in to change notification settings - Fork 54
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Propagating slice/pad/cat to inputs (#3549)
Added a scheduler util function schedule a fusion with resize-based ops such as slice, pad and concat. This propagates resize ops to producers so that all tensors have the exact-mapped loop domains. Part of #3425. Extracted so that it can be individually tested. (There's a follow-up PR: #3555)
- Loading branch information
Showing
6 changed files
with
566 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,70 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
|
||
#include <id_model/id_model.h> | ||
#include <ir/cloner.h> | ||
#include <ir/utils.h> | ||
#include <iter_visitor.h> | ||
#include <logical_domain_map.h> | ||
#include <scheduler/tools/loop_domain_scheduler.h> | ||
#include <scheduler/tools/resize_utils.h> | ||
#include <val_graph_visitor.h> | ||
|
||
namespace nvfuser { | ||
namespace scheduler_tools { | ||
|
||
void propagateResizeToInputs(Expr* resize_tensor_op) { | ||
NVF_ERROR( | ||
resize_tensor_op->isA<SliceOp>() || resize_tensor_op->isA<PadOp>(), | ||
"Unexpected resize tensor op: ", | ||
resize_tensor_op->toString()); | ||
|
||
Fusion* fusion = resize_tensor_op->fusion(); | ||
|
||
auto producer_tv = resize_tensor_op->input(0)->as<TensorView>(); | ||
auto consumer_tv = resize_tensor_op->output(0)->as<TensorView>(); | ||
|
||
auto all_dep_vals = DependencyCheck::getAllValsBetween( | ||
{fusion->inputs().begin(), fusion->inputs().end()}, {producer_tv}); | ||
|
||
std::vector<TensorView*> tvs_to_schedule; | ||
tvs_to_schedule.reserve(all_dep_vals.size()); | ||
for (auto val : all_dep_vals) { | ||
if (val->isA<TensorView>() && !val->isFusionInput()) { | ||
tvs_to_schedule.push_back(val->as<TensorView>()); | ||
} | ||
} | ||
|
||
// Ideally, this should be just calling | ||
// scheduler_tools::scheduleLoopDomainsLike once with the consumer | ||
// tensor as a reference. However, due to the indexing issue with | ||
// resize, propagating the Resize iter-domain op may fail. To avoid | ||
// the problem, the propagation of the resize op is explicitly done | ||
// by using scheduler_tools::scheduleLoopDomainsBy. | ||
// | ||
// Before doing so, all the dependent tensors need to have the exact-mapped | ||
// loop domain. | ||
scheduler_tools::scheduleLoopDomainsLike( | ||
tvs_to_schedule, producer_tv->getLoopDomain()); | ||
|
||
// Now that all the dependent tensors have the uniform, exact-mapped | ||
// loop domains, we just need to propagte the specific Resize ops of | ||
// this tensor. | ||
for (const auto i : c10::irange(consumer_tv->getLogicalDomain().size())) { | ||
auto out_logical_id = consumer_tv->getLogicalDomain().at(i); | ||
auto resize = dynamic_cast<Resize*>(out_logical_id->definition()); | ||
if (resize == nullptr) { | ||
continue; | ||
} | ||
|
||
scheduler_tools::scheduleLoopDomainsBy(tvs_to_schedule, resize); | ||
} | ||
} | ||
|
||
} // namespace scheduler_tools | ||
} // namespace nvfuser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
// clang-format off | ||
/* | ||
* SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. | ||
* All rights reserved. | ||
* SPDX-License-Identifier: BSD-3-Clause | ||
*/ | ||
// clang-format on | ||
#pragma once | ||
|
||
namespace nvfuser { | ||
|
||
class Expr; | ||
|
||
namespace scheduler_tools { | ||
|
||
// For a given resize-based tensor op such as SliceOp and PadOp, make the loop | ||
// domain of each dependent producer tensor exact-mapped by propagating | ||
// the iter-domain ops of the output tensor of the given op. Note that | ||
// fusion inputs are skipped as their loop domains don't matter. | ||
void propagateResizeToInputs(Expr* resize_op); | ||
|
||
} // namespace scheduler_tools | ||
} // namespace nvfuser |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.