From da3fb7cc6fb3073bb3ca527288a1f975860a6136 Mon Sep 17 00:00:00 2001 From: Brian Barrett Date: Wed, 13 Dec 2023 19:41:49 +0000 Subject: [PATCH] Move RDMA control messages to own endpoint For the long message RDMA protocol, we want to make sure that we never starve the sender for data to move, which means prioritizing control messages from the receiver to the sender. This patch moves both the communicator setup and recv control messages to a new endpoint, which is always on device rail 0. Future patches will optimize polling of the control message cq in the send path and setting priority bits on the control cq. Signed-off-by: Brian Barrett --- include/nccl_ofi_rdma.h | 24 ++- src/nccl_ofi_rdma.c | 412 +++++++++++++++++++++++----------------- 2 files changed, 252 insertions(+), 184 deletions(-) diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 7238dedb8..a4ba63bf4 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -67,6 +67,8 @@ typedef enum nccl_net_ofi_rdma_req_type { * allocate a rdma memory registration handle with `num_rails' rails. */ typedef struct nccl_net_ofi_rdma_mr_handle { + struct fid_mr *control_mr; + int num_rails; /* Array of size `num_rails' */ @@ -102,6 +104,8 @@ struct nccl_net_ofi_rdma_req; struct nccl_net_ofi_rdma_ep; typedef struct nccl_net_ofi_rdma_req nccl_net_ofi_rdma_req_t; typedef struct nccl_net_ofi_rdma_ep nccl_net_ofi_rdma_ep_t; +struct nccl_net_ofi_ep_rail; +typedef struct nccl_net_ofi_ep_rail nccl_net_ofi_ep_rail_t; typedef struct { /* Bounce buffer freelist item */ @@ -116,7 +120,7 @@ typedef struct { * This is useful for re-posting the bounce buffer on the same rail * when it gets completed. */ - int bounce_rail_id; + nccl_net_ofi_ep_rail_t *rail; /* * Back-pointer to associated endpoint */ @@ -157,10 +161,6 @@ typedef struct { typedef struct { /* Pointer to the allocated control buffer from freelist */ nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item; - /* Schedule used to transfer the control buffer. We save the - * pointer to reference it when transferring the buffer over - * network. */ - nccl_net_ofi_schedule_t *ctrl_schedule; /* Pointer to recv parent request */ nccl_net_ofi_rdma_req_t *recv_req; } rdma_req_send_ctrl_data_t; @@ -297,6 +297,8 @@ typedef struct nccl_ofi_rdma_connection_info { side. The receiver must use this tag when sending messages to sender */ uint64_t local_tag; + nccl_ofi_rdma_ep_name_t control_ep_name; + /* Number of rails */ int num_rails; @@ -357,6 +359,8 @@ typedef struct nccl_net_ofi_rdma_send_comm { nccl_ofi_msgbuff_t *msgbuff; + nccl_net_ofi_rdma_send_comm_rail_t control_rail; + /* Number of rails */ int num_rails; @@ -430,6 +434,8 @@ typedef struct nccl_net_ofi_rdma_recv_comm { /* Free list to track control buffers, for sending RDMA control messages */ nccl_ofi_freelist_t *ctrl_buff_fl; + nccl_net_ofi_rdma_recv_comm_rail_t control_rail; + /* Number of rails */ int num_rails; @@ -467,7 +473,9 @@ typedef struct nccl_net_ofi_rdma_listen_comm { * Endpoint rail encapsulates data of an endpoint for a * specific rail. */ -typedef struct nccl_net_ofi_ep_rail { +struct nccl_net_ofi_ep_rail { + int rail_id; + /* Local libfabric endpoint handle */ struct fid_ep *ofi_ep; @@ -492,7 +500,7 @@ typedef struct nccl_net_ofi_ep_rail { size_t max_bounce_posted; /* Mutex for bounce buffer operations */ pthread_mutex_t bounce_mutex; -} nccl_net_ofi_ep_rail_t; +}; /* * @brief RDMA Endpoint @@ -516,6 +524,8 @@ struct nccl_net_ofi_rdma_ep { /* Current available tag ID */ uint64_t tag; + nccl_net_ofi_ep_rail_t control_rail; + /* Number of rails */ int num_rails; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index a0f09b17b..03ed59d01 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -211,7 +211,7 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req); static int receive_progress(nccl_net_ofi_rdma_req_t *req, bool add_to_pending); -static int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, int rail_id); +static int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_ep_rail_t *rail); static inline int repost_bounce_buff(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_rdma_req_t *bounce_req); @@ -1001,14 +1001,12 @@ static void copy_ctrl_data(nccl_net_ofi_rdma_req_t *bounce_req, nccl_net_ofi_rdm * Post all bounce buffers for a rail if we don't have enough */ static inline int check_post_bounce_buffers_rail(nccl_net_ofi_rdma_ep_t *ep, - int rail_id) + nccl_net_ofi_ep_rail_t *rail) { - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - /* Not taking lock here since we are only reading a value. If needed, post_bounce_buffs_on_rail will take the lock. */ if (rail->num_bounce_posted < rail->min_bounce_posted) { - return post_bounce_buffs_on_rail(ep, rail_id); + return post_bounce_buffs_on_rail(ep, rail); } return 0; @@ -1040,20 +1038,18 @@ static inline int repost_bounce_buff(nccl_net_ofi_rdma_ep_t *ep, } rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); - int rail_id = bounce_data->bounce_rail_id; /* Next, check the posted count and post more buffers if needed. */ - return check_post_bounce_buffers_rail(ep, rail_id); + return check_post_bounce_buffers_rail(ep, bounce_data->rail); } /* * @brief Decrement the number of bounce buffers posted for the rail * corresponding to bounce_req */ -static inline int decrease_bounce_buff_cnt(nccl_net_ofi_rdma_ep_t *ep, int rail_id) +static inline int decrease_bounce_buff_cnt(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) { - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - int ret = pthread_mutex_lock(&rail->bounce_mutex); if (ret) { NCCL_OFI_WARN("Failed to lock bounce_mutex"); @@ -1069,7 +1065,7 @@ static inline int decrease_bounce_buff_cnt(nccl_net_ofi_rdma_ep_t *ep, int rail_ return -ret; } - return check_post_bounce_buffers_rail(ep, rail_id); + return check_post_bounce_buffers_rail(ep, rail); } /** @@ -1083,7 +1079,6 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_net_ofi_rdma_ep_t *ep) { int ret; - int bounce_rail_id = get_bounce_data(bounce_req)->bounce_rail_id; nccl_ofi_msgbuff_status_t stat; nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert(s_comm->msgbuff, msg_seq_num, @@ -1092,7 +1087,7 @@ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { /* Inserted! In this case sender has not yet called send() for this message, so return success and initiate RDMA write when sender calls send(). */ - return decrease_bounce_buff_cnt(ep, bounce_rail_id); + return decrease_bounce_buff_cnt(ep, get_bounce_data(bounce_req)->rail); } if (mb_res != NCCL_OFI_MSGBUFF_INVALID_IDX || stat != NCCL_OFI_MSGBUFF_INPROGRESS) { @@ -1202,10 +1197,9 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_ep_t *ep) { int ret; - int bounce_rail_id = get_bounce_data(bounce_req)->bounce_rail_id; /* Decrease bounce buffer count. It will be incremented again when reposting */ - ret = decrease_bounce_buff_cnt(ep, bounce_rail_id); + ret = decrease_bounce_buff_cnt(ep, get_bounce_data(bounce_req)->rail); if (ret != 0) { return ret; } @@ -1383,8 +1377,8 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm); * error, on others */ static inline int process_completions(struct fi_cq_tagged_entry *cq_entry, - uint64_t num_cqes, nccl_net_ofi_rdma_ep_t *ep, - int rail_id) + uint64_t num_cqes, nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) { int ret = 0; nccl_net_ofi_rdma_req_t *req = NULL; @@ -1453,16 +1447,16 @@ static inline int process_completions(struct fi_cq_tagged_entry *cq_entry, /* This is a bounce buffer receive event. It could be a ctrl message receive (send comm) or an eager message receive (recv comm) */ - ret = handle_bounce_recv(&cq_entry[comp_idx], rail_id); + ret = handle_bounce_recv(&cq_entry[comp_idx], rail->rail_id); } else if (comp_flags & FI_REMOTE_WRITE) { /* Type 6: Remote-initiated write is complete */ - ret = handle_write_comp(&cq_entry[comp_idx], ep, rail_id); + ret = handle_write_comp(&cq_entry[comp_idx], ep, rail->rail_id); } else if (comp_flags & FI_WRITE) { /* Type 5: Local-initiated write is complete */ req = op_ctx; rdma_req_send_data_t *send_data = get_send_data(req); - NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE(req->dev_id, rail_id, req->comm, req->msg_seq_num, req); + NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE(req->dev_id, rail->rail_id, req->comm, req->msg_seq_num, req); if (inc_req_completion(req, 0, send_data->total_num_compls)) { ret = ncclInternalError; @@ -1676,14 +1670,7 @@ static int process_pending_reqs(nccl_net_ofi_rdma_ep_t *ep) return rc; } -/* - * @brief Process completion entries for the given completion quque. - * This also updates several request fileds like size, status, etc - * - * @return 0, on success - * error, on others - */ -static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep) +static int ofi_process_cq_rail(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_ep_rail_t *rail) { ssize_t rc = 0; int ret = 0; @@ -1691,83 +1678,106 @@ static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep) struct fi_cq_tagged_entry cqe_tagged_buffers[cq_read_count]; nccl_net_ofi_rdma_req_t *req = NULL; - for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) { - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - - while (true) { - /* Receive completions for the given endpoint */ - rc = fi_cq_read(rail->cq, cqe_tagged_buffers, cq_read_count); - if (rc > 0) { - ret = process_completions( - cqe_tagged_buffers, rc, ep, rail_id); - if (OFI_UNLIKELY(ret != 0)) - goto exit; - } else if (OFI_UNLIKELY(rc == -FI_EAVAIL)) { - rc = fi_cq_readerr(rail->cq, &err_buffer, 0); - if (OFI_UNLIKELY(rc == -FI_EAGAIN)) { - /* - * Error not available yet. - * fi_cq_read will keep returning -FI_EAVAIL so just bail out and try again later. - */ - break; - } else if (OFI_UNLIKELY(rc < 0)) { - NCCL_OFI_WARN("Unable to read from fi_cq_readerr. RC: %zd. Error: %s", - rc, - fi_strerror(-rc)); + while (true) { + /* Receive completions for the given endpoint */ + rc = fi_cq_read(rail->cq, cqe_tagged_buffers, cq_read_count); + if (rc > 0) { + ret = process_completions( + cqe_tagged_buffers, rc, ep, rail); + if (OFI_UNLIKELY(ret != 0)) + goto exit; + } else if (OFI_UNLIKELY(rc == -FI_EAVAIL)) { + rc = fi_cq_readerr(rail->cq, &err_buffer, 0); + if (OFI_UNLIKELY(rc == -FI_EAGAIN)) { + /* + * Error not available yet. + * fi_cq_read will keep returning -FI_EAVAIL so just bail out and try again later. + */ + break; + } else if (OFI_UNLIKELY(rc < 0)) { + NCCL_OFI_WARN("Unable to read from fi_cq_readerr. RC: %zd. Error: %s", + rc, + fi_strerror(-rc)); + ret = ncclSystemError; + goto exit; + } + if (err_buffer.flags & FI_REMOTE_WRITE) { + req = get_req_from_imm_data(ep, err_buffer.data); + if (!req) { + NCCL_OFI_WARN("Unknown remote write error"); ret = ncclSystemError; goto exit; } - if (err_buffer.flags & FI_REMOTE_WRITE) { - req = get_req_from_imm_data(ep, err_buffer.data); - if (!req) { - NCCL_OFI_WARN("Unknown remote write error"); - ret = ncclSystemError; - goto exit; - } - } else { - /* For all other operations, ctx should be a req */ - if (!err_buffer.op_context) { - NCCL_OFI_WARN("Operation with NULL context completed with error!"); - ret = ncclSystemError; - goto exit; - } - req = err_buffer.op_context; - } - - NCCL_OFI_WARN("Request %p completed with error. RC: %d. Error: %s. Completed length: %ld, Request: %s", - req, - err_buffer.err, - fi_cq_strerror(rail->cq, - err_buffer.prov_errno, - err_buffer.err_data, NULL, 0), - (long)err_buffer.len, - nccl_net_ofi_req_str(req)); - set_request_state_to_error(req); - - if (req->type == NCCL_OFI_RDMA_BOUNCE) { - /* A bounce buffer receive failed -- this is an internal error so bail out */ - NCCL_OFI_WARN("Fatal: Bounce buffer recv completed with error"); + } else { + /* For all other operations, ctx should be a req */ + if (!err_buffer.op_context) { + NCCL_OFI_WARN("Operation with NULL context completed with error!"); ret = ncclSystemError; goto exit; } - } else if (rc == -FI_EAGAIN) { - /* No completions to process */ - break; - } else { - NCCL_OFI_WARN("Unable to retrieve completion queue entries. RC: %zd, ERROR: %s", - rc, fi_strerror(-rc)); + req = err_buffer.op_context; + } + + NCCL_OFI_WARN("Request %p completed with error. RC: %d. Error: %s. Completed length: %ld, Request: %s", + req, + err_buffer.err, + fi_cq_strerror(rail->cq, + err_buffer.prov_errno, + err_buffer.err_data, NULL, 0), + (long)err_buffer.len, + nccl_net_ofi_req_str(req)); + set_request_state_to_error(req); + + if (req->type == NCCL_OFI_RDMA_BOUNCE) { + /* A bounce buffer receive failed -- this is an internal error so bail out */ + NCCL_OFI_WARN("Fatal: Bounce buffer recv completed with error"); ret = ncclSystemError; goto exit; } + } else if (rc == -FI_EAGAIN) { + /* No completions to process */ + break; + } else { + NCCL_OFI_WARN("Unable to retrieve completion queue entries. RC: %zd, ERROR: %s", + rc, fi_strerror(-rc)); + ret = ncclSystemError; + goto exit; } + } + +exit: + return ret; +} + +/* + * @brief Process completion entries for the given completion quque. + * This also updates several request fileds like size, status, etc + * + * @return 0, on success + * error, on others + */ +static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep) +{ + int ret; + for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) { + nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); + + ret = ofi_process_cq_rail(ep, rail); + if (ret != 0) { + goto exit; + } + } + + ret = ofi_process_cq_rail(ep, &ep->control_rail); + if (ret != 0) { + goto exit; } /* Process any pending requests */ - rc = process_pending_reqs(ep); - if (OFI_UNLIKELY(rc != 0 && rc != -FI_EAGAIN)) { - NCCL_OFI_WARN("Failed call to process_pending_reqs: %zd", rc); - ret = ncclSystemError; + ret = process_pending_reqs(ep); + if (OFI_UNLIKELY(ret != 0 && ret != -FI_EAGAIN)) { + NCCL_OFI_WARN("Failed call to process_pending_reqs: %zd", ret); } exit: @@ -1925,12 +1935,6 @@ static inline int free_send_ctrl_req(nccl_net_ofi_rdma_req_t *req, (nccl_net_ofi_rdma_recv_comm_t *)req->comm; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req); - if (send_ctrl_data->ctrl_schedule) { - nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)req->comm->ep->device; - nccl_net_ofi_release_schedule(device->scheduler, send_ctrl_data->ctrl_schedule); - send_ctrl_data->ctrl_schedule = NULL; - } - if (send_ctrl_data->ctrl_fl_item) { nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm; nccl_ofi_freelist_entry_free(r_comm->ctrl_buff_fl, send_ctrl_data->ctrl_fl_item); @@ -2003,7 +2007,7 @@ static inline int free_bounce_req(nccl_net_ofi_rdma_req_t *req, } static inline nccl_net_ofi_rdma_req_t *alloc_bounce_req(nccl_net_ofi_rdma_ep_t *ep, - int rail_id) + nccl_net_ofi_ep_rail_t *rail) { nccl_net_ofi_rdma_req_t *req = allocate_req(ep->bounce_buff_reqs_fl); if (!req) return NULL; @@ -2026,12 +2030,13 @@ static inline nccl_net_ofi_rdma_req_t *alloc_bounce_req(nccl_net_ofi_rdma_ep_t * bounce_data->bounce_fl_item = bounce_fl_item; bounce_data->buff_len = ep->bounce_buff_size; - bounce_data->bounce_rail_id = rail_id; + bounce_data->rail = rail; bounce_data->ep = ep; return req; } -static inline int handle_bounce_eagain(nccl_net_ofi_rdma_ep_t *ep, int rail_id, +static inline int handle_bounce_eagain(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail, nccl_net_ofi_rdma_req_t *req, size_t num_buffs_failed) { /* Add to pending reqs queue */ @@ -2042,8 +2047,6 @@ static inline int handle_bounce_eagain(nccl_net_ofi_rdma_ep_t *ep, int rail_id, } NCCL_OFI_TRACE_PENDING_INSERT(req); - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - ret = pthread_mutex_lock(&rail->bounce_mutex); if (ret != 0) { NCCL_OFI_WARN("Failed to lock bounce_mutex: %d", ret); @@ -2062,12 +2065,11 @@ static inline int handle_bounce_eagain(nccl_net_ofi_rdma_ep_t *ep, int rail_id, return ret; } -static inline int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, int rail_id) +static inline int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) { int ret = 0; - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - ret = pthread_mutex_lock(&rail->bounce_mutex); if (ret != 0) { NCCL_OFI_WARN("Failed to lock bounce_mutex: %d", ret); @@ -2087,7 +2089,7 @@ static inline int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, int rail /* Post all the bounce buffers we need */ for (size_t i = 0; i < buffers_needed; ++i) { nccl_net_ofi_rdma_req_t *req = - alloc_bounce_req(ep, rail_id); + alloc_bounce_req(ep, rail); if (!req) { NCCL_OFI_WARN("Failed to allocate bounce req"); return -ENOMEM; @@ -2097,7 +2099,7 @@ static inline int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, int rail /* Update posted count */ /* We failed to post num_buffs_failed buffers that we promised above */ size_t num_buffs_failed = buffers_needed - i - 1; - ret = handle_bounce_eagain(ep, rail_id, req, num_buffs_failed); + ret = handle_bounce_eagain(ep, rail, req, num_buffs_failed); if (ret != 0) return ret; break; @@ -2118,13 +2120,20 @@ static inline int post_bounce_buffs(nccl_net_ofi_rdma_ep_t *ep) int ret = 0; for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { - ret = post_bounce_buffs_on_rail(ep, rail_id); + nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); + ret = post_bounce_buffs_on_rail(ep, rail); if (ret != 0) { NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail"); goto exit; } } + ret = post_bounce_buffs_on_rail(ep, &ep->control_rail); + if (ret != 0) { + NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)"); + goto exit; + } + exit: return ret; } @@ -2418,7 +2427,7 @@ static int post_recv_conn_resp(nccl_net_ofi_rdma_send_comm_t *s_comm, int ret = 0; int dev_id = s_comm->base.base.dev_id; assert(s_comm && s_comm->num_rails > 0); - nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = get_send_comm_rail(s_comm, 0); + nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = &s_comm->control_rail; nccl_net_ofi_rdma_req_t *req = s_comm->conn_resp_req; /* Post a buffer for receiving connect response requests */ @@ -2504,6 +2513,13 @@ static int dereg_rails(nccl_net_ofi_rdma_mr_handle_t *handle) int num_rails = handle->num_rails; int rc = 0; + rc = fi_close(&handle->control_mr->fid); + if (OFI_UNLIKELY(rc != 0)) { + NCCL_OFI_WARN("Unable to de-register memory on control mr. RC: %d, Error: %s", + rc, fi_strerror(-rc)); + ret = ncclSystemError; + } + for (int rail_id = 0; rail_id != num_rails; ++rail_id) { /* No memory registration available for this rail */ if (!handle->mr[rail_id]) continue; @@ -2583,9 +2599,17 @@ static int reg_mr_ep(nccl_net_ofi_rdma_ep_t *ep, void *data, goto exit; } - ret_handle->num_rails = num_rails; + ret = register_rail_mr_buffer(device->device_rails[0].domain, ep->control_rail.ofi_ep, + -1, type, &mr_attr, + &ret_handle->control_mr); + if (OFI_UNLIKELY(ret != 0)) { + free(ret_handle); + ret_handle = NULL; + goto exit; + } /* Register memory on each rail */ + ret_handle->num_rails = num_rails; for (int rail_id = 0; rail_id != num_rails; ++rail_id) { nccl_net_ofi_rdma_device_rail_t *dev_rail = get_device_rail(device, rail_id); nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); @@ -2690,12 +2714,12 @@ static int dereg_mr_ep(nccl_net_ofi_rdma_mr_handle_t *mr_handle, int ret = 0; if (OFI_UNLIKELY(mr_handle == NULL)) { - NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Null MR handle provided. This is an error."); + NCCL_OFI_WARN("Null MR handle provided. This is an error."); return ncclInternalError; } - if (OFI_UNLIKELY(mr_handle->num_rails < 1)) { - NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, "Unexpected number of rails in rdma memory registration handle"); + if (OFI_UNLIKELY(mr_handle->num_rails < 0)) { + NCCL_OFI_WARN("Unexpected number of rails in rdma memory registration handle"); return ncclInternalError; } @@ -2822,7 +2846,8 @@ static inline nccl_net_ofi_rdma_req_t *allocate_req(nccl_ofi_freelist_t *fl) } /** - * @brief Allocate a new send ctrl req from freelist + * @brief Allocate a new control message that the receiver will + * send to the sender describing the recv buffer. */ static inline int insert_send_ctrl_req( nccl_net_ofi_rdma_recv_comm_t *r_comm, @@ -2832,7 +2857,6 @@ static inline int insert_send_ctrl_req( nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, nccl_net_ofi_rdma_req_t *recv_req) { - nccl_net_ofi_scheduler_t *scheduler = device->scheduler; nccl_net_ofi_rdma_req_t *send_ctrl_req = allocate_req(r_comm->nccl_ofi_reqs_fl); if (OFI_UNLIKELY(send_ctrl_req == NULL)) { NCCL_OFI_WARN("Unable to get NCCL OFI send control request for device %d", @@ -2849,18 +2873,6 @@ static inline int insert_send_ctrl_req( rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(send_ctrl_req); send_ctrl_data->recv_req = recv_req; send_ctrl_data->ctrl_fl_item = NULL; - send_ctrl_data->ctrl_schedule = scheduler->get_schedule(scheduler, - sizeof(nccl_net_ofi_rdma_ctrl_msg_t), - device->num_rails); - - if (OFI_UNLIKELY(!(send_ctrl_data->ctrl_schedule))) { - return ncclInternalError; - } else if (OFI_UNLIKELY(send_ctrl_data->ctrl_schedule->num_xfer_infos != 1)) { - NCCL_OFI_WARN("Invalid schedule for outgoing control message (%zu bytes). Expected one rail, but got %zu", - size, - send_ctrl_data->ctrl_schedule->num_xfer_infos); - return ncclInternalError; - } /* * Allocate RDMA control buffer which transfers the RDMA write buffer @@ -3586,6 +3598,26 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen /* Add ourselves to ep's lookup array */ set_comm(ep, r_comm->local_tag, &r_comm->base.base); + r_comm->control_rail.local_ep = l_comm->leader_local_ep; + ret = fi_av_insert(ep->control_rail.av, (void *)conn_msg->control_ep_name.ep_name, 1, + &r_comm->control_rail.remote_addr, 0, NULL); + if (OFI_UNLIKELY(ret != 1)) { + NCCL_OFI_WARN("Unable to insert remote address into address vector " + "for device %d. RC: %d", + dev_id, fi_strerror(-ret)); + goto error; + } + + ret = fi_av_insert(ep->control_rail.av, (void *)ep->control_rail.local_ep_name, 1, + &r_comm->control_rail.local_addr, 0, NULL); + if (OFI_UNLIKELY(ret != 1)) { + NCCL_OFI_WARN("Unable to insert local address into address vector " + "for device %d. RC: %d", + dev_id, fi_strerror(-ret)); + goto error; + } + + /* Allocate array of communicator rails */ r_comm->num_rails = num_rails; @@ -3742,7 +3774,7 @@ static int post_send_conn_resp(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_req_t *req) { ssize_t rc = 0; - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = get_recv_comm_rail(r_comm, 0); + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail; req->state = NCCL_OFI_RDMA_REQ_PENDING; rc = fi_tsend(comm_rail->local_ep, (void *)conn_resp, @@ -4029,7 +4061,6 @@ static int listen(nccl_net_ofi_ep_t *base_ep, bool first_post = true; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_ep; - nccl_net_ofi_ep_rail_t *first_rail = get_rail(ep, 0); /* Retrieve and validate device */ nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t*)ep->base.device; @@ -4044,9 +4075,9 @@ static int listen(nccl_net_ofi_ep_t *base_ep, /* Build handle */ memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t)); - assert(sizeof(handle->ep_name) == sizeof(first_rail->local_ep_name)); - memcpy(handle->ep_name, first_rail->local_ep_name, - sizeof(first_rail->local_ep_name)); + assert(sizeof(handle->ep_name) == sizeof(ep->control_rail.local_ep_name)); + memcpy(handle->ep_name, ep->control_rail.local_ep_name, + sizeof(ep->control_rail.local_ep_name)); handle->tag = ep->tag; /* Build listen_comm */ @@ -4064,7 +4095,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, l_comm->base.accept = accept; l_comm->base.close = listen_close; l_comm->tag = ep->tag; - l_comm->leader_local_ep = first_rail->ofi_ep; + l_comm->leader_local_ep = ep->control_rail.ofi_ep; /* Prepare receive request to accept connections */ ret = prepare_recv_conn_req(l_comm); @@ -4265,7 +4296,7 @@ static int post_bounce_buffer(nccl_net_ofi_rdma_req_t *req, rdma_req_bounce_data_t *bounce_data = get_bounce_data(req); nccl_net_ofi_rdma_bounce_fl_item_t *bounce_fl_item = bounce_data->bounce_fl_item; freelist_regmr_fn_handle_t *fl_mr_handle = bounce_fl_item->fl_reginfo.mr_handle; - void *desc = fi_mr_desc(fl_mr_handle->mr_handle->mr[bounce_data->bounce_rail_id]); + void *desc = fi_mr_desc(fl_mr_handle->mr_handle->mr[bounce_data->rail->rail_id]); /* Reset memcheck guards of bounce buffer freelist entry to * accessible but undefined to cover cases where the buffer @@ -4346,12 +4377,9 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req) } else if (req->type == NCCL_OFI_RDMA_BOUNCE) { // Post Bounce Buffer rdma_req_bounce_data_t *bounce_data = get_bounce_data(req); /* Get ep rail information to xfer the req */ - nccl_net_ofi_rdma_ep_t *ep = bounce_data->ep; - assert(bounce_data->bounce_rail_id >=0 ); - assert(bounce_data->bounce_rail_id < ep->num_rails); - nccl_net_ofi_ep_rail_t *ep_rail = &ep->rails[bounce_data->bounce_rail_id]; + assert(bounce_data->rail != NULL); - ret = post_bounce_buffer(req, ep_rail); + ret = post_bounce_buffer(req, bounce_data->rail); } else { NCCL_OFI_WARN("Unexpected request type. Request type: %d", req->type); ret = -EINVAL; @@ -4365,16 +4393,9 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) assert(req->type == NCCL_OFI_RDMA_SEND_CTRL); nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req); - nccl_net_ofi_schedule_t *schedule = send_ctrl_data->ctrl_schedule; - - assert(schedule != NULL); - - // Should be using a single rail for posting the control message - nccl_net_ofi_xfer_info_t *xfer_info = &schedule->rail_xfer_infos[0]; // Get communicator rail information to xfer the req - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; - comm_rail = get_recv_comm_rail(r_comm, xfer_info->rail_id); + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail; nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item = send_ctrl_data->ctrl_fl_item; @@ -4382,8 +4403,7 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) freelist_regmr_fn_handle_t * fl_handle = ctrl_fl_item->fl_reginfo.mr_handle; nccl_net_ofi_rdma_mr_handle_t *mr_handle = fl_handle->mr_handle; - assert(xfer_info->rail_id < mr_handle->num_rails); - void *desc = fi_mr_desc(mr_handle->mr[xfer_info->rail_id]); + void *desc = fi_mr_desc(mr_handle->mr[0]); uint64_t data = GET_RDMA_WRITE_IMM_DATA(r_comm->remote_tag, req->msg_seq_num, 0); @@ -4415,7 +4435,7 @@ static int post_eager_copy(nccl_net_ofi_rdma_req_t *req) // Get communicator rail information to xfer the req nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; - int bounce_rail_id = bounce_data->bounce_rail_id; + int bounce_rail_id = bounce_data->rail->rail_id; comm_rail = get_recv_comm_rail(r_comm, bounce_rail_id); /* Unpack mr_handle */ @@ -4497,9 +4517,8 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req) int ret = 0; rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); nccl_net_ofi_rdma_ep_t *ep = bounce_data->ep; - int rail_id = bounce_data->bounce_rail_id; - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); + nccl_net_ofi_ep_rail_t *rail = bounce_data->rail; ret = pthread_mutex_lock(&rail->bounce_mutex); if (ret) { @@ -4536,7 +4555,7 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req) } /* Post more buffers if needed */ - ret = check_post_bounce_buffers_rail(ep, rail_id); + ret = check_post_bounce_buffers_rail(ep, rail); } else { ret = bounce_req->free(bounce_req, false); if (ret != 0) { @@ -4835,6 +4854,9 @@ static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, /* Send s_comm's local tag to be transferred to receiver */ conn_msg->local_tag = local_tag; + memcpy(conn_msg->control_ep_name.ep_name, ep->control_rail.local_ep_name, + sizeof(ep->control_rail.local_ep_name)); + /* Set number of rails to be sent back to remote for verification */ conn_msg->num_rails = num_rails; @@ -4892,6 +4914,14 @@ static inline int init_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep) return ret; } + ep->control_rail.min_bounce_posted = NCCL_OFI_DIV_CEIL( + ofi_nccl_rdma_min_posted_bounce_buffers(), ep->num_rails + ); + ep->control_rail.max_bounce_posted = NCCL_OFI_DIV_CEIL( + ofi_nccl_rdma_max_posted_bounce_buffers(), ep->num_rails + ); + ret = pthread_mutex_init(&ep->control_rail.bounce_mutex, NULL); + for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); rail->min_bounce_posted = NCCL_OFI_DIV_CEIL( @@ -4975,7 +5005,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL; int num_rails = ep->num_rails; int rail_id = 0; - nccl_net_ofi_ep_rail_t *first_rail = get_rail(ep, 0); + nccl_net_ofi_ep_rail_t *control_rail = &ep->control_rail; *s_comm = NULL; /* Retrieve and validate device */ @@ -5025,7 +5055,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret_s_comm->num_rails = num_rails; /* Insert remote name into AV of first rail */ - ret = fi_av_insert(first_rail->av, + ret = fi_av_insert(control_rail->av, (void *)handle->ep_name, 1, &remote_addr, 0, NULL); if (OFI_UNLIKELY(ret != 1)) { @@ -5035,11 +5065,11 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, } /* Store remote address of first rail in communicator */ - ret_s_comm->rails[0].remote_addr = remote_addr; + ret_s_comm->control_rail.remote_addr = remote_addr; - /* Store local libfabric endpoint of first rail */ - ret_s_comm->rails[0].local_ep = first_rail->ofi_ep; - ret_s_comm->num_init_rails = 1; + /* Store local libfabric endpoint of control rail */ + ret_s_comm->control_rail.local_ep = control_rail->ofi_ep; + ret_s_comm->num_init_rails = 0; /* Allocate request free list */ ret = nccl_ofi_freelist_init(sizeof(nccl_net_ofi_rdma_req_t), 16, 16, @@ -5136,7 +5166,7 @@ static int post_send_conn(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_net_ofi_rdma_req_t *req) { ssize_t rc = 0; - nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = get_send_comm_rail(s_comm, 0); + nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = &s_comm->control_rail; /* * TODO: replace it with API of FI_INJECT type when most of @@ -5338,19 +5368,24 @@ static int connect(nccl_net_ofi_ep_t *base_ep, return ret; } + +static void ep_rail_release(nccl_net_ofi_ep_rail_t *rail, int dev_id) +{ + nccl_ofi_ofiutils_ep_release(rail->ofi_ep, rail->av, + rail->cq, dev_id); + rail->ofi_ep = NULL; + rail->av = NULL; + rail->cq = NULL; +} + + /* * @brief Release libfabric resources of rdma endpoint */ static void release_rdma_ep_resources(nccl_net_ofi_rdma_ep_t *ep, int dev_id) { for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) { - nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - - nccl_ofi_ofiutils_ep_release(rail->ofi_ep, rail->av, - rail->cq, dev_id); - rail->ofi_ep = NULL; - rail->av = NULL; - rail->cq = NULL; + ep_rail_release(get_rail(ep, rail_id), dev_id); } } @@ -5386,6 +5421,32 @@ static inline int set_local_address(struct fid_ep *ep, nccl_net_ofi_ep_rail_t *r return 0; } + +static int ep_rail_init(nccl_net_ofi_rdma_ep_t *ep, + int dev_id, int rail_id, + nccl_net_ofi_rdma_device_rail_t *dev_rail, + nccl_net_ofi_ep_rail_t *ep_rail) +{ + int ret = 0; + + ret = nccl_ofi_ofiutils_init_connection(FI_VERSION(1, 18), + dev_rail->info, dev_rail->domain, + &ep_rail->ofi_ep, &ep_rail->av, &ep_rail->cq); + if (ret != 0) { + return ret; + } + + ep_rail->rail_id = rail_id; + + ret = set_local_address(ep_rail->ofi_ep, ep_rail); + if (ret != 0) { + ep_rail_release(ep_rail, dev_id); + } + + return 0; +} + + /* * @brief Initialize libfabric resources of endpoint rails */ @@ -5401,16 +5462,7 @@ static int init_rail_ofi_resources(nccl_net_ofi_rdma_device_t *device, get_device_rail(device, rail_id); nccl_net_ofi_ep_rail_t *rail = get_rail(ep, rail_id); - ret = nccl_ofi_ofiutils_init_connection(FI_VERSION(1, 18), - rail_dev->info, - rail_dev->domain, - &rail->ofi_ep, - &rail->av, &rail->cq); - if (ret != 0) { - goto exit; - } - - ret = set_local_address(rail->ofi_ep, rail); + ret = ep_rail_init(ep, dev_id, rail_id, rail_dev, rail); if (ret != 0) { goto exit; } @@ -5589,6 +5641,12 @@ static int get_ep(nccl_net_ofi_device_t *base_dev, goto unlock; } + ret = ep_rail_init(ep, dev_id, 0, &device->device_rails[0], &ep->control_rail); + if (ret != 0) { + NCCL_OFI_WARN("Initializing control rail failed"); + goto unlock; + } + ret = init_rail_ofi_resources(device, ep); if (ret != 0) { goto unlock;