diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 59f36d1d2..703f1978d 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -111,6 +111,8 @@ typedef uint16_t nccl_ofi_rdma_msg_type_t; * allocate a rdma memory registration handle with `num_rails' rails. */ typedef struct nccl_net_ofi_rdma_mr_handle { + struct fid_mr *control_mr; + int num_rails; /* Array of size `num_rails' */ @@ -394,13 +396,15 @@ typedef struct nccl_ofi_rdma_connection_info { * on the receiver side */ uint32_t remote_comm_id; + nccl_ofi_rdma_ep_name_t control_ep_name; + /* Array of `MAX_NUM_RAILS` `nccl_ofi_rdma_ep_name_t` * structs. The member `num_rails` indicates the number of * entries that are in use. */ nccl_ofi_rdma_ep_name_t ep_names[MAX_NUM_RAILS]; } nccl_ofi_rdma_connection_info_t; /* Since this is a message on the wire, check that it has the expected size */ -_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 272, +_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 336, "Wrong size for RDMA connect message"); /* @@ -452,6 +456,8 @@ typedef struct nccl_net_ofi_rdma_send_comm { nccl_ofi_msgbuff_t *msgbuff; + nccl_net_ofi_rdma_send_comm_rail_t control_rail; + /* Number of rails */ int num_rails; @@ -534,6 +540,7 @@ typedef struct nccl_net_ofi_rdma_recv_comm { #if HAVE_NVTX_TRACING nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; #endif + nccl_net_ofi_rdma_recv_comm_rail_t control_rail; /* Number of rails */ int num_rails; @@ -626,6 +633,8 @@ struct nccl_net_ofi_rdma_ep { * and its base struct. */ nccl_net_ofi_ep_t base; + nccl_net_ofi_ep_rail_t control_rail; + /* Number of rails */ int num_rails; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index af1d20758..fe5885beb 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -1677,6 +1677,11 @@ static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep) } } + ret = ofi_process_cq_rail(ep, &ep->control_rail); + if (ret != 0) { + goto exit; + } + /* Process any pending requests */ ret = process_pending_reqs(ep); if (OFI_UNLIKELY(ret != 0 && ret != -FI_EAGAIN)) { @@ -2012,6 +2017,12 @@ static inline int post_bounce_buffs(nccl_net_ofi_rdma_ep_t *ep) } } + ret = post_bounce_buffs_on_rail(ep, &ep->control_rail); + if (ret != 0) { + NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)"); + goto exit; + } + exit: return ret; } @@ -2304,14 +2315,23 @@ static int prepare_recv_conn_req(nccl_net_ofi_rdma_listen_comm_t *l_comm) */ static int dereg_rails(nccl_net_ofi_rdma_mr_handle_t *handle) { - /* Cleanup memory registration */ int ret = 0; + int rc = 0; int num_rails = handle->num_rails; + /* Cleanup memory registration for control */ + rc = fi_close(&handle->control_mr->fid); + if (OFI_UNLIKELY(rc != 0)) { + NCCL_OFI_WARN("Unable to de-register memory on control mr. RC: %d, Error: %s", + rc, fi_strerror(-rc)); + ret = rc; + } + + /* Cleanup memory registration for data rails */ for (int rail_id = 0; rail_id != num_rails; ++rail_id) { /* No memory registration available for this rail */ if (!handle->mr[rail_id]) continue; - int rc = fi_close(&handle->mr[rail_id]->fid); + rc = fi_close(&handle->mr[rail_id]->fid); if (OFI_UNLIKELY(rc != 0)) { NCCL_OFI_WARN("Unable to de-register memory. RC: %d, Error: %s", rc, fi_strerror(-rc)); @@ -2361,7 +2381,7 @@ static int dereg_mr_ep(nccl_net_ofi_rdma_mr_handle_t *mr_handle, return -EINVAL; } - if (OFI_UNLIKELY(mr_handle->num_rails < 1)) { + if (OFI_UNLIKELY(mr_handle->num_rails < 0)) { NCCL_OFI_WARN("Unexpected number of rails in rdma memory registration handle"); return -EINVAL; } @@ -2444,6 +2464,15 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep, goto exit; } + ret = register_rail_mr_buffer(get_device_rail(device, 0)->domain, ep->control_rail.ofi_ep, + -1, type, &mr_attr, + &ret_handle->control_mr); + if (OFI_UNLIKELY(ret != 0)) { + free(ret_handle); + ret_handle = NULL; + goto exit; + } + /* Register memory on each rail */ ret_handle->num_rails = num_rails; for (int rail_id = 0; rail_id != num_rails; ++rail_id) { @@ -2743,7 +2772,8 @@ static inline nccl_net_ofi_rdma_req_t *allocate_req(nccl_ofi_freelist_t *fl) } /** - * @brief Allocate a new send ctrl req from freelist + * @brief Allocate a new control message that the receiver will + * send to the sender describing the recv buffer. */ static inline int insert_send_ctrl_req( nccl_net_ofi_rdma_recv_comm_t *r_comm, @@ -3481,7 +3511,8 @@ static inline nccl_net_ofi_rdma_recv_comm_t *calloc_rdma_recv_comm(int num_rails * @return Receive communicator object, on success * NULL, on error */ -static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device_t *device, +static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen_comm_t *l_comm, + nccl_net_ofi_rdma_device_t *device, nccl_net_ofi_rdma_ep_t *l_comm_ep, nccl_ofi_rdma_connection_info_t *conn_msg) { @@ -3576,6 +3607,25 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device /* Add ourselves to ep's lookup array */ set_comm(device, r_comm->local_comm_id, &r_comm->base.base); + r_comm->control_rail.local_ep = l_comm->leader_local_ep; + ret = fi_av_insert(ep->control_rail.av, (void *)conn_msg->control_ep_name.ep_name, 1, + &r_comm->control_rail.remote_addr, 0, NULL); + if (OFI_UNLIKELY(ret != 1)) { + NCCL_OFI_WARN("Unable to insert remote address into address vector " + "for device %d. RC: %d", + dev_id, fi_strerror(-ret)); + goto error; + } + + ret = fi_av_insert(ep->control_rail.av, (void *)ep->control_rail.local_ep_name, 1, + &r_comm->control_rail.local_addr, 0, NULL); + if (OFI_UNLIKELY(ret != 1)) { + NCCL_OFI_WARN("Unable to insert local address into address vector " + "for device %d. RC: %d", + dev_id, fi_strerror(-ret)); + goto error; + } + /* Allocate array of communicator rails */ r_comm->num_rails = num_rails; @@ -3748,7 +3798,7 @@ static int post_send_conn_resp(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_req_t *req) { ssize_t rc = 0; - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = get_recv_comm_rail(r_comm, 0); + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail;; req->state = NCCL_OFI_RDMA_REQ_PENDING; rc = fi_send(comm_rail->local_ep, (void *)conn_resp, sizeof(nccl_ofi_rdma_connection_info_t), NULL, @@ -3888,7 +3938,7 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm, } /* Prepare receive communicator object for the received peer connection */ - r_comm = prepare_recv_comm(device, l_comm_ep, conn_msg); + r_comm = prepare_recv_comm(l_comm, device, l_comm_ep, conn_msg); if (OFI_UNLIKELY(r_comm == NULL)) { ret = -EINVAL; goto exit; @@ -4042,7 +4092,6 @@ static int listen(nccl_net_ofi_ep_t *base_ep, nccl_net_ofi_rdma_listen_comm_t *l_comm = NULL; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_ep; - nccl_net_ofi_ep_rail_t *first_rail = get_rail(ep, 0); /* Retrieve and validate device */ nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)ep->base.device; @@ -4052,14 +4101,14 @@ static int listen(nccl_net_ofi_ep_t *base_ep, /* Build handle */ memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t)); - assert(sizeof(handle->ep_name) == sizeof(first_rail->local_ep_name)); - memcpy(handle->ep_name, first_rail->local_ep_name, - first_rail->local_ep_name_len); + assert(sizeof(handle->ep_name) == sizeof(ep->control_rail.local_ep_name)); + memcpy(handle->ep_name, ep->control_rail.local_ep_name, + ep->control_rail.local_ep_name_len); /* We don't copy the size here since the handle doesn't have a size field. The size will be distributed later by the connect response message. Instead, zero the unused bytes here. */ - memset(handle->ep_name + first_rail->local_ep_name_len, 0, - sizeof(handle->ep_name) - first_rail->local_ep_name_len); + memset(handle->ep_name + ep->control_rail.local_ep_name_len, 0, + sizeof(handle->ep_name) - ep->control_rail.local_ep_name_len); /* Build listen_comm */ l_comm = (nccl_net_ofi_rdma_listen_comm_t *)calloc(1, @@ -4076,7 +4125,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, l_comm->base.base.dev_id = dev_id; l_comm->base.accept = accept; l_comm->base.close = listen_close; - l_comm->leader_local_ep = first_rail->ofi_ep; + l_comm->leader_local_ep = ep->control_rail.ofi_ep; /* Allocate listen communicator ID */ int comm_id = nccl_ofi_idpool_allocate_id(device->comm_idpool); @@ -4381,11 +4430,9 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req); nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; - const int control_rail_id = 0; // Get communicator rail information to xfer the req - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; - comm_rail = get_recv_comm_rail(r_comm, control_rail_id); + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail; nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item = send_ctrl_data->ctrl_fl_item; @@ -4394,7 +4441,7 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) (freelist_regmr_fn_handle_t *)ctrl_fl_item->fl_reginfo.mr_handle; nccl_net_ofi_rdma_mr_handle_t *mr_handle = fl_handle->mr_handle; - void *desc = fi_mr_desc(mr_handle->mr[control_rail_id]); + void *desc = fi_mr_desc(mr_handle->control_mr); NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, xfer_info->rail_id, req->comm, req, req->msg_seq_num); @@ -4808,6 +4855,14 @@ static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, /* Set number of rails to be sent back to remote for verification */ conn_msg->num_rails = num_rails; + /* Set libfabric endpoint name for control rail */ + memcpy(conn_msg->control_ep_name.ep_name, + ep->control_rail.local_ep_name, + ep->control_rail.local_ep_name_len); + conn_msg->control_ep_name.ep_name_len = + ep->control_rail.local_ep_name_len; + + /* Set libfabric endpoint names for each rail */ for (int rail_id = 0; rail_id != num_rails; ++rail_id) { memcpy(conn_msg->ep_names[rail_id].ep_name, @@ -4866,6 +4921,14 @@ static inline int init_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep) return ret; } + ep->control_rail.min_bounce_posted = NCCL_OFI_DIV_CEIL( + ofi_nccl_rdma_min_posted_bounce_buffers(), ep->num_rails + ); + ep->control_rail.max_bounce_posted = NCCL_OFI_DIV_CEIL( + ofi_nccl_rdma_max_posted_bounce_buffers(), ep->num_rails + ); + ret = pthread_mutex_init(&ep->control_rail.bounce_mutex, NULL); + for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); rail->min_bounce_posted = NCCL_OFI_DIV_CEIL( @@ -4941,7 +5004,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL; int num_rails = ep->num_rails; int rail_id = 0; - nccl_net_ofi_ep_rail_t *first_rail = get_rail(ep, 0); + nccl_net_ofi_ep_rail_t *control_rail = &ep->control_rail; *s_comm = NULL; /* Retrieve and validate device */ @@ -4994,7 +5057,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret_s_comm->num_rails = num_rails; /* Insert remote name into AV of first rail */ - ret = fi_av_insert(first_rail->av, + ret = fi_av_insert(control_rail->av, (void *)handle->ep_name, 1, &remote_addr, 0, NULL); if (OFI_UNLIKELY(ret != 1)) { @@ -5005,11 +5068,11 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, } /* Store remote address of first rail in communicator */ - ret_s_comm->rails[0].remote_addr = remote_addr; + ret_s_comm->control_rail.remote_addr = remote_addr; - /* Store local libfabric endpoint of first rail */ - ret_s_comm->rails[0].local_ep = first_rail->ofi_ep; - ret_s_comm->num_init_rails = 1; + /* Store local libfabric endpoint of control rail */ + ret_s_comm->control_rail.local_ep = control_rail->ofi_ep; + ret_s_comm->num_init_rails = 0; /* Allocate request free list */ ret = nccl_ofi_freelist_init(sizeof(nccl_net_ofi_rdma_req_t), 16, 16, @@ -5127,7 +5190,7 @@ static int post_send_conn(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_net_ofi_rdma_req_t *req) { ssize_t rc = 0; - nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = get_send_comm_rail(s_comm, 0); + nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = &s_comm->control_rail; /* * TODO: replace it with API of FI_INJECT type when most of @@ -5349,6 +5412,7 @@ static void ep_rail_release(nccl_net_ofi_ep_rail_t *rail, int dev_id) */ static void release_rdma_ep_resources(nccl_net_ofi_rdma_ep_t *ep, int dev_id) { + ep_rail_release(&ep->control_rail, dev_id); for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) { ep_rail_release(get_rail(ep, rail_id), dev_id); } @@ -5584,6 +5648,15 @@ static int create_ep(nccl_net_ofi_rdma_device_t *device, ep->use_long_rkeys = device->use_long_rkeys; + /* we pass 0 as the railid for the control rail, so + * that any lookups based on railid in the domain find + * the right domain */ + ret = ep_rail_init(ep, device->base.dev_id, 0, &device->device_rails[0], &ep->control_rail); + if (ret != 0) { + NCCL_OFI_WARN("Initializing control rail failed"); + goto error; + } + ret = init_rail_ofi_resources(device, ep); if (ret != 0) { goto error;