diff --git a/include/nccl_ofi_param.h b/include/nccl_ofi_param.h index ac4ddcac0..c84f5a669 100644 --- a/include/nccl_ofi_param.h +++ b/include/nccl_ofi_param.h @@ -265,6 +265,12 @@ OFI_NCCL_PARAM_INT(rdma_min_posted_bounce_buffers, "RDMA_MIN_POSTED_BOUNCE_BUFFE */ OFI_NCCL_PARAM_INT(rdma_max_posted_bounce_buffers, "RDMA_MAX_POSTED_BOUNCE_BUFFERS", 128); +/* + * Whether to spread the control message across multiple rails in round robin fashion or + * send it consistenly on one rail with a dedicated endpoint. + */ +OFI_NCCL_PARAM_INT(rdma_rr_ctrl_msg, "RR_CTRL_MSG", 1); + /* * Internode network latency reported to NCCL. Defaults to 0, unless the configured * platform sets a specific value. diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 0d8a18f3f..a53dd6dfd 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -282,6 +282,10 @@ typedef struct { typedef struct { /* Pointer to the allocated control buffer from freelist */ nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item; + /* Schedule used to transfer the control buffer. We save the + * pointer to reference it when transferring the buffer over + * network. */ + nccl_net_ofi_schedule_t *ctrl_schedule; /* Pointer to recv parent request */ nccl_net_ofi_rdma_req_t *recv_req; #if HAVE_NVTX_TRACING diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 507d10e40..f803aa927 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -1925,9 +1925,11 @@ static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep) } } - ret = ofi_process_cq_rail(ep, &ep->control_rail); - if (ret != 0) { - goto exit; + if (!ofi_nccl_rdma_rr_ctrl_msg()) { + ret = ofi_process_cq_rail(ep, &ep->control_rail); + if (ret != 0) { + goto exit; + } } /* Process any pending requests */ @@ -2114,6 +2116,12 @@ static inline int free_send_ctrl_req(nccl_net_ofi_rdma_req_t *req, (nccl_net_ofi_rdma_recv_comm_t *)req->comm; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req); + if (send_ctrl_data->ctrl_schedule != NULL) { + nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)req->comm->ep->device; + nccl_net_ofi_release_schedule(device->scheduler, send_ctrl_data->ctrl_schedule); + send_ctrl_data->ctrl_schedule = NULL; + } + if (send_ctrl_data->ctrl_fl_item) { nccl_ofi_freelist_entry_free(r_comm->ctrl_buff_fl, send_ctrl_data->ctrl_fl_item); send_ctrl_data->ctrl_fl_item = NULL; @@ -2309,10 +2317,12 @@ static inline int post_bounce_buffs(nccl_net_ofi_rdma_ep_t *ep) } } - ret = post_bounce_buffs_on_rail(ep, &ep->control_rail); - if (ret != 0) { - NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)"); - goto exit; + if (!ofi_nccl_rdma_rr_ctrl_msg()) { + ret = post_bounce_buffs_on_rail(ep, &ep->control_rail); + if (ret != 0) { + NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)"); + goto exit; + } } exit: @@ -2613,11 +2623,13 @@ static int dereg_rails(nccl_net_ofi_rdma_mr_handle_t *handle) int num_rails = handle->num_rails; /* Cleanup memory registration for control */ - rc = fi_close(&handle->control_mr->fid); - if (OFI_UNLIKELY(rc != 0)) { - NCCL_OFI_WARN("Unable to de-register memory on control mr. RC: %d, Error: %s", - rc, fi_strerror(-rc)); - ret = rc; + if (handle->control_mr != NULL) { + rc = fi_close(&handle->control_mr->fid); + if (OFI_UNLIKELY(rc != 0)) { + NCCL_OFI_WARN("Unable to de-register memory on control mr. RC: %d, Error: %s", + rc, fi_strerror(-rc)); + ret = rc; + } } /* Cleanup memory registration for data rails */ @@ -2754,13 +2766,15 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep, goto exit; } - ret = register_rail_mr_buffer(ep->control_rail.domain, ep->control_rail.ofi_ep, - -1, type, &mr_attr, regattr_flags, - &ret_handle->control_mr); - if (OFI_UNLIKELY(ret != 0)) { - free(ret_handle); - ret_handle = NULL; - goto exit; + if (!ofi_nccl_rdma_rr_ctrl_msg()) { + ret = register_rail_mr_buffer(ep->control_rail.domain, ep->control_rail.ofi_ep, + -1, type, &mr_attr, regattr_flags, + &ret_handle->control_mr); + if (OFI_UNLIKELY(ret != 0)) { + free(ret_handle); + ret_handle = NULL; + goto exit; + } } /* Register memory on each rail */ @@ -3071,6 +3085,7 @@ static inline int insert_send_ctrl_req( nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, nccl_net_ofi_rdma_req_t *recv_req) { + nccl_net_ofi_scheduler_t *scheduler = device->scheduler; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; nccl_net_ofi_rdma_req_t *send_ctrl_req = allocate_req(r_comm->nccl_ofi_reqs_fl); if (OFI_UNLIKELY(send_ctrl_req == NULL)) { @@ -3087,6 +3102,24 @@ static inline int insert_send_ctrl_req( rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(send_ctrl_req); + if (ofi_nccl_rdma_rr_ctrl_msg()) { + size_t ctrl_msg_len = nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys); + send_ctrl_data->ctrl_schedule = scheduler->get_schedule(scheduler, ctrl_msg_len, device->num_rails); + + if (OFI_UNLIKELY(!(send_ctrl_data->ctrl_schedule))) { + return -EINVAL; + } else if (OFI_UNLIKELY(send_ctrl_data->ctrl_schedule->num_xfer_infos != 1)) { + NCCL_OFI_WARN( + "Invalid schedule for outgoing control message (%zu bytes). Expected one rail, but got " + "%zu", + size, + send_ctrl_data->ctrl_schedule->num_xfer_infos); + return -EINVAL; + } + } else { + send_ctrl_data->ctrl_schedule = NULL; + } + send_ctrl_data->recv_req = recv_req; send_ctrl_data->ctrl_fl_item = NULL; @@ -4552,7 +4585,13 @@ static int post_send_conn_resp(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_req_t *req) { ssize_t rc = 0; - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail;; + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; + + if (ofi_nccl_rdma_rr_ctrl_msg()) { + comm_rail = rdma_recv_comm_get_rail(r_comm, 0); + } else { + comm_rail = &r_comm->control_rail; + } req->state = NCCL_OFI_RDMA_REQ_PENDING; rc = fi_send(comm_rail->local_ep, (void *)conn_resp, sizeof(nccl_ofi_rdma_connection_info_t), NULL, @@ -4864,14 +4903,24 @@ static int listen(nccl_net_ofi_ep_t *base_ep, /* Build handle */ memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t)); - assert(sizeof(handle->ep_name) == sizeof(ep->control_rail.local_ep_name)); - memcpy(handle->ep_name, ep->control_rail.local_ep_name, - ep->control_rail.local_ep_name_len); + /* We don't copy the size here since the handle doesn't have a size field. The size will be distributed later by the connect response message. Instead, zero the unused bytes here. */ - memset(handle->ep_name + ep->control_rail.local_ep_name_len, 0, - sizeof(handle->ep_name) - ep->control_rail.local_ep_name_len); + if (!ofi_nccl_rdma_rr_ctrl_msg()) { + assert(sizeof(handle->ep_name) == sizeof(ep->control_rail.local_ep_name)); + memcpy(handle->ep_name, ep->control_rail.local_ep_name, ep->control_rail.local_ep_name_len); + memset(handle->ep_name + ep->control_rail.local_ep_name_len, + 0, + sizeof(handle->ep_name) - ep->control_rail.local_ep_name_len); + } else { + nccl_net_ofi_ep_rail_t *first_rail = rdma_endpoint_get_rail(ep, 0); + assert(sizeof(handle->ep_name) == sizeof(first_rail->local_ep_name)); + memcpy(handle->ep_name, first_rail->local_ep_name, first_rail->local_ep_name_len); + memset(handle->ep_name + ep->control_rail.local_ep_name_len, + 0, + sizeof(handle->ep_name) - ep->control_rail.local_ep_name_len); + } /* Build listen_comm */ l_comm = (nccl_net_ofi_rdma_listen_comm_t *)calloc(1, @@ -5265,9 +5314,9 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req); nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; - - // Get communicator rail information to xfer the req - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail; + nccl_net_ofi_schedule_t *schedule = send_ctrl_data->ctrl_schedule; + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; + void *desc; nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item = send_ctrl_data->ctrl_fl_item; @@ -5276,9 +5325,19 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) (freelist_regmr_fn_handle_t *)ctrl_fl_item->fl_reginfo.mr_handle; nccl_net_ofi_rdma_mr_handle_t *mr_handle = fl_handle->mr_handle; - void *desc = fi_mr_desc(mr_handle->control_mr); - - NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, ep->control_rail.rail_id, req->comm, req, req->msg_seq_num); + if (schedule != NULL) { + /* Use round robin schedule for ctrl message */ + nccl_net_ofi_xfer_info_t *xfer_info = &schedule->rail_xfer_infos[0]; + comm_rail = rdma_recv_comm_get_rail(r_comm, xfer_info->rail_id); + assert(xfer_info->rail_id < mr_handle->num_rails); + desc = fi_mr_desc(mr_handle->mr[xfer_info->rail_id]); + NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, xfer_info->rail_id, req->comm, req, req->msg_seq_num); + } else { + /* Use control QP for ctrl message */ + comm_rail = &r_comm->control_rail; + desc = fi_mr_desc(mr_handle->control_mr); + NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, ep->control_rail.rail_id, req->comm, req, req->msg_seq_num); + } size_t ctrl_msg_len = nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys); @@ -5577,15 +5636,17 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t goto error; } - /* look for control messages and then retry the message search - to avoid unnecessary polling / queueing. */ - if (OFI_UNLIKELY(!polled_cq && !have_ctrl)) { - ret = ofi_process_cq_rail(ep, &ep->control_rail); - if (ret != 0) { - goto error; + if (!ofi_nccl_rdma_rr_ctrl_msg()) { + /* look for control messages and then retry the message search + to avoid unnecessary polling / queueing. */ + if (OFI_UNLIKELY(!polled_cq && !have_ctrl)) { + ret = ofi_process_cq_rail(ep, &ep->control_rail); + if (ret != 0) { + goto error; + } + polled_cq = true; + goto retry; } - polled_cq = true; - goto retry; } /* Determine if this should be sent eagerly. */ @@ -6021,9 +6082,16 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL; int num_rails = ep->num_rails; int rail_id = 0; - nccl_net_ofi_ep_rail_t *control_rail = &ep->control_rail; + nccl_net_ofi_ep_rail_t *control_rail; + *s_comm = NULL; + if (ofi_nccl_rdma_rr_ctrl_msg()) { + control_rail = rdma_endpoint_get_rail(ep, 0); + } else { + control_rail = &ep->control_rail; + } + /* Retrieve and validate device */ nccl_net_ofi_rdma_device_t *device = rdma_endpoint_get_device(ep); if (OFI_UNLIKELY(device == NULL)) { @@ -6098,12 +6166,18 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, goto error; } - /* Store remote address of first rail in communicator */ - ret_s_comm->control_rail.remote_addr = remote_addr; + if (ofi_nccl_rdma_rr_ctrl_msg()) { + /* Store remote address of first rail in communicator */ + ret_s_comm->rails[0].remote_addr = remote_addr; + /* Store local libfabric endpoint of first rail */ + ret_s_comm->rails[0].local_ep = control_rail->ofi_ep; + ret_s_comm->num_init_rails = 1; + } + /* Store remote address of control rail in communicator */ + ret_s_comm->control_rail.remote_addr = remote_addr; /* Store local libfabric endpoint of control rail */ ret_s_comm->control_rail.local_ep = control_rail->ofi_ep; - ret_s_comm->num_init_rails = 0; /* Allocate request free list */ ret = nccl_ofi_freelist_init(sizeof(nccl_net_ofi_rdma_req_t), 16, 16,