diff --git a/csrc/fusion_segmenter.cpp b/csrc/fusion_segmenter.cpp index 1c49713eaab..1a176a404fc 100644 --- a/csrc/fusion_segmenter.cpp +++ b/csrc/fusion_segmenter.cpp @@ -3955,6 +3955,7 @@ SegmentCandidateFinder::SegmentCandidateFinder( options_.run_final_merge), "Invalid Segmenter options"); segmented_fusion_ = std::make_unique(std::move(fusion)); + privatizeUpcast(); findSegments(); } @@ -4206,6 +4207,201 @@ void SegmentCandidateFinder::findSegments() { } } +void SegmentCandidateFinder::privatizeUpcast() { + if (getenv("DISABLE_PRIVATIZE")) { + return; + } + // Insert castOp to complete_fusion_ + FusionGuard fg(segmented_fusion_->complete_fusion_.get()); + + const auto exprs = segmented_fusion_->complete_fusion_->exprs(); + + for (auto expr : exprs) { + if (!ir_utils::isTvOp(expr)) { + continue; + } + + for (const auto i : c10::irange(expr->inputs().size())) { + auto maybe_upcast_out_tv = dynamic_cast(expr->input(i)); + if (maybe_upcast_out_tv == nullptr) { + continue; + } + + // Check if the input is an output of an upcast op + auto maybe_upcast_op = + dynamic_cast(maybe_upcast_out_tv->definition()); + if (maybe_upcast_op == nullptr || + maybe_upcast_op->getUnaryOpType() != UnaryOpType::Cast) { + continue; + } + + auto precisions = + ir_utils::getPrecisionOfProducerConsumerTensors(maybe_upcast_op); + if (!precisions.has_value() || precisions->first >= precisions->second) { + continue; + } + + // Check if there's multiple uses of the upcast output + auto uses_of_upcast_out_tv = maybe_upcast_out_tv->uses(); + if (uses_of_upcast_out_tv.size() < 2) { + continue; + } + + // If this is the first use of the upcast output, keep it as is + if (expr == uses_of_upcast_out_tv.front()) { + continue; + } + + auto upcast_out_tv_clone = + castOp(maybe_upcast_out_tv->dtype(), maybe_upcast_op->input(0)); + expr = ir_utils::replaceValInExprInputs( + expr, maybe_upcast_out_tv, upcast_out_tv_clone); + + privatized_upcast_ops_[maybe_upcast_op].insert( + upcast_out_tv_clone->definition()->as()); + } + } +} + +void SegmentCandidateFinder::revertPrivatizedUpcast(SegmentedGroup* group) { + // If a given consumer edge is a duplicate of another edge of the + // same producer group, remove the given edge from both the producer + // and consumer groups. + auto maybe_deduplicate_edge = + [](SegmentedEdge* maybe_duplicated_consumer_edge) { + SegmentedGroup* producer_group = maybe_duplicated_consumer_edge->from; + + auto same_edge_it = std::find_if( + producer_group->consumer_edges.begin(), + producer_group->consumer_edges.end(), + [&](SegmentedEdge* consumer_edge) { + return consumer_edge != maybe_duplicated_consumer_edge && + *consumer_edge == *maybe_duplicated_consumer_edge; + }); + + if (same_edge_it == producer_group->consumer_edges.end()) { + return; + } + + // maybe_duplicated_consumer_edge is redundant. Remove it from the + // from and the two groups + auto consumer_edge_to_remove = std::find( + producer_group->consumer_edges.begin(), + producer_group->consumer_edges.end(), + maybe_duplicated_consumer_edge); + NVF_ERROR( + consumer_edge_to_remove != producer_group->consumer_edges.end()); + producer_group->consumer_edges.erase(consumer_edge_to_remove); + + SegmentedGroup* consumer_group = maybe_duplicated_consumer_edge->to; + auto producer_edge_to_remove = std::find( + consumer_group->producer_edges.begin(), + consumer_group->producer_edges.end(), + maybe_duplicated_consumer_edge); + NVF_ERROR( + producer_edge_to_remove != consumer_group->producer_edges.end()); + consumer_group->producer_edges.erase(producer_edge_to_remove); + }; + + // Replace old_expr with new_expr if found in a given group. Return + // true if replaced. + auto maybe_replace = + [](SegmentedGroup* group, Expr* old_expr, Expr* new_expr) -> bool { + auto it = std::find(group->exprs_.begin(), group->exprs_.end(), old_expr); + if (it != group->exprs_.end()) { + *it = new_expr; + return true; + } else { + return false; + } + }; + + for (const auto& [original_upcast, clones] : privatized_upcast_ops_) { + std::vector upcast_in_group; + Val* upcast_val_to_keep = nullptr; + for (auto uop : ir_utils::filterByType(group->exprs())) { + if (uop != original_upcast && !clones.count(uop)) { + continue; + } + + upcast_in_group.push_back(uop); + + auto upcast_tv = uop->out(); + + // Prefer the original upcast if found + if (upcast_val_to_keep == nullptr || + upcast_tv == original_upcast->out()) { + upcast_val_to_keep = upcast_tv; + } + } + + if (upcast_in_group.size() < 2) { + continue; + } + + for (auto uop : upcast_in_group) { + Val* upcast_val_to_replace = uop->out(); + if (upcast_val_to_replace == upcast_val_to_keep) { + // Keep this uop as is since its output replaces the other + // upcast outputs + continue; + } + + NVF_ERROR( + upcast_val_to_replace->uses().size() == 1, + "Multiple use of replicated upcast tensor found: ", + toDelimitedString(upcast_val_to_replace->uses())); + + auto use_of_upcast_val_to_replace = upcast_val_to_replace->uses().at(0); + + auto updated_expr = ir_utils::replaceValInExprInputs( + use_of_upcast_val_to_replace, + upcast_val_to_replace, + upcast_val_to_keep); + + // Replace use_of_upcast_val_to_replace with + // updated_expr. use_of_upcast_val_to_replace must be in the + // same group of its consumer groups + if (!maybe_replace(group, use_of_upcast_val_to_replace, updated_expr)) { + for (auto consumer_edge : group->consumer_edges) { + if (maybe_replace( + consumer_edge->to, + use_of_upcast_val_to_replace, + updated_expr)) { + break; + } + } + } + + // Update a consumer edge if its val is + // upcast_val_to_replace. Again, there must be at most one such + // edge. + SegmentedEdge* consumer_edge_to_update = nullptr; + for (auto consumer_edge : group->consumer_edges) { + if (consumer_edge->val == upcast_val_to_replace) { + NVF_ERROR( + consumer_edge_to_update == nullptr, + "Multiple consumer edges using ", + upcast_val_to_replace->toString(), + " found"); + consumer_edge->val = upcast_val_to_keep; + consumer_edge_to_update = consumer_edge; + } + } + + // Now that the consumer edge is updated, it may be a duplicate + // of an exising edge. Remove if so. + if (consumer_edge_to_update != nullptr) { + maybe_deduplicate_edge(consumer_edge_to_update); + } + + // Note that it should not be necessary to do anything with + // group->output_vals since the inserted upcast ops should never produce + // fusion outputs. + } + } +} + // Decides whether we should forward an input (or a forwarded input) of a // fusion. Currently, we forward an input only when its single use is a UnaryOp. // Therefore, this function returns `v`'s single unary use or nullptr if it @@ -4632,6 +4828,10 @@ void SegmentCandidateFinder::finalize() { resolveScalarsInGroup(group); } + for (auto group : segmented_fusion_->groups()) { + revertPrivatizedUpcast(group); + } + // Finalize each group, fill in the missing inputs, i.e. tensor dims. for (auto g : groups()) { g->setSchedulerType(deriveSchedulerType(g)); diff --git a/csrc/fusion_segmenter.h b/csrc/fusion_segmenter.h index c70aab19e49..e5fcc70c7b7 100644 --- a/csrc/fusion_segmenter.h +++ b/csrc/fusion_segmenter.h @@ -40,6 +40,14 @@ struct SegmentedEdge { Val* val; void print() const; + + bool operator==(const SegmentedEdge& other) const { + return from == other.from && to == other.to && val == other.val; + } + + bool operator!=(const SegmentedEdge& other) const { + return !(*this == other); + } }; std::ostream& operator<<(std::ostream& os, const SegmentedEdge* edge); @@ -564,8 +572,18 @@ class SegmentCandidateFinder { void buildInitialSegments(); + // Replicate upcast ops when consumed by multiple expressions. This + // promotes segmented fusions to share pre-upcast tensors rather + // than post-upcast tensors. Replicated upcast ops will be reverted + // when they are grouped into the same segment. See + // https://github.com/NVIDIA/Fuser/pull/3776/ for more details. + void privatizeUpcast(); + void findSegments(); + // Revert privatized upcast ops when not necessary + void revertPrivatizedUpcast(SegmentedGroup* group); + //! Find a group found in candidates that can be merged with the //! given group and set them to be merged if found. When no //! candidate is given, SegmentedGroup::getMergeCandidates is used @@ -723,6 +741,9 @@ class SegmentCandidateFinder { // used for breaking the fusion into compute and communication segments std::optional runtime_info_; + std::unordered_map> + privatized_upcast_ops_; + //! Note: //! Segmenter should eventually rely only on runtime_info_ for //! safe caching. runtime_inputs_ is only used in translateWelford diff --git a/tests/cpp/test_segmentation.cpp b/tests/cpp/test_segmentation.cpp index 8bcf3b13d60..8fbc5bc0bc2 100644 --- a/tests/cpp/test_segmentation.cpp +++ b/tests/cpp/test_segmentation.cpp @@ -703,4 +703,107 @@ TEST_F(SegmentationTest, ForwardInputsToSegmenterSetIssue2658) { executor_cache.fusion(), out_tensors, {in_tensor}, __LINE__, __FILE__); } +// Test to verify an upcast is replicated between different segments +TEST_F(NVFuserTest, PrivatizeUpcast) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2, DataType::BFloat16); + fusion.addInput(tv0); + + auto tv1 = segment_set(tv0); + auto tv2 = castOp(DataType::Float, tv1); + + auto tv3 = sum(tv2, {0}); + fusion.addOutput(tv3); + + auto tv4 = sum(tv2, {1}); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn({16, 32}, options); + std::vector inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + + // There must be three segments, one with ExprEvalExecutor and two + // with KernelExecutor. + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_THAT(runtime->fusionSegments()->groups(), SizeIs(3)); + + for (const auto& executor : runtime->executors()) { + // Ignore the one taken care by ExprEvalExecutor + if (executor.get()->isA()) { + continue; + } + // This segment should corresponds to each of the reductions. Both + // of them should use tv1. + auto ke = dynamic_cast(executor.get()); + ASSERT_NE(ke, nullptr); + kir::Kernel* kernel = ke->compiledKernel()->kernel(); + EXPECT_EQ(kernel->inputs().size(), 1); + EXPECT_EQ(kernel->inputs().at(0)->name(), 1); + } +} + +// Unlike PrivatizeUpcast, verify replicated upcast ops are +// consolidated back as they are grouped into the same segment +TEST_F(NVFuserTest, RevertPrivatizedUpcast) { + auto fusion_ptr = std::make_unique(); + auto& fusion = *fusion_ptr; + FusionGuard fg(fusion_ptr.get()); + + auto tv0 = makeSymbolicTensor(2, DataType::BFloat16); + fusion.addInput(tv0); + + auto tv1 = segment_set(tv0); + auto tv2 = castOp(DataType::Float, tv1); + + auto tv3 = sum(tv2, {1}); + fusion.addOutput(tv3); + + auto tv4 = sum(tv2, {1}); + fusion.addOutput(tv4); + + auto options = at::TensorOptions().dtype(at::kBFloat16).device(at::kCUDA, 0); + auto t0 = at::randn({16, 32}, options); + std::vector inputs({t0}); + + FusionExecutorCache executor_cache(std::move(fusion_ptr)); + auto outputs = executor_cache.runFusionWithInputs(inputs); + testValidate(&fusion, outputs, inputs, __LINE__, __FILE__); + + // There must be two segments, one with ExprEvalExecutor and another + // with KernelExecutor. + FusionKernelRuntime* runtime = executor_cache.getMostRecentKernelRuntime(); + EXPECT_THAT(runtime->fusionSegments()->groups(), SizeIs(2)); + + for (const auto& executor : runtime->executors()) { + // Ignore the one taken care by ExprEvalExecutor + if (executor.get()->isA()) { + continue; + } + // This segment should have the two reductions. There must be only + // one upcast op with tv1 as its producer. + auto ke = dynamic_cast(executor.get()); + ASSERT_NE(ke, nullptr); + kir::Kernel* kernel = ke->compiledKernel()->kernel(); + int64_t num_upcast_ops = 0; + for (auto expr : KernelExprVisitor::getAllExprs(kernel)) { + auto uop = dynamic_cast(expr); + if (uop == nullptr || uop->getUnaryOpType() != UnaryOpType::Cast) { + continue; + } + + EXPECT_EQ(uop->in()->as()->view()->name(), 1); + + ++num_upcast_ops; + } + EXPECT_EQ(num_upcast_ops, 1); + } +} + } // namespace nvfuser