diff --git a/yacl/link/context.cc b/yacl/link/context.cc index 986eaa9d..7ab05577 100644 --- a/yacl/link/context.cc +++ b/yacl/link/context.cc @@ -325,9 +325,13 @@ Buffer Context::RecvInternal(size_t src_rank, const std::string& key) { return value; } -std::unique_ptr Context::Spawn() { +std::unique_ptr Context::Spawn(const std::string& id) { ContextDesc sub_desc = desc_; - sub_desc.id = fmt::format("{}-{}", desc_.id, child_counter_++); + if (id.empty()) { + sub_desc.id = fmt::format("{}-{}", desc_.id, child_counter_++); + } else { + sub_desc.id = fmt::format("{}-{}", desc_.id, id); + } // sub-context share the same event-loop and statistics with parent. auto sub_ctx = diff --git a/yacl/link/context.h b/yacl/link/context.h index b291ae05..67f886ae 100644 --- a/yacl/link/context.h +++ b/yacl/link/context.h @@ -265,7 +265,7 @@ class Context { void ConnectToMesh( spdlog::level::level_enum connect_log_level = spdlog::level::debug); - std::unique_ptr Spawn(); + std::unique_ptr Spawn(const std::string& id = ""); // Create a new Context from a subset of original parities. // Party which not in `sub_parties` should not call the SubWorld() method. @@ -327,7 +327,6 @@ class Context { // stateful properties. size_t counter_ = 0U; // collective algorithm counter. std::map p2p_counter_; - size_t child_counter_ = 0U; uint64_t recv_timeout_ms_;