diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 7238dedb8..a4ba63bf4 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -67,6 +67,8 @@ typedef enum nccl_net_ofi_rdma_req_type { * allocate a rdma memory registration handle with `num_rails' rails. */ typedef struct nccl_net_ofi_rdma_mr_handle { + struct fid_mr *control_mr; + int num_rails; /* Array of size `num_rails' */ @@ -102,6 +104,8 @@ struct nccl_net_ofi_rdma_req; struct nccl_net_ofi_rdma_ep; typedef struct nccl_net_ofi_rdma_req nccl_net_ofi_rdma_req_t; typedef struct nccl_net_ofi_rdma_ep nccl_net_ofi_rdma_ep_t; +struct nccl_net_ofi_ep_rail; +typedef struct nccl_net_ofi_ep_rail nccl_net_ofi_ep_rail_t; typedef struct { /* Bounce buffer freelist item */ @@ -116,7 +120,7 @@ typedef struct { * This is useful for re-posting the bounce buffer on the same rail * when it gets completed. */ - int bounce_rail_id; + nccl_net_ofi_ep_rail_t *rail; /* * Back-pointer to associated endpoint */ @@ -157,10 +161,6 @@ typedef struct { typedef struct { /* Pointer to the allocated control buffer from freelist */ nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item; - /* Schedule used to transfer the control buffer. We save the - * pointer to reference it when transferring the buffer over - * network. */ - nccl_net_ofi_schedule_t *ctrl_schedule; /* Pointer to recv parent request */ nccl_net_ofi_rdma_req_t *recv_req; } rdma_req_send_ctrl_data_t; @@ -297,6 +297,8 @@ typedef struct nccl_ofi_rdma_connection_info { side. The receiver must use this tag when sending messages to sender */ uint64_t local_tag; + nccl_ofi_rdma_ep_name_t control_ep_name; + /* Number of rails */ int num_rails; @@ -357,6 +359,8 @@ typedef struct nccl_net_ofi_rdma_send_comm { nccl_ofi_msgbuff_t *msgbuff; + nccl_net_ofi_rdma_send_comm_rail_t control_rail; + /* Number of rails */ int num_rails; @@ -430,6 +434,8 @@ typedef struct nccl_net_ofi_rdma_recv_comm { /* Free list to track control buffers, for sending RDMA control messages */ nccl_ofi_freelist_t *ctrl_buff_fl; + nccl_net_ofi_rdma_recv_comm_rail_t control_rail; + /* Number of rails */ int num_rails; @@ -467,7 +473,9 @@ typedef struct nccl_net_ofi_rdma_listen_comm { * Endpoint rail encapsulates data of an endpoint for a * specific rail. */ -typedef struct nccl_net_ofi_ep_rail { +struct nccl_net_ofi_ep_rail { + int rail_id; + /* Local libfabric endpoint handle */ struct fid_ep *ofi_ep; @@ -492,7 +500,7 @@ typedef struct nccl_net_ofi_ep_rail { size_t max_bounce_posted; /* Mutex for bounce buffer operations */ pthread_mutex_t bounce_mutex; -} nccl_net_ofi_ep_rail_t; +}; /* * @brief RDMA Endpoint @@ -516,6 +524,8 @@ struct nccl_net_ofi_rdma_ep { /* Current available tag ID */ uint64_t tag; + nccl_net_ofi_ep_rail_t control_rail; + /* Number of rails */ int num_rails; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index a0f09b17b..03ed59d01 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -211,7 +211,7 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req); static int receive_progress(nccl_net_ofi_rdma_req_t *req, bool add_to_pending); -static int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, int rail_id); +static int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_ep_rail_t *rail); static inline int repost_bounce_buff(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_rdma_req_t *bounce_req); @@ -1001,14 +1001,12 @@ static void copy_ctrl_data(nccl_net_ofi_rdma_req_t *bounce_req, nccl_net_ofi_rdm * Post all bounce buffers for a rail if we don't have enough */ static inline int check_post_bounce_buffers_rail(nccl_net_ofi_rdma_ep_t *ep, - int rail_id) + nccl_net_ofi_ep_rail_t *rail) { - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - /* Not taking lock here since we are only reading a value. If needed, post_bounce_buffs_on_rail will take the lock. */ if (rail->num_bounce_posted < rail->min_bounce_posted) { - return post_bounce_buffs_on_rail(ep, rail_id); + return post_bounce_buffs_on_rail(ep, rail); } return 0; @@ -1040,20 +1038,18 @@ static inline int repost_bounce_buff(nccl_net_ofi_rdma_ep_t *ep, } rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); - int rail_id = bounce_data->bounce_rail_id; /* Next, check the posted count and post more buffers if needed. */ - return check_post_bounce_buffers_rail(ep, rail_id); + return check_post_bounce_buffers_rail(ep, bounce_data->rail); } /* * @brief Decrement the number of bounce buffers posted for the rail * corresponding to bounce_req */ -static inline int decrease_bounce_buff_cnt(nccl_net_ofi_rdma_ep_t *ep, int rail_id) +static inline int decrease_bounce_buff_cnt(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) { - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - int ret = pthread_mutex_lock(&rail->bounce_mutex); if (ret) { NCCL_OFI_WARN("Failed to lock bounce_mutex"); @@ -1069,7 +1065,7 @@ static inline int decrease_bounce_buff_cnt(nccl_net_ofi_rdma_ep_t *ep, int rail_ return -ret; } - return check_post_bounce_buffers_rail(ep, rail_id); + return check_post_bounce_buffers_rail(ep, rail); } /** @@ -1083,7 +1079,6 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_net_ofi_rdma_ep_t *ep) { int ret; - int bounce_rail_id = get_bounce_data(bounce_req)->bounce_rail_id; nccl_ofi_msgbuff_status_t stat; nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert(s_comm->msgbuff, msg_seq_num, @@ -1092,7 +1087,7 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { /* Inserted! In this case sender has not yet called send() for this message, so return success and initiate RDMA write when sender calls send(). */ - return decrease_bounce_buff_cnt(ep, bounce_rail_id); + return decrease_bounce_buff_cnt(ep, get_bounce_data(bounce_req)->rail); } if (mb_res != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_INPROGRESS) { @@ -1202,10 +1197,9 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_ep_t *ep) { int ret; - int bounce_rail_id = get_bounce_data(bounce_req)->bounce_rail_id; /* Decrease bounce buffer count. It will be incremented again when reposting */ - ret = decrease_bounce_buff_cnt(ep, bounce_rail_id); + ret = decrease_bounce_buff_cnt(ep, get_bounce_data(bounce_req)->rail); if (ret != 0) { return ret; } @@ -1383,8 +1377,8 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm); * error, on others */ static inline int process_completions(struct fi_cq_tagged_entry *cq_entry, - uint64_t num_cqes, nccl_net_ofi_rdma_ep_t *ep, - int rail_id) + uint64_t num_cqes, nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) { int ret = 0; nccl_net_ofi_rdma_req_t *req = NULL; @@ -1453,16 +1447,16 @@ static inline int process_completions(struct fi_cq_tagged_entry *cq_entry, /* This is a bounce buffer receive event. It could be a ctrl message receive (send comm) or an eager message receive (recv comm) */ - ret = handle_bounce_recv(&cq_entry[comp_idx], rail_id); + ret = handle_bounce_recv(&cq_entry[comp_idx], rail->rail_id); } else if (comp_flags & FI_REMOTE_WRITE) { /* Type 6: Remote-initiated write is complete */ - ret = handle_write_comp(&cq_entry[comp_idx], ep, rail_id); + ret = handle_write_comp(&cq_entry[comp_idx], ep, rail->rail_id); } else if (comp_flags & FI_WRITE) { /* Type 5: Local-initiated write is complete */ req = op_ctx; rdma_req_send_data_t *send_data = get_send_data(req); - NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE(req->dev_id, rail_id, req->comm, req->msg_seq_num, req); + NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE(req->dev_id, rail->rail_id, req->comm, req->msg_seq_num, req); if (inc_req_completion(req, 0, send_data->total_num_compls)) { ret = ncclInternalError; @@ -1676,14 +1670,7 @@ static int process_pending_reqs(nccl_net_ofi_rdma_ep_t *ep) return rc; } -/* - * @brief Process completion entries for the given completion quque. - * This also updates several request fileds like size, status, etc - * - * @return 0, on success - * error, on others - */ -static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep) +static int ofi_process_cq_rail(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_ep_rail_t *rail) { ssize_t rc = 0; int ret = 0; @@ -1691,83 +1678,106 @@ static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep) struct fi_cq_tagged_entry cqe_tagged_buffers[cq_read_count]; nccl_net_ofi_rdma_req_t *req = NULL; - for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) { - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - - while (true) { - /* Receive completions for the given endpoint */ - rc = fi_cq_read(rail->cq, cqe_tagged_buffers, cq_read_count); - if (rc > 0) { - ret = process_completions( - cqe_tagged_buffers, rc, ep, rail_id); - if (OFI_UNLIKELY(ret != 0)) - goto exit; - } else if (OFI_UNLIKELY(rc == -FI_EAVAIL)) { - rc = fi_cq_readerr(rail->cq, &err_buffer, 0); - if (OFI_UNLIKELY(rc == -FI_EAGAIN)) { - /* - * Error not available yet. - * fi_cq_read will keep returning -FI_EAVAIL so just bail out and try again later. - */ - break; - } else if (OFI_UNLIKELY(rc < 0)) { - NCCL_OFI_WARN("Unable to read from fi_cq_readerr. RC: %zd. Error: %s", - rc, - fi_strerror(-rc)); + while (true) { + /* Receive completions for the given endpoint */ + rc = fi_cq_read(rail->cq, cqe_tagged_buffers, cq_read_count); + if (rc > 0) { + ret = process_completions( + cqe_tagged_buffers, rc, ep, rail); + if (OFI_UNLIKELY(ret != 0)) + goto exit; + } else if (OFI_UNLIKELY(rc == -FI_EAVAIL)) { + rc = fi_cq_readerr(rail->cq, &err_buffer, 0); + if (OFI_UNLIKELY(rc == -FI_EAGAIN)) { + /* + * Error not available yet. + * fi_cq_read will keep returning -FI_EAVAIL so just bail out and try again later. + */ + break; + } else if (OFI_UNLIKELY(rc < 0)) { + NCCL_OFI_WARN("Unable to read from fi_cq_readerr. RC: %zd. Error: %s", + rc, + fi_strerror(-rc)); + ret = ncclSystemError; + goto exit; + } + if (err_buffer.flags & FI_REMOTE_WRITE) { + req = get_req_from_imm_data(ep, err_buffer.data); + if (!req) { + NCCL_OFI_WARN("Unknown remote write error"); ret = ncclSystemError; goto exit; } - if (err_buffer.flags & FI_REMOTE_WRITE) { - req = get_req_from_imm_data(ep, err_buffer.data); - if (!req) { - NCCL_OFI_WARN("Unknown remote write error"); - ret = ncclSystemError; - goto exit; - } - } else { - /* For all other operations, ctx should be a req */ - if (!err_buffer.op_context) { - NCCL_OFI_WARN("Operation with NULL context completed with error!"); - ret = ncclSystemError; - goto exit; - } - req = err_buffer.op_context; - } - - NCCL_OFI_WARN("Request %p completed with error. RC: %d. Error: %s. Completed length: %ld, Request: %s", - req, - err_buffer.err, - fi_cq_strerror(rail->cq, - err_buffer.prov_errno, - err_buffer.err_data, NULL, 0), - (long)err_buffer.len, - nccl_net_ofi_req_str(req)); - set_request_state_to_error(req); - - if (req->type == NCCL_OFI_RDMA_BOUNCE) { - /* A bounce buffer receive failed -- this is an internal error so bail out */ - NCCL_OFI_WARN("Fatal: Bounce buffer recv completed with error"); + } else { + /* For all other operations, ctx should be a req */ + if (!err_buffer.op_context) { + NCCL_OFI_WARN("Operation with NULL context completed with error!"); ret = ncclSystemError; goto exit; } - } else if (rc == -FI_EAGAIN) { - /* No completions to process */ - break; - } else { - NCCL_OFI_WARN("Unable to retrieve completion queue entries. RC: %zd, ERROR: %s", - rc, fi_strerror(-rc)); + req = err_buffer.op_context; + } + + NCCL_OFI_WARN("Request %p completed with error. RC: %d. Error: %s. Completed length: %ld, Request: %s", + req, + err_buffer.err, + fi_cq_strerror(rail->cq, + err_buffer.prov_errno, + err_buffer.err_data, NULL, 0), + (long)err_buffer.len, + nccl_net_ofi_req_str(req)); + set_request_state_to_error(req); + + if (req->type == NCCL_OFI_RDMA_BOUNCE) { + /* A bounce buffer receive failed -- this is an internal error so bail out */ + NCCL_OFI_WARN("Fatal: Bounce buffer recv completed with error"); ret = ncclSystemError; goto exit; } + } else if (rc == -FI_EAGAIN) { + /* No completions to process */ + break; + } else { + NCCL_OFI_WARN("Unable to retrieve completion queue entries. RC: %zd, ERROR: %s", + rc, fi_strerror(-rc)); + ret = ncclSystemError; + goto exit; } + } + +exit: + return ret; +} + +/* + * @brief Process completion entries for the given completion quque. + * This also updates several request fileds like size, status, etc + * + * @return 0, on success + * error, on others + */ +static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep) +{ + int ret; + for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) { + nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); + + ret = ofi_process_cq_rail(ep, rail); + if (ret != 0) { + goto exit; + } + } + + ret = ofi_process_cq_rail(ep, &ep->control_rail); + if (ret != 0) { + goto exit; } /* Process any pending requests */ - rc = process_pending_reqs(ep); - if (OFI_UNLIKELY(rc != 0 && rc != -FI_EAGAIN)) { - NCCL_OFI_WARN("Failed call to process_pending_reqs: %zd", rc); - ret = ncclSystemError; + ret = process_pending_reqs(ep); + if (OFI_UNLIKELY(ret != 0 && ret != -FI_EAGAIN)) { + NCCL_OFI_WARN("Failed call to process_pending_reqs: %zd", ret); } exit: @@ -1925,12 +1935,6 @@ static inline int free_send_ctrl_req(nccl_net_ofi_rdma_req_t *req, (nccl_net_ofi_rdma_recv_comm_t *)req->comm; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req); - if (send_ctrl_data->ctrl_schedule) { - nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)req->comm->ep->device; - nccl_net_ofi_release_schedule(device->scheduler, send_ctrl_data->ctrl_schedule); - send_ctrl_data->ctrl_schedule = NULL; - } - if (send_ctrl_data->ctrl_fl_item) { nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm; nccl_ofi_freelist_entry_free(r_comm->ctrl_buff_fl, send_ctrl_data->ctrl_fl_item); @@ -2003,7 +2007,7 @@ static inline int free_bounce_req(nccl_net_ofi_rdma_req_t *req, } static inline nccl_net_ofi_rdma_req_t *alloc_bounce_req(nccl_net_ofi_rdma_ep_t *ep, - int rail_id) + nccl_net_ofi_ep_rail_t *rail) { nccl_net_ofi_rdma_req_t *req = allocate_req(ep->bounce_buff_reqs_fl); if (!req) return NULL; @@ -2026,12 +2030,13 @@ static inline nccl_net_ofi_rdma_req_t *alloc_bounce_req(nccl_net_ofi_rdma_ep_t * bounce_data->bounce_fl_item = bounce_fl_item; bounce_data->buff_len = ep->bounce_buff_size; - bounce_data->bounce_rail_id = rail_id; + bounce_data->rail = rail; bounce_data->ep = ep; return req; } -static inline int handle_bounce_eagain(nccl_net_ofi_rdma_ep_t *ep, int rail_id, +static inline int handle_bounce_eagain(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail, nccl_net_ofi_rdma_req_t *req, size_t num_buffs_failed) { /* Add to pending reqs queue */ @@ -2042,8 +2047,6 @@ static inline int handle_bounce_eagain(nccl_net_ofi_rdma_ep_t *ep, int rail_id, } NCCL_OFI_TRACE_PENDING_INSERT(req); - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - ret = pthread_mutex_lock(&rail->bounce_mutex); if (ret != 0) { NCCL_OFI_WARN("Failed to lock bounce_mutex: %d", ret); @@ -2062,12 +2065,11 @@ static inline int handle_bounce_eagain(nccl_net_ofi_rdma_ep_t *ep, int rail_id, return ret; } -static inline int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, int rail_id) +static inline int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) { int ret = 0; - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - ret = pthread_mutex_lock(&rail->bounce_mutex); if (ret != 0) { NCCL_OFI_WARN("Failed to lock bounce_mutex: %d", ret); @@ -2087,7 +2089,7 @@ static inline int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, int rail /* Post all the bounce buffers we need */ for (size_t i = 0; i < buffers_needed; ++i) { nccl_net_ofi_rdma_req_t *req = - alloc_bounce_req(ep, rail_id); + alloc_bounce_req(ep, rail); if (!req) { NCCL_OFI_WARN("Failed to allocate bounce req"); return -ENOMEM; @@ -2097,7 +2099,7 @@ static inline int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, int rail /* Update posted count */ /* We failed to post num_buffs_failed buffers that we promised above */ size_t num_buffs_failed = buffers_needed - i - 1; - ret = handle_bounce_eagain(ep, rail_id, req, num_buffs_failed); + ret = handle_bounce_eagain(ep, rail, req, num_buffs_failed); if (ret != 0) return ret; break; @@ -2118,13 +2120,20 @@ static inline int post_bounce_buffs(nccl_net_ofi_rdma_ep_t *ep) int ret = 0; for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { - ret = post_bounce_buffs_on_rail(ep, rail_id); + nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); + ret = post_bounce_buffs_on_rail(ep, rail); if (ret != 0) { NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail"); goto exit; } } + ret = post_bounce_buffs_on_rail(ep, &ep->control_rail); + if (ret != 0) { + NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)"); + goto exit; + } + exit: return ret; } @@ -2418,7 +2427,7 @@ static int post_recv_conn_resp(nccl_net_ofi_rdma_send_comm_t *s_comm, int ret = 0; int dev_id = s_comm->base.base.dev_id; assert(s_comm && s_comm->num_rails > 0); - nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = get_send_comm_rail(s_comm, 0); + nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = &s_comm->control_rail; nccl_net_ofi_rdma_req_t *req = s_comm->conn_resp_req; /* Post a buffer for receiving connect response requests */ @@ -2504,6 +2513,13 @@ static int dereg_rails(nccl_net_ofi_rdma_mr_handle_t *handle) int num_rails = handle->num_rails; int rc = 0; + rc = fi_close(&handle->control_mr->fid); + if (OFI_UNLIKELY(rc != 0)) { + NCCL_OFI_WARN("Unable to de-register memory on control mr. RC: %d, Error: %s", + rc, fi_strerror(-rc)); + ret = ncclSystemError; + } + for (int rail_id = 0; rail_id != num_rails; ++rail_id) { /* No memory registration available for this rail */ if (!handle->mr[rail_id]) continue; @@ -2583,9 +2599,17 @@ static int reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep, void *data, goto exit; } - ret_handle->num_rails = num_rails; + ret = register_rail_mr_buffer(device->device_rails[0].domain, ep->control_rail.ofi_ep, + -1, type, &mr_attr, + &ret_handle->control_mr); + if (OFI_UNLIKELY(ret != 0)) { + free(ret_handle); + ret_handle = NULL; + goto exit; + } /* Register memory on each rail */ + ret_handle->num_rails = num_rails; for (int rail_id = 0; rail_id != num_rails; ++rail_id) { nccl_net_ofi_rdma_device_rail_t *dev_rail = get_device_rail(device, rail_id); nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); @@ -2690,12 +2714,12 @@ static int dereg_mr_ep(nccl_net_ofi_rdma_mr_handle_t *mr_handle, int ret = 0; if (OFI_UNLIKELY(mr_handle == NULL)) { - NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Null MR handle provided. This is an error."); + NCCL_OFI_WARN("Null MR handle provided. This is an error."); return ncclInternalError; } - if (OFI_UNLIKELY(mr_handle->num_rails < 1)) { - NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Unexpected number of rails in rdma memory registration handle"); + if (OFI_UNLIKELY(mr_handle->num_rails < 0)) { + NCCL_OFI_WARN("Unexpected number of rails in rdma memory registration handle"); return ncclInternalError; } @@ -2822,7 +2846,8 @@ static inline nccl_net_ofi_rdma_req_t *allocate_req(nccl_ofi_freelist_t *fl) } /** - * @brief Allocate a new send ctrl req from freelist + * @brief Allocate a new control message that the receiver will + * send to the sender describing the recv buffer. */ static inline int insert_send_ctrl_req( nccl_net_ofi_rdma_recv_comm_t *r_comm, @@ -2832,7 +2857,6 @@ static inline int insert_send_ctrl_req( nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, nccl_net_ofi_rdma_req_t *recv_req) { - nccl_net_ofi_scheduler_t *scheduler = device->scheduler; nccl_net_ofi_rdma_req_t *send_ctrl_req = allocate_req(r_comm->nccl_ofi_reqs_fl); if (OFI_UNLIKELY(send_ctrl_req == NULL)) { NCCL_OFI_WARN("Unable to get NCCL OFI send control request for device %d", @@ -2849,18 +2873,6 @@ static inline int insert_send_ctrl_req( rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(send_ctrl_req); send_ctrl_data->recv_req = recv_req; send_ctrl_data->ctrl_fl_item = NULL; - send_ctrl_data->ctrl_schedule = scheduler->get_schedule(scheduler, - sizeof(nccl_net_ofi_rdma_ctrl_msg_t), - device->num_rails); - - if (OFI_UNLIKELY(!(send_ctrl_data->ctrl_schedule))) { - return ncclInternalError; - } else if (OFI_UNLIKELY(send_ctrl_data->ctrl_schedule->num_xfer_infos != 1)) { - NCCL_OFI_WARN("Invalid schedule for outgoing control message (%zu bytes). Expected one rail, but got %zu", - size, - send_ctrl_data->ctrl_schedule->num_xfer_infos); - return ncclInternalError; - } /* * Allocate RDMA control buffer which transfers the RDMA write buffer @@ -3586,6 +3598,26 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen /* Add ourselves to ep's lookup array */ set_comm(ep, r_comm->local_tag, &r_comm->base.base); + r_comm->control_rail.local_ep = l_comm->leader_local_ep; + ret = fi_av_insert(ep->control_rail.av, (void *)conn_msg->control_ep_name.ep_name, 1, + &r_comm->control_rail.remote_addr, 0, NULL); + if (OFI_UNLIKELY(ret != 1)) { + NCCL_OFI_WARN("Unable to insert remote address into address vector " + "for device %d. RC: %d", + dev_id, fi_strerror(-ret)); + goto error; + } + + ret = fi_av_insert(ep->control_rail.av, (void *)ep->control_rail.local_ep_name, 1, + &r_comm->control_rail.local_addr, 0, NULL); + if (OFI_UNLIKELY(ret != 1)) { + NCCL_OFI_WARN("Unable to insert local address into address vector " + "for device %d. RC: %d", + dev_id, fi_strerror(-ret)); + goto error; + } + + /* Allocate array of communicator rails */ r_comm->num_rails = num_rails; @@ -3742,7 +3774,7 @@ static int post_send_conn_resp(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_req_t *req) { ssize_t rc = 0; - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = get_recv_comm_rail(r_comm, 0); + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail; req->state = NCCL_OFI_RDMA_REQ_PENDING; rc = fi_tsend(comm_rail->local_ep, (void *)conn_resp, @@ -4029,7 +4061,6 @@ static int listen(nccl_net_ofi_ep_t *base_ep, bool first_post = true; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_ep; - nccl_net_ofi_ep_rail_t *first_rail = get_rail(ep, 0); /* Retrieve and validate device */ nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)ep->base.device; @@ -4044,9 +4075,9 @@ static int listen(nccl_net_ofi_ep_t *base_ep, /* Build handle */ memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t)); - assert(sizeof(handle->ep_name) == sizeof(first_rail->local_ep_name)); - memcpy(handle->ep_name, first_rail->local_ep_name, - sizeof(first_rail->local_ep_name)); + assert(sizeof(handle->ep_name) == sizeof(ep->control_rail.local_ep_name)); + memcpy(handle->ep_name, ep->control_rail.local_ep_name, + sizeof(ep->control_rail.local_ep_name)); handle->tag = ep->tag; /* Build listen_comm */ @@ -4064,7 +4095,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, l_comm->base.accept = accept; l_comm->base.close = listen_close; l_comm->tag = ep->tag; - l_comm->leader_local_ep = first_rail->ofi_ep; + l_comm->leader_local_ep = ep->control_rail.ofi_ep; /* Prepare receive request to accept connections */ ret = prepare_recv_conn_req(l_comm); @@ -4265,7 +4296,7 @@ static int post_bounce_buffer(nccl_net_ofi_rdma_req_t *req, rdma_req_bounce_data_t *bounce_data = get_bounce_data(req); nccl_net_ofi_rdma_bounce_fl_item_t *bounce_fl_item = bounce_data->bounce_fl_item; freelist_regmr_fn_handle_t *fl_mr_handle = bounce_fl_item->fl_reginfo.mr_handle; - void *desc = fi_mr_desc(fl_mr_handle->mr_handle->mr[bounce_data->bounce_rail_id]); + void *desc = fi_mr_desc(fl_mr_handle->mr_handle->mr[bounce_data->rail->rail_id]); /* Reset memcheck guards of bounce buffer freelist entry to * accessible but undefined to cover cases where the buffer @@ -4346,12 +4377,9 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req) } else if (req->type == NCCL_OFI_RDMA_BOUNCE) { // Post Bounce Buffer rdma_req_bounce_data_t *bounce_data = get_bounce_data(req); /* Get ep rail information to xfer the req */ - nccl_net_ofi_rdma_ep_t *ep = bounce_data->ep; - assert(bounce_data->bounce_rail_id >=0 ); - assert(bounce_data->bounce_rail_id < ep->num_rails); - nccl_net_ofi_ep_rail_t *ep_rail = &ep->rails[bounce_data->bounce_rail_id]; + assert(bounce_data->rail != NULL); - ret = post_bounce_buffer(req, ep_rail); + ret = post_bounce_buffer(req, bounce_data->rail); } else { NCCL_OFI_WARN("Unexpected request type. Request type: %d", req->type); ret = -EINVAL; @@ -4365,16 +4393,9 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) assert(req->type == NCCL_OFI_RDMA_SEND_CTRL); nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req); - nccl_net_ofi_schedule_t *schedule = send_ctrl_data->ctrl_schedule; - - assert(schedule != NULL); - - // Should be using a single rail for posting the control message - nccl_net_ofi_xfer_info_t *xfer_info = &schedule->rail_xfer_infos[0]; // Get communicator rail information to xfer the req - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; - comm_rail = get_recv_comm_rail(r_comm, xfer_info->rail_id); + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail; nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item = send_ctrl_data->ctrl_fl_item; @@ -4382,8 +4403,7 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) freelist_regmr_fn_handle_t * fl_handle = ctrl_fl_item->fl_reginfo.mr_handle; nccl_net_ofi_rdma_mr_handle_t *mr_handle = fl_handle->mr_handle; - assert(xfer_info->rail_id < mr_handle->num_rails); - void *desc = fi_mr_desc(mr_handle->mr[xfer_info->rail_id]); + void *desc = fi_mr_desc(mr_handle->mr[0]); uint64_t data = GET_RDMA_WRITE_IMM_DATA(r_comm->remote_tag, req->msg_seq_num, 0); @@ -4415,7 +4435,7 @@ static int post_eager_copy(nccl_net_ofi_rdma_req_t *req) // Get communicator rail information to xfer the req nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; - int bounce_rail_id = bounce_data->bounce_rail_id; + int bounce_rail_id = bounce_data->rail->rail_id; comm_rail = get_recv_comm_rail(r_comm, bounce_rail_id); /* Unpack mr_handle */ @@ -4497,9 +4517,8 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req) int ret = 0; rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); nccl_net_ofi_rdma_ep_t *ep = bounce_data->ep; - int rail_id = bounce_data->bounce_rail_id; - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); + nccl_net_ofi_ep_rail_t *rail = bounce_data->rail; ret = pthread_mutex_lock(&rail->bounce_mutex); if (ret) { @@ -4536,7 +4555,7 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req) } /* Post more buffers if needed */ - ret = check_post_bounce_buffers_rail(ep, rail_id); + ret = check_post_bounce_buffers_rail(ep, rail); } else { ret = bounce_req->free(bounce_req, false); if (ret != 0) { @@ -4835,6 +4854,9 @@ static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, /* Send s_comm's local tag to be transferred to receiver */ conn_msg->local_tag = local_tag; + memcpy(conn_msg->control_ep_name.ep_name, ep->control_rail.local_ep_name, + sizeof(ep->control_rail.local_ep_name)); + /* Set number of rails to be sent back to remote for verification */ conn_msg->num_rails = num_rails; @@ -4892,6 +4914,14 @@ static inline int init_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep) return ret; } + ep->control_rail.min_bounce_posted = NCCL_OFI_DIV_CEIL( + ofi_nccl_rdma_min_posted_bounce_buffers(), ep->num_rails + ); + ep->control_rail.max_bounce_posted = NCCL_OFI_DIV_CEIL( + ofi_nccl_rdma_max_posted_bounce_buffers(), ep->num_rails + ); + ret = pthread_mutex_init(&ep->control_rail.bounce_mutex, NULL); + for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); rail->min_bounce_posted = NCCL_OFI_DIV_CEIL( @@ -4975,7 +5005,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL; int num_rails = ep->num_rails; int rail_id = 0; - nccl_net_ofi_ep_rail_t *first_rail = get_rail(ep, 0); + nccl_net_ofi_ep_rail_t *control_rail = &ep->control_rail; *s_comm = NULL; /* Retrieve and validate device */ @@ -5025,7 +5055,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret_s_comm->num_rails = num_rails; /* Insert remote name into AV of first rail */ - ret = fi_av_insert(first_rail->av, + ret = fi_av_insert(control_rail->av, (void *)handle->ep_name, 1, &remote_addr, 0, NULL); if (OFI_UNLIKELY(ret != 1)) { @@ -5035,11 +5065,11 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, } /* Store remote address of first rail in communicator */ - ret_s_comm->rails[0].remote_addr = remote_addr; + ret_s_comm->control_rail.remote_addr = remote_addr; - /* Store local libfabric endpoint of first rail */ - ret_s_comm->rails[0].local_ep = first_rail->ofi_ep; - ret_s_comm->num_init_rails = 1; + /* Store local libfabric endpoint of control rail */ + ret_s_comm->control_rail.local_ep = control_rail->ofi_ep; + ret_s_comm->num_init_rails = 0; /* Allocate request free list */ ret = nccl_ofi_freelist_init(sizeof(nccl_net_ofi_rdma_req_t), 16, 16, @@ -5136,7 +5166,7 @@ static int post_send_conn(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_net_ofi_rdma_req_t *req) { ssize_t rc = 0; - nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = get_send_comm_rail(s_comm, 0); + nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = &s_comm->control_rail; /* * TODO: replace it with API of FI_INJECT type when most of @@ -5338,19 +5368,24 @@ static int connect(nccl_net_ofi_ep_t *base_ep, return ret; } + +static void ep_rail_release(nccl_net_ofi_ep_rail_t *rail, int dev_id) +{ + nccl_ofi_ofiutils_ep_release(rail->ofi_ep, rail->av, + rail->cq, dev_id); + rail->ofi_ep = NULL; + rail->av = NULL; + rail->cq = NULL; +} + + /* * @brief Release libfabric resources of rdma endpoint */ static void release_rdma_ep_resources(nccl_net_ofi_rdma_ep_t *ep, int dev_id) { for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) { - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - - nccl_ofi_ofiutils_ep_release(rail->ofi_ep, rail->av, - rail->cq, dev_id); - rail->ofi_ep = NULL; - rail->av = NULL; - rail->cq = NULL; + ep_rail_release(get_rail(ep, rail_id), dev_id); } } @@ -5386,6 +5421,32 @@ static inline int set_local_address(struct fid_ep *ep, nccl_net_ofi_ep_rail_t *r return 0; } + +static int ep_rail_init(nccl_net_ofi_rdma_ep_t *ep, + int dev_id, int rail_id, + nccl_net_ofi_rdma_device_rail_t *dev_rail, + nccl_net_ofi_ep_rail_t *ep_rail) +{ + int ret = 0; + + ret = nccl_ofi_ofiutils_init_connection(FI_VERSION(1, 18), + dev_rail->info, dev_rail->domain, + &ep_rail->ofi_ep, &ep_rail->av, &ep_rail->cq); + if (ret != 0) { + return ret; + } + + ep_rail->rail_id = rail_id; + + ret = set_local_address(ep_rail->ofi_ep, ep_rail); + if (ret != 0) { + ep_rail_release(ep_rail, dev_id); + } + + return 0; +} + + /* * @brief Initialize libfabric resources of endpoint rails */ @@ -5401,16 +5462,7 @@ static int init_rail_ofi_resources(nccl_net_ofi_rdma_device_t *device, get_device_rail(device, rail_id); nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - ret = nccl_ofi_ofiutils_init_connection(FI_VERSION(1, 18), - rail_dev->info, - rail_dev->domain, - &rail->ofi_ep, - &rail->av, &rail->cq); - if (ret != 0) { - goto exit; - } - - ret = set_local_address(rail->ofi_ep, rail); + ret = ep_rail_init(ep, dev_id, rail_id, rail_dev, rail); if (ret != 0) { goto exit; } @@ -5589,6 +5641,12 @@ static int get_ep(nccl_net_ofi_device_t *base_dev, goto unlock; } + ret = ep_rail_init(ep, dev_id, 0, &device->device_rails[0], &ep->control_rail); + if (ret != 0) { + NCCL_OFI_WARN("Initializing control rail failed"); + goto unlock; + } + ret = init_rail_ofi_resources(device, ep); if (ret != 0) { goto unlock;