From 31e4360c7de7d6ea52f16310e54240cda6d7b65d Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 9 Dec 2024 21:53:39 -0800 Subject: [PATCH] wip --- CMakeLists.txt | 1 + csrc/scheduler/registry.cpp | 3 + csrc/scheduler/resize.cpp | 437 +++++++++++++++++++++++++++++ csrc/scheduler/resize.h | 41 +++ csrc/scheduler/scheduler_types.cpp | 2 + csrc/scheduler/scheduler_types.h | 6 +- 6 files changed, 488 insertions(+), 2 deletions(-) create mode 100644 csrc/scheduler/resize.cpp create mode 100644 csrc/scheduler/resize.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 040d8129b72..dce8385ea08 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -231,6 +231,7 @@ list(APPEND NVFUSER_SRCS ${NVFUSER_SRCS_DIR}/scheduler/reduction_utils.cpp ${NVFUSER_SRCS_DIR}/scheduler/registry.cpp ${NVFUSER_SRCS_DIR}/scheduler/registry_utils.cpp + ${NVFUSER_SRCS_DIR}/scheduler/resize.cpp ${NVFUSER_SRCS_DIR}/scheduler/runtime_info.cpp ${NVFUSER_SRCS_DIR}/scheduler/scheduler_types.cpp ${NVFUSER_SRCS_DIR}/scheduler/tools/inlining.cpp diff --git a/csrc/scheduler/registry.cpp b/csrc/scheduler/registry.cpp index fd9573d6b32..039d94b38af 100644 --- a/csrc/scheduler/registry.cpp +++ b/csrc/scheduler/registry.cpp @@ -13,6 +13,7 @@ #include #include #include +#include #include #include @@ -90,6 +91,8 @@ std::unique_ptr SchedulerEntry::makeSchedulerInstance( return std::make_unique(); case SchedulerType::ExprEval: return std::make_unique(); + case SchedulerType::Resize: + return std::make_unique(); default: NVF_THROW("unreachable"); } diff --git a/csrc/scheduler/resize.cpp b/csrc/scheduler/resize.cpp new file mode 100644 index 00000000000..cd07f572872 --- /dev/null +++ b/csrc/scheduler/resize.cpp @@ -0,0 +1,437 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace nvfuser { + +bool ResizeScheduler::canScheduleCompileTime(Fusion* fusion) { + std::cerr << "ResizeScheduler::canScheduleCompileTime\n"; + + if (!ir_utils::hasOpsOfType(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "No resize op to schedule"); + return false; + } + + if (scheduler_utils::isResharding(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Fusion is resharding."); + return false; + } + + if (ir_utils::hasAnyReductionOps(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "No support for reduction ops"); + return false; + } + + if (registry_utils::hasNonUniqueBcast(fusion)) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), + "Broadcasting dimension might be broadcasting to multiple sizes."); + return false; + } + + // For now, only a single resize op is allowed to exist. + auto resize_based_tensor_ops = ir_utils::getOpsOfType(fusion); + if (resize_based_tensor_ops.size() != 1) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Only a single resize op is allowed."); + return false; + } + + // Slicing of or to a broadcast ID is not allowed yet. + for (auto tensor_op : resize_based_tensor_ops) { + TensorView* out_tv = tensor_op->output(0)->as(); + for (auto logical_id : out_tv->getLogicalDomain()) { + Resize* resize = dynamic_cast(logical_id->definition()); + if (resize == nullptr) { + continue; + } + if (resize->out()->isBroadcast()) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Resize to a broadcast ID is not allowed."); + return false; + } + if (resize->in()->isBroadcast()) { + scheduler_debug_utils::canScheduleRejectReason( + schedulerType(), "Resize of a broadcast ID is not allowed."); + return false; + } + } + } + + return true; +} + +std::unique_ptr ResizeScheduler::computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache) { + FUSER_PERF_SCOPE("ResizeScheduler::computeHeuristics"); + auto params = std::make_unique(SchedulerType::Resize); + params->cparams.index_type = runtime_info.getIndexType(); + return params; +} + +namespace { + +std::vector>> +getReferenceTensors(Fusion* fusion) { + std::vector ref_candidates; + + std::cerr << "getReferenceTensors\n"; + fusion->printMath(); + + const auto all_tvs = fusion->allTvs(); + + DisjointSets disjoint_val_sets; + + std::vector resize_ops = + ir_utils::getOpsOfType(fusion); + + // Group all tvs that are dependent on resize op outputs + for (Expr* resize_op : resize_ops) { + auto ref_tv = resize_op->output(0)->as(); + + auto dep_vals = DependencyCheck::getAllValsBetween( + {fusion->inputs().begin(), fusion->inputs().end()}, {ref_tv}); + + for (auto dep_tv : ir_utils::filterByType(dep_vals)) { + // Don't add inputs. Inputs are not replicated nor scheduled. + if (dep_tv->isFusionInput()) { + continue; + } + std::cerr << "Mapping " << ref_tv->toString() << " and " + << dep_tv->toString() << "\n"; + disjoint_val_sets.mapEntries(ref_tv, dep_tv); + } + } + + // TODO: Reuse + IdModel id_model(fusion, /*build_graphs=*/false); + const auto& broadcast_graph = id_model.buildBroadcastGraph(); + + for (const auto i : c10::irange(resize_ops.size() - 1)) { + for (const auto j : c10::irange(i + 1, resize_ops.size())) { + auto out_tv_i = resize_ops.at(i)->output(0)->as(); + auto out_tv_j = resize_ops.at(j)->output(0)->as(); + if (disjoint_val_sets.strictAreMapped(out_tv_i, out_tv_j)) { + continue; + } + + const auto out_tv_i_loop_groups = + broadcast_graph.toGroups(out_tv_i->getLoopDomain()); + const auto out_tv_j_loop_groups = + broadcast_graph.toGroups(out_tv_j->getLoopDomain()); + + bool same_loop_domain = + broadcast_graph.toGroups(out_tv_i->getLoopDomain()).set() == + broadcast_graph.toGroups(out_tv_j->getLoopDomain()).set(); + std::cerr << "Comparing " << out_tv_i->toString() << " and " + << out_tv_j->toString() << ": " << same_loop_domain << "\n"; + if (!same_loop_domain) { + auto [path_from_i_to_j, all_visited] = + ValGraphBFS::getExprGroupsBetween( + broadcast_graph, + out_tv_i_loop_groups, + out_tv_j_loop_groups, + /*require_all_to_visited=*/false); + if (!all_visited) { + // There are some unreachable loop groups + continue; + } + + // If there's any resize node, don't merge them + if (std::any_of( + path_from_i_to_j.begin(), + path_from_i_to_j.end(), + [](const auto& path_component) { + return path_component.first->front()->template isA(); + })) { + continue; + } + } + + std::cerr << "Same loop domain: " << out_tv_i->toString() << " and " + << out_tv_j->toString() << "\n"; + disjoint_val_sets.mapEntries(out_tv_i, out_tv_j); + } + } + + const auto num_disjoint_resize_groups = disjoint_val_sets.size(); + + std::cerr << "Number of disjoint resize groups: " + << num_disjoint_resize_groups << "\n"; + + std::cerr << "Initial disjoint grouping of tensors\n"; + for (const auto& set : disjoint_val_sets.disjointSets()) { + std::cerr << "\t{"; + for (auto tv : *set) { + std::cerr << " T" << tv->name(); + } + std::cerr << "}\n"; + } + + // Include outputs + for (Expr* resize_op : resize_ops) { + auto resize_out = resize_op->output(0)->as(); + auto output_dep_vals = + DependencyCheck::getAllValsBetween({resize_out}, fusion->outputs()); + for (auto tv : ir_utils::filterByType(output_dep_vals)) { + disjoint_val_sets.mapEntries(resize_out, tv); + } + } + + // Output dep vals should also be disjointly grouped, so the number + // of groups should not change + NVF_ERROR( + num_disjoint_resize_groups == disjoint_val_sets.size(), + "Expected number of groups: ", + num_disjoint_resize_groups, + ". Actual: ", + disjoint_val_sets.size()); + + // There can still be tensors that are not producers nor consumers + // of resize ops. They should be fine with any of the groups. + // All of them should now be privatized. + + auto first_group_tv = resize_ops.at(0)->output(0)->as(); + + for (auto tv : all_tvs) { + if (tv->isFusionInput() || disjoint_val_sets.mappingExists(tv)) { + continue; + } + + std::cerr << "Remaining tv: " << tv->toString() + << ". Put into the group of " << first_group_tv->toString() + << "\n"; + + auto dep_outputs = DependencyCheck::getAllOutputsOf({tv}); + NVF_ERROR(!dep_outputs.empty()); + + TensorView* first_dep_output = (*(dep_outputs.begin()))->as(); + bool added_to_group = false; + for (const auto& disjoint_set : disjoint_val_sets.disjointSets()) { + if (!disjoint_set->has(first_dep_output)) { + continue; + } + + // Make sure all outputs are in the same set + for (const auto dep_output : dep_outputs) { + NVF_ERROR(disjoint_set->has(dep_output->as())); + } + + disjoint_val_sets.mapEntries(tv, disjoint_set->front()); + added_to_group = true; + break; + } + + // Could not find any group to join + NVF_ERROR( + added_to_group, "Could not find any group to add ", tv->toString()); + } + + NVF_ERROR( + num_disjoint_resize_groups == disjoint_val_sets.size(), + "Expected number of groups: ", + num_disjoint_resize_groups, + ". Actual: ", + disjoint_val_sets.size()); + + std::cerr << "TV disjoint groups: " << disjoint_val_sets.size() << "\n"; + + std::vector>> ref_list; + + // Pick a reference in each disjoint set + for (const auto& disjoint_set : disjoint_val_sets.disjointSets()) { + TensorView* ref_tv = nullptr; + // TensorView* input_tv = nullptr; + std::unordered_set resize_op_outputs; +#if 0 + for (TensorView* tv : *disjoint_set) { + // All of the slice/pad/cat output tensors should have the same + // loop domain. Any of them can be equally used as the reference + // for this group. + // Update: But propagation could still fail due to the resize + // cyclic mapping. Don't use resize outputs as reference for + // now. + + if (auto def = tv->definition(); + def != nullptr && def->isOneOf()) { + ref_tv = def->output(0)->as(); + break; + } + + if (auto def = tv->definition(); std::any_of( + def->inputs().begin(), def->inputs().end(), [](Val* input) { + return input->isA() && input->isFusionInput(); + })) { + if (input_tv == nullptr || + (input_tv->domain()->noBroadcasts().size() < + tv->domain()->noBroadcasts().size())) { + input_tv = tv; + } + } + } +#endif + + for (TensorView* tv : *disjoint_set) { + if (auto def = tv->definition(); + def != nullptr && def->isOneOf()) { + resize_op_outputs.insert(def->output(0)->as()); + } + } + + for (TensorView* tv : *disjoint_set) { + if (!tv->isFusionOutput()) { + continue; + } + + // Ref if all resize_outputs have a dependency with this output + auto all_dep_vals = DependencyCheck::getAllValsBetween( + {fusion->inputs().begin(), fusion->inputs().end()}, {tv}); + bool all_resize_out_dependent = true; + for (auto resize_out : resize_op_outputs) { + auto it = + std::find(all_dep_vals.begin(), all_dep_vals.end(), resize_out); + if (it == all_dep_vals.end()) { + std::cerr << "Not a dependency: " << resize_out->toString() << " of " + << tv->toString() << "\n"; + all_resize_out_dependent = false; + break; + } + } + + if (!all_resize_out_dependent) { + continue; + } + + ref_tv = tv; + } + + if (ref_tv) { + std::cerr << "Reference: " << ref_tv->toString() << "\n"; + + ref_list.emplace_back(ref_tv, std::vector{}); + auto& member_list = ref_list.back().second; + for (auto tv : all_tvs) { + if (disjoint_set->has(tv)) { + member_list.push_back(tv); + } + } + + continue; + } + + NVF_THROW( + "No reference found for ", toDelimitedString(disjoint_set->vector())); + } + + std::cerr << "Disjoint grouping of tensors with representatives:\n"; + for (const auto& [ref, set] : ref_list) { + std::cerr << "\tRepresentative: " << ref->toString() << "\n" + << "\t{"; + for (auto tv : set) { + std::cerr << " T" << tv->name(); + } + std::cerr << "}\n"; + } + + return ref_list; +} + +} // namespace + +void ResizeScheduler::schedule(Fusion* fusion, const HeuristicParams* params) { + FUSER_PERF_SCOPE("ResizeScheduler::schedule"); + + DebugStreamGuard dsg(std::cerr); + + FusionGuard fg(fusion); + + std::cerr << "ResizeScheduler::schedule\n"; + + scheduler_utils::clearMemorySpace(fusion); + + scheduler_utils::cacheInputs(fusion, true); + + fusion->printMath(); + + const auto exprs = fusion->exprs(); + for (auto expr : exprs) { + if (!expr->isOneOf()) { + continue; + } + + std::cerr << "Propagating resize tensor op: " << expr->toString(); + scheduler_tools::propagateResizeToInputs(expr); + } + + const auto ref_tensors = getReferenceTensors(fusion); + + for (const auto& [ref_tv, tvs_to_schedule] : ref_tensors) { + std::cerr << "Reference: " << ref_tv->toString() << "\n"; + std::cerr << "Tvs to schedule: " << toDelimitedString(tvs_to_schedule) + << "\n"; + + ref_tv->flatten(); + ref_tv->split(0, 128); + ref_tv->split(0, 1 << 14); + ref_tv->axis(-1)->parallelize(ParallelType::TIDx); + ref_tv->axis(-2)->parallelize(ParallelType::BIDx); + + std::cerr << "Scheduled reference:\n"; + ref_tv->printTransforms(); + + scheduler_tools::scheduleLoopDomainsLike( + tvs_to_schedule, ref_tv->getLoopDomain()); + } + + { + std::cerr << "All done\n"; + fusion->printMath(); + for (auto tv : fusion->allTvs()) { + std::cerr << "Final scheduled T" << tv->name() << "\n"; + if (tv->hasRoot()) { + std::cerr << "\tRoot: " << toDelimitedString(tv->getRootDomain()) + << "\n"; + } + std::cerr << "\tLogical: " << toDelimitedString(tv->getLogicalDomain()) + << "\n"; + std::cerr << "\tLoop: " << toDelimitedString(tv->getLoopDomain()) << "\n"; + std::cerr << "\tAdditional ids: " + << toDelimitedString(tv->domain()->additionalIDs()) << "\n"; + for (auto expr : tv->domain()->allExprs()) { + std::cerr << expr->toString(4); + } + } + } + + inlineMost(); + + fusion->printMath(); + + return; +} + +} // namespace nvfuser diff --git a/csrc/scheduler/resize.h b/csrc/scheduler/resize.h new file mode 100644 index 00000000000..b51ecf1e6dd --- /dev/null +++ b/csrc/scheduler/resize.h @@ -0,0 +1,41 @@ +// clang-format off +/* + * SPDX-FileCopyrightText: Copyright (c) 2023-present NVIDIA CORPORATION & AFFILIATES. + * All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + */ +// clang-format on +#pragma once + +#include +#include + +namespace nvfuser { + +class Fusion; +class SchedulerRuntimeInfo; +class HeuristicDataCache; + +class ResizeScheduler : public SchedulerEntry { + public: + bool canScheduleCompileTime(Fusion* fusion) override; + bool canScheduleRunTime( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache = nullptr) override { + return true; + } + + std::unique_ptr computeHeuristics( + Fusion* fusion, + SchedulerRuntimeInfo& runtime_info, + HeuristicDataCache* data_cache) override; + + void schedule(Fusion* fusion, const HeuristicParams* params) override; + + constexpr static SchedulerType schedulerType() { + return SchedulerType::Resize; + } +}; + +} // namespace nvfuser diff --git a/csrc/scheduler/scheduler_types.cpp b/csrc/scheduler/scheduler_types.cpp index 623d5a22697..cf9b974acf5 100644 --- a/csrc/scheduler/scheduler_types.cpp +++ b/csrc/scheduler/scheduler_types.cpp @@ -31,6 +31,8 @@ std::string toString(SchedulerType scheduler_type) { return "matmul"; case SchedulerType::ExprEval: return "expr_eval"; + case SchedulerType::Resize: + return "resize"; case SchedulerType::None: return "none"; default: diff --git a/csrc/scheduler/scheduler_types.h b/csrc/scheduler/scheduler_types.h index 275a1f372e7..caa389abb9a 100644 --- a/csrc/scheduler/scheduler_types.h +++ b/csrc/scheduler/scheduler_types.h @@ -56,15 +56,17 @@ enum class SchedulerType { InnerOuterPersistent, OuterPersistent, Transpose, - ExprEval + ExprEval, + Resize }; //! Define a schedule table to loop over all the heuristics in priority order. -constexpr std::array all_heuristics_in_priority_order = { +constexpr std::array all_heuristics_in_priority_order = { SchedulerType::ExprEval, SchedulerType::NoOp, SchedulerType::Matmul, SchedulerType::Reduction, + SchedulerType::Resize, SchedulerType::Transpose, SchedulerType::PointWise, SchedulerType::InnerPersistent,