From 212ac38e08c47251356e0f0ee8f48e21a12b2293 Mon Sep 17 00:00:00 2001 From: Naoya Maruyama Date: Mon, 3 Feb 2025 19:22:10 -0800 Subject: [PATCH] Privatize up-cast ops for better segmentation (#3776) ## Problem This is to address yet another segmentation problem with RoPE. In particular, in Phi3 forward, there's two bfloat-to-float cast ops that are consumed by two segments. Find `T49` and `T36` below: ![phi3_fwd](https://github.com/user-attachments/assets/a3e1c256-6e58-4028-8fd2-a6725e5a810f) They are consumed by two segments, one with the blue color and another with the light purple color (not the one spanning vertically in the center of the graph). The problem here is that the upcast ops are grouped into the blue segment and their output float tensors are fed into the light purple segment. Specifically, we get this segment: ``` g{(resize) group id: 6 inputs: T0_g___bfloat[bS0{1}, iS1{8192}, iS2{9216}] __bfloat T34_g___bfloat[bS121{1}, bS122{1 ex 32}, iS123{8192}, iS124{96}] __bfloat T47_g___bfloat[bS177{1}, bS178{1 ex 32}, iS179{8192}, iS180{96}] __bfloat outputs: T36_g_float[bS129{1}, bS130{1 ex 32}, iS131{8192}, iS132{96}] float T49_g_float[bS185{1}, bS186{1 ex 32}, iS187{8192}, iS188{96}] float T52_g___bfloat[bS197{1}, iS198{32}, iS199{8192}, iS200{96}] __bfloat ``` which is followed by: ``` g{(resize) group id: 7 inputs: T0_g___bfloat[bS0{1}, iS1{8192}, iS2{9216}] __bfloat T36_g_float[bS129{1}, bS130{1 ex 32}, iS131{8192}, iS132{96}] float T49_g_float[bS185{1}, bS186{1 ex 32}, iS187{8192}, iS188{96}] float outputs: T66_g___bfloat[bS257{1}, iS258{32}, iS259{8192}, iS260{96}] __bfloat ``` Notice that the first segment produces `T36` and `T49`, which are just upcast of `T34` and `T47`, respectively, and then they are inputs of the following segment. This is not ideal. The second segment should just use `T34` and `T47` directly, and by doing so the first segment would not need to produce `T34` and `T47` as segment outputs. More concretely, in the current segmentation, there are two reads of bf16 tensors (`T34` and `T47`), two writes of fp32 tensor (`T36` and `T47`), and two reads of fp32 tensors (`T36` and `T47`). What we could do instead is just two reads of bf16 tensors (`T34` and `T47`) and another two reads of the same tensors. The fusion segmenter already addresses this problem partially by forwarding unary ops, but only for unary ops using fusion inputs, which doesn't apply to the Phi3 case. ## Fix The above problem with Phi3 wouldn't happen if `T49` and `T36` are not shared by the two segments. So, we first privatize all upcast tensors. This is done after the initial unsegmented trial and before the segmentation loop. https://github.com/NVIDIA/Fuser/pull/3776/files#diff-e2f2ad44a6dc03e4ad8e5f0f047be25eb1c142add431d48c1e046c968a577f3bR3958 That's all for the Phi3 case, but this privatization isn't necessary if the two segments were actually fused (which we don't support yet). If that actually happened, the fused segment would have something like: ``` T2 = bf16ToFp32(T0); T3 = bf16ToFp32(T0); T6 = T2 + T3 ``` Instead of: ``` T2 = bf16ToFp32(T0); T6 = T2 + T2 ``` This is functionally correct and shouldn't have any perf issue either, but just in case, we revert the privatization in the final segments. ## Perf benefit Current resize schedule on H100: ``` NVFUSER_ENABLE=resize_scheduler pytest benchmarks/python/test_rope.py --benchmark-thunder -k 'hf_phi3_rope and fwd' ---------------------------------------------------------------------------- benchmark: 1 tests --------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_rope_fwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 129.9170 132.9290 131.3976 0.7926 131.2950 0.7330 2;1 7.6105 10 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` With this PR: ``` -------------------------------------------------------------------------- benchmark: 1 tests ------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_rope_fwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 97.0230 99.9030 98.9724 0.7649 99.1510 0.4780 2;1 10.1038 10 1 ----------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` It's also effective even without the resize scheduler. TOT: ``` ---------------------------------------------------------------------------- benchmark: 1 tests --------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_rope_fwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 195.1030 196.3500 195.7106 0.3948 195.6955 0.5120 3;0 5.1096 10 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` With this PR: ``` ---------------------------------------------------------------------------- benchmark: 1 tests --------------------------------------------------------------------------- Name (time in us) Min Max Mean StdDev Median IQR Outliers OPS (Kops/s) Rounds Iterations --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- test_rope_fwd_benchmark[executor='thunder'-variation='hf_phi3_rope'] 141.1850 142.4950 141.7790 0.4813 141.7605 0.9600 5;0 7.0532 10 1 --------------------------------------------------------------------------------------------------------------------------------------------------------------------------- ``` --------- Co-authored-by: Liqiang Lu --- csrc/fusion_segmenter.cpp | 200 ++++++++++++++++++++++++++++++++ csrc/fusion_segmenter.h | 21 ++++ tests/cpp/test_segmentation.cpp | 103 ++++++++++++++++ 3 files changed, 324 insertions(+) diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 1c49713eaab..1a176a404fc 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -3955,6 +3955,7 @@ SegmentCandidateFinder::SegmentCandidateFinder( options_.run_final_merge), "Invalid Segmenter options"); segmented_fusion_ = std::make_unique(std::move(fusion)); + privatizeUpcast(); findSegments(); } @@ -4206,6 +4207,201 @@ void SegmentCandidateFinder::findSegments() { } } +void SegmentCandidateFinder::privatizeUpcast() { + if (getenv("DISABLE_PRIVATIZE")) { + return; + } + // Insert castOp to complete_fusion_ + FusionGuard fg(segmented_fusion_->complete_fusion_.get()); + + const auto exprs = segmented_fusion_->complete_fusion_->exprs(); + + for (auto expr : exprs) { + if (!ir_utils::isTvOp(expr)) { + continue; + } + + for (const auto i : c10::irange(expr->inputs().size())) { + auto maybe_upcast_out_tv = dynamic_cast(expr->input(i)); + if (maybe_upcast_out_tv == nullptr) { + continue; + } + + // Check if the input is an output of an upcast op + auto maybe_upcast_op = + dynamic_cast(maybe_upcast_out_tv->definition()); + if (maybe_upcast_op == nullptr || + maybe_upcast_op->getUnaryOpType() != UnaryOpType::Cast) { + continue; + } + + auto precisions = + ir_utils::getPrecisionOfProducerConsumerTensors(maybe_upcast_op); + if (!precisions.has_value() || precisions->first >= precisions->second) { + continue; + } + + // Check if there's multiple uses of the upcast output + auto uses_of_upcast_out_tv = maybe_upcast_out_tv->uses(); + if (uses_of_upcast_out_tv.size() < 2) { + continue; + } + + // If this is the first use of the upcast output, keep it as is + if (expr == uses_of_upcast_out_tv.front()) { + continue; + } + + auto upcast_out_tv_clone = + castOp(maybe_upcast_out_tv->dtype(), maybe_upcast_op->input(0)); + expr = ir_utils::replaceValInExprInputs( + expr, maybe_upcast_out_tv, upcast_out_tv_clone); + + privatized_upcast_ops_[maybe_upcast_op].insert( + upcast_out_tv_clone->definition()->as()); + } + } +} + +void SegmentCandidateFinder::revertPrivatizedUpcast(SegmentedGroup* group) { + // If a given consumer edge is a duplicate of another edge of the + // same producer group, remove the given edge from both the producer + // and consumer groups. + auto maybe_deduplicate_edge = + [](SegmentedEdge* maybe_duplicated_consumer_edge) { + SegmentedGroup* producer_group = maybe_duplicated_consumer_edge->from; + + auto same_edge_it = std::find_if( + producer_group->consumer_edges.begin(), + producer_group->consumer_edges.end(), + [&](SegmentedEdge* consumer_edge) { + return consumer_edge != maybe_duplicated_consumer_edge && + *consumer_edge == *maybe_duplicated_consumer_edge; + }); + + if (same_edge_it == producer_group->consumer_edges.end()) { + return; + } + + // maybe_duplicated_consumer_edge is redundant. Remove it from the + // from and the two groups + auto consumer_edge_to_remove = std::find( + producer_group->consumer_edges.begin(), + producer_group->consumer_edges.end(), + maybe_duplicated_consumer_edge); + NVF_ERROR( + consumer_edge_to_remove != producer_group->consumer_edges.end()); + producer_group->consumer_edges.erase(consumer_edge_to_remove); + + SegmentedGroup* consumer_group = maybe_duplicated_consumer_edge->to; + auto producer_edge_to_remove = std::find( + consumer_group->producer_edges.begin(), + consumer_group->producer_edges.end(), + maybe_duplicated_consumer_edge); + NVF_ERROR( + producer_edge_to_remove != consumer_group->producer_edges.end()); + consumer_group->producer_edges.erase(producer_edge_to_remove); + }; + + // Replace old_expr with new_expr if found in a given group. Return + // true if replaced. + auto maybe_replace = + [](SegmentedGroup* group, Expr* old_expr, Expr* new_expr) -> bool { + auto it = std::find(group->exprs_.begin(), group->exprs_.end(), old_expr); + if (it != group->exprs_.end()) { + *it = new_expr; + return true; + } else { + return false; + } + }; + + for (const auto& [original_upcast, clones] : privatized_upcast_ops_) { + std::vector upcast_in_group; + Val* upcast_val_to_keep = nullptr; + for (auto uop : ir_utils::filterByType(group->exprs())) { + if (uop != original_upcast && !clones.count(uop)) { + continue; + } + + upcast_in_group.push_back(uop); + + auto upcast_tv = uop->out(); + + // Prefer the original upcast if found + if (upcast_val_to_keep == nullptr || + upcast_tv == original_upcast->out()) { + upcast_val_to_keep = upcast_tv; + } + } + + if (upcast_in_group.size() < 2) { + continue; + } + + for (auto uop : upcast_in_group) { + Val* upcast_val_to_replace = uop->out(); + if (upcast_val_to_replace == upcast_val_to_keep) { + // Keep this uop as is since its output replaces the other + // upcast outputs + continue; + } + + NVF_ERROR( + upcast_val_to_replace->uses().size() == 1, + "Multiple use of replicated upcast tensor found: ", + toDelimitedString(upcast_val_to_replace->uses())); + + auto use_of_upcast_val_to_replace = upcast_val_to_replace->uses().at(0); + + auto updated_expr = ir_utils::replaceValInExprInputs( + use_of_upcast_val_to_replace, + upcast_val_to_replace, + upcast_val_to_keep); + + // Replace use_of_upcast_val_to_replace with + // updated_expr. use_of_upcast_val_to_replace must be in the + // same group of its consumer groups + if (!maybe_replace(group, use_of_upcast_val_to_replace, updated_expr)) { + for (auto consumer_edge : group->consumer_edges) { + if (maybe_replace( + consumer_edge->to, + use_of_upcast_val_to_replace, + updated_expr)) { + break; + } + } + } + + // Update a consumer edge if its val is + // upcast_val_to_replace. Again, there must be at most one such + // edge. + SegmentedEdge* consumer_edge_to_update = nullptr; + for (auto consumer_edge : group->consumer_edges) { + if (consumer_edge->val == upcast_val_to_replace) { + NVF_ERROR( + consumer_edge_to_update == nullptr, + "Multiple consumer edges using ", + upcast_val_to_replace->toString(), + " found"); + consumer_edge->val = upcast_val_to_keep; + consumer_edge_to_update = consumer_edge; + } + } + + // Now that the consumer edge is updated, it may be a duplicate + // of an exising edge. Remove if so. + if (consumer_edge_to_update != nullptr) { + maybe_deduplicate_edge(consumer_edge_to_update); + } + + // Note that it should not be necessary to do anything with + // group->output_vals since the inserted upcast ops should never produce + // fusion outputs. + } + } +} + // Decides whether we should forward an input (or a forwarded input) of a // fusion. Currently, we forward an input only when its single use is a UnaryOp. // Therefore, this function returns `v`'s single unary use or nullptr if it @@ -4632,6 +4828,10 @@ void SegmentCandidateFinder::finalize() { resolveScalarsInGroup(group); } + for (auto group : segmented_fusion_->groups()) { + revertPrivatizedUpcast(group); + } + // Finalize each group, fill in the missing inputs, i.e. tensor dims. for (auto g : groups()) { g->setSchedulerType(deriveSchedulerType(g)); diff --git a/csrc/fusion_segmenter.h b/csrc/fusion_segmenter.h index c70aab19e49..e5fcc70c7b7 100644 --- a/csrc/fusion_segmenter.h +++ b/csrc/fusion_segmenter.h @@ -40,6 +40,14 @@ struct SegmentedEdge { Val* val; void print() const; + + bool operator==(const SegmentedEdge& other) const { + return from == other.from && to == other.to && val == other.val; + } + + bool operator!=(const SegmentedEdge& other) const { + return !(*this == other); + } }; std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge); @@ -564,8 +572,18 @@ class SegmentCandidateFinder { void buildInitialSegments(); + // Replicate upcast ops when consumed by multiple expressions. This + // promotes segmented fusions to share pre-upcast tensors rather + // than post-upcast tensors. Replicated upcast ops will be reverted + // when they are grouped into the same segment. See + // https://github.com/NVIDIA/Fuser/pull/3776/ for more details. + void privatizeUpcast(); + void findSegments(); + // Revert privatized upcast ops when not necessary + void revertPrivatizedUpcast(SegmentedGroup* group); + //! Find a group found in candidates that can be merged with the //! given group and set them to be merged if found. When no //! candidate is given, SegmentedGroup::getMergeCandidates is used @@ -723,6 +741,9 @@ class SegmentCandidateFinder { // used for breaking the fusion into compute and communication segments std::optional runtime_info_; + std::unordered_map> + privatized_upcast_ops_; + //! Note: //! Segmenter should eventually rely only on runtime_info_ for //! safe caching. runtime_inputs_ is only used in translateWelford diff --git a/tests/cpp/test_segmentation.cpp b/tests/cpp/test_segmentation.cpp index 8bcf3b13d60..8fbc5bc0bc2 100644 --- a/tests/cpp/test_segmentation.cpp +++ b/tests/cpp/test_segmentation.cpp @@ -703,4 +703,107 @@ TEST_F(SegmentationTest, ForwardInputsToSegmenterSetIssue2658) { executor_cache.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); } +// Test to verify an upcast is replicated between different segments +TEST_F(NVFuserTest, PrivatizeUpcast) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2, DataType::BFloat16); + fusion.addInput(tv0); + + auto tv1 = segment_set(tv0); + auto tv2 = castOp(DataType::Float, tv1); + + auto tv3 = sum(tv2, {0}); + fusion.addOutput(tv3); + + auto tv4 = sum(tv2, {1}); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn({16, 32}, options); + std::vector inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + + // There must be three segments, one with ExprEvalExecutor and two + // with KernelExecutor. + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_THAT(runtime->fusionSegments()->groups(), SizeIs(3)); + + for (const auto& executor : runtime->executors()) { + // Ignore the one taken care by ExprEvalExecutor + if (executor.get()->isA()) { + continue; + } + // This segment should corresponds to each of the reductions. Both + // of them should use tv1. + auto ke = dynamic_cast(executor.get()); + ASSERT_NE(ke, nullptr); + kir::Kernel* kernel = ke->compiledKernel()->kernel(); + EXPECT_EQ(kernel->inputs().size(), 1); + EXPECT_EQ(kernel->inputs().at(0)->name(), 1); + } +} + +// Unlike PrivatizeUpcast, verify replicated upcast ops are +// consolidated back as they are grouped into the same segment +TEST_F(NVFuserTest, RevertPrivatizedUpcast) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2, DataType::BFloat16); + fusion.addInput(tv0); + + auto tv1 = segment_set(tv0); + auto tv2 = castOp(DataType::Float, tv1); + + auto tv3 = sum(tv2, {1}); + fusion.addOutput(tv3); + + auto tv4 = sum(tv2, {1}); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn({16, 32}, options); + std::vector inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + + // There must be two segments, one with ExprEvalExecutor and another + // with KernelExecutor. + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_THAT(runtime->fusionSegments()->groups(), SizeIs(2)); + + for (const auto& executor : runtime->executors()) { + // Ignore the one taken care by ExprEvalExecutor + if (executor.get()->isA()) { + continue; + } + // This segment should have the two reductions. There must be only + // one upcast op with tv1 as its producer. + auto ke = dynamic_cast(executor.get()); + ASSERT_NE(ke, nullptr); + kir::Kernel* kernel = ke->compiledKernel()->kernel(); + int64_t num_upcast_ops = 0; + for (auto expr : KernelExprVisitor::getAllExprs(kernel)) { + auto uop = dynamic_cast(expr); + if (uop == nullptr || uop->getUnaryOpType() != UnaryOpType::Cast) { + continue; + } + + EXPECT_EQ(uop->in()->as()->view()->name(), 1); + + ++num_upcast_ops; + } + EXPECT_EQ(num_upcast_ops, 1); + } +} + } // namespace nvfuser