From 7f2db6dda29823f597db3d329f63970a7476e04b Mon Sep 17 00:00:00 2001 From: Brian Barrett Date: Mon, 24 Feb 2025 22:09:03 +0000 Subject: [PATCH] Remove support for not creating local MRs Commits 27548d4 and a1ac0f5 fixed providers that required memory registrations when FI_MR_LOCAL was set, but also broke providers that clear FI_MR_LOCAL (such as HPE's provider), because I did not account for the mr_local handling in the send/recv transport. We don't have a great way to test that case from AWS, the vast majority of transfers will either be HMEM (which creates an MR anyway) or control messages (which have a freelist which creates an MR), and the Libfabric specification allows passing an MR descriptor to a local operation even if the provider clears the FI_MR_LOCAL bit. Therefore, the best path forward seems to be removing the code to skip registration if FI_MR_LOCAL is cleared, and always creating an MR. Signed-off-by: Brian Barrett --- include/nccl_ofi.h | 5 +---- src/nccl_ofi_net.c | 13 ------------- src/nccl_ofi_rdma.c | 30 ++++++++++++------------------ src/nccl_ofi_sendrecv.c | 30 +----------------------------- 4 files changed, 14 insertions(+), 64 deletions(-) diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index ca0353c43..6e79dd239 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -103,9 +103,6 @@ extern int nic_dup_conns; read in the polling loop without protection of a lock. */ extern size_t cq_read_count; -/* Indicates if memory registration of local buffers is required */ -extern bool local_mr; - /* Indicates if endpoint memory registration is required */ extern bool endpoint_mr; @@ -745,7 +742,7 @@ int nccl_net_ofi_dealloc_mr_buffer(void *ptr, size_t size); * @return 0 (Success) * * Set required behavior flags (and print debugging information) for - * local_mr, virt_addr_mr, and endpoint_mr. + * virt_addr_mr, and endpoint_mr. */ int nccl_net_ofi_query_provider_capabilities(const struct fi_info *selected_provider, unsigned int num_providers); diff --git a/src/nccl_ofi_net.c b/src/nccl_ofi_net.c index 058e9db56..1f60596ee 100644 --- a/src/nccl_ofi_net.c +++ b/src/nccl_ofi_net.c @@ -50,8 +50,6 @@ int nic_dup_conns = 0; read in the polling loop without protection of a lock. */ size_t cq_read_count = 1; -/* Indicates if memory registration of local buffers is required */ -bool local_mr = false; /* Indicates if endpoint memory registration is required */ bool endpoint_mr = false; @@ -613,17 +611,6 @@ int nccl_net_ofi_query_provider_capabilities(const struct fi_info *selected_prov } } - /* Check if provider requires local memory registration */ - if (selected_provider->domain_attr->mr_mode & FI_MR_LOCAL) { - NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Provider %s requires registration of local memory buffers", - selected_provider->fabric_attr->prov_name); - local_mr = true; - } else { - NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Provider %s does not require registration of local memory buffers", - selected_provider->fabric_attr->prov_name); - local_mr = false; - } - /* Check if provider uses remote virtual addressing */ if (selected_provider->domain_attr->mr_mode & FI_MR_VIRT_ADDR) { NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Provider %s uses remote virtual addressing", diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 36f11cff9..fbc900ed5 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -3832,25 +3832,19 @@ static int alloc_and_reg_flush_buff(nccl_net_ofi_rdma_recv_comm_t *r_comm, int d /* make sure flush destination address does not overflow beyond host buffer */ assert(((cpu_cache_line_size * ep->num_rails) + flush_buff->size) <= system_page_size); - /* Check if provider requires registration of local buffers */ - if (local_mr == true) { - /* Register flush dummy buffer for provider access */ - ret = reg_internal_mr_ep(ep, flush_buff->host_buffer, system_page_size, - NCCL_PTR_HOST, &mr_handle); - if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Could not register dummy buffer for flush, dev: %d", - dev_id); - rc = nccl_net_ofi_dealloc_mr_buffer(flush_buff->host_buffer, - system_page_size); - if (rc != 0) { - NCCL_OFI_WARN("Unable to deallocate flush buffer (%d)", - rc); - } - flush_buff->host_buffer = MAP_FAILED; + /* Register flush dummy buffer for provider access */ + ret = reg_internal_mr_ep(ep, flush_buff->host_buffer, system_page_size, + NCCL_PTR_HOST, &mr_handle); + if (OFI_UNLIKELY(ret != 0)) { + NCCL_OFI_WARN("Could not register dummy buffer for flush, dev: %d", + dev_id); + rc = nccl_net_ofi_dealloc_mr_buffer(flush_buff->host_buffer, + system_page_size); + if (rc != 0) { + NCCL_OFI_WARN("Unable to deallocate flush buffer (%d)", + rc); } - } else { - NCCL_OFI_TRACE(NCCL_NET, - "Skip registering host buffer. local_mr: %d", local_mr); + flush_buff->host_buffer = MAP_FAILED; } flush_buff->mr_handle = mr_handle; diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index 6261fc57d..7129842db 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -538,17 +538,6 @@ static inline struct fid_domain* sendrecv_endpoint_get_ofi_domain(nccl_net_ofi_s return domain->domain; } -/* - * @brief Returns whether the registration of local buffers is not required by - * the provider. - * - * @return true if registration is not required; otherwise, false - */ - -static bool sendrecv_mr_buffer_skip_local_registration(int type) { - return (local_mr != true) && (type == NCCL_PTR_HOST); -} - /* * @brief Registers memory region (both HOST and CUDA) * @@ -568,18 +557,6 @@ static int sendrecv_mr_buffers_register(struct fid_domain *domain, struct fi_mr_attr mr_attr = {}; uint64_t regattr_flags = 0; - /* Check if provider requires registration of local buffers */ - if (sendrecv_mr_buffer_skip_local_registration(type)) { - NCCL_OFI_TRACE(NCCL_NET, - "Skip registering host buffer. local_mr: %d", local_mr); - /* the mr handle will still be threaded through NCCL, - * so we still need some sentinal to tell us not to try - * and use the registration. NULL is as good as any. - */ - *mr_handle = NULL; - goto exit; - } - mr_attr.access = FI_SEND | FI_RECV; nccl_ofi_mr_ckey_fill_mr_attrs(ckey, &mr_attr, ®attr_flags); switch (type) { @@ -822,11 +799,6 @@ static int sendrecv_comm_mr_base_reg(nccl_net_ofi_comm_t *base_comm, nccl_ofi_mr_cache_t *mr_cache = domain->base.mr_cache; void *ret_handle = NULL; - if (sendrecv_mr_buffer_skip_local_registration(type)) { - /* Registraton and caching are unnecessary */ - goto exit; - } - if (mr_cache) { /* * MR cache is locked between lookup and insert, to be sure we @@ -870,7 +842,7 @@ static int sendrecv_comm_mr_base_reg(nccl_net_ofi_comm_t *base_comm, if (mr_cache) { nccl_net_ofi_mutex_unlock(&mr_cache->lock); } -exit: + *mhandle = ret_handle; return ret; }