diff --git a/include/nccl_ofi_param.h b/include/nccl_ofi_param.h index 09715b8a3..7391a1b10 100644 --- a/include/nccl_ofi_param.h +++ b/include/nccl_ofi_param.h @@ -300,6 +300,12 @@ OFI_NCCL_PARAM_INT(rdma_min_posted_bounce_buffers, "RDMA_MIN_POSTED_BOUNCE_BUFFE */ OFI_NCCL_PARAM_INT(rdma_max_posted_bounce_buffers, "RDMA_MAX_POSTED_BOUNCE_BUFFERS", 128); +/* + * Whether to spread the control message across multiple rails in round robin fashion or + * send it consistenly on one rail. + */ +OFI_NCCL_PARAM_INT(rdma_rr_ctrl_msg, "RR_CTRL_MSG", 0); + /* * Internode network latency reported to NCCL. Defaults to 0, unless the configured * platform sets a specific value. diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index fc880e754..5c391b513 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -112,18 +112,22 @@ typedef uint16_t nccl_ofi_rdma_msg_type_t; /* * @brief Rdma memory registration handle - - * Note that the rdma memory registration handle has a variable array - * member. Use function `calloc_rdma_mr_handle(int num_rails)' to - * allocate a rdma memory registration handle with `num_rails' rails. + * + * Use function `calloc_rdma_mr_handle(int num_rails, int num_control_rails)' to + * allocate a RDMA memory registration handle with `num_rails`+`num_control_rails` rails. */ typedef struct nccl_net_ofi_rdma_mr_handle { - struct fid_mr *control_mr; int num_rails; + int num_control_rails; + /* Array of size `num_rails' */ - struct fid_mr *mr[]; + struct fid_mr **mr; + + /* Array of size `num_control_rails' */ + struct fid_mr **control_mr; + } nccl_net_ofi_rdma_mr_handle_t; /* Contents of ctrl message sent from receiver to sender to advertise @@ -282,6 +286,10 @@ typedef struct { typedef struct { /* Pointer to the allocated control buffer from freelist */ nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item; + /* Schedule used to transfer the control buffer. We save the + * pointer to reference it when transferring the buffer over + * network. */ + nccl_net_ofi_schedule_t *ctrl_schedule; /* Pointer to recv parent request */ nccl_net_ofi_rdma_req_t *recv_req; #if HAVE_NVTX_TRACING @@ -442,6 +450,7 @@ typedef struct nccl_ofi_rdma_connection_info { /* Number of rails */ uint16_t num_rails; + uint16_t num_control_rails; /* A comm identitifer that uniquely identifies the comm on the sender side. The receiver must use this ID when sending messages to sender */ @@ -451,15 +460,14 @@ typedef struct nccl_ofi_rdma_connection_info { * on the receiver side */ uint32_t remote_comm_id; - nccl_ofi_rdma_ep_name_t control_ep_name; - - /* Array of `MAX_NUM_RAILS` `nccl_ofi_rdma_ep_name_t` - * structs. The member `num_rails` indicates the number of - * entries that are in use. */ + /* Arrays of `MAX_NUM_RAILS` `nccl_ofi_rdma_ep_name_t` + * structs. The member `num_rails` and `num_control_rails` indicate + * the number of entries that are in use. */ + nccl_ofi_rdma_ep_name_t control_ep_names[MAX_NUM_RAILS]; nccl_ofi_rdma_ep_name_t ep_names[MAX_NUM_RAILS]; } nccl_ofi_rdma_connection_info_t; /* Since this is a message on the wire, check that it has the expected size */ -static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 336, +static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 528, "Wrong size for RDMA connect message"); /* @@ -480,9 +488,8 @@ typedef struct nccl_net_ofi_rdma_send_comm_rail { /* * @brief RDMA send communicator * - * Note that the RDMA send communicator has a variable array - * member. Use function `calloc_rdma_send_comm(int num_rails)' to - * allocate a RMDA send communicator with `num_rails' rails. + * Use function `calloc_rdma_send_comm(int num_rails, int num_control_rails)' to + * allocate a RDMA send communicator with `num_rails'+`num_control_rails' rails. */ typedef struct nccl_net_ofi_rdma_send_comm { /* This base send communicator must be the first member of this @@ -511,19 +518,19 @@ typedef struct nccl_net_ofi_rdma_send_comm { nccl_ofi_msgbuff_t *msgbuff; - nccl_net_ofi_rdma_send_comm_rail_t control_rail; - /* Number of rails */ int num_rails; + /* Number of rails */ + int num_control_rails; /* Number of initialized rails. The function * `create_send_comm()' creates a send communicator with one - * initialized rail and sets `num_init_rails=0' after the + * initialized control rail and sets `num_init_control_rails=1' after the * out-of-bounds message is received. After the connect * response message has been received, the remaining rails * will be initialized via function `init_send_comm_rails()' - * and `num_init_rails' is adjusted. */ - int num_init_rails; + * and `num_init_control_rails' is adjusted. */ + int num_init_control_rails; #if HAVE_NVTX_TRACING nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; @@ -540,7 +547,9 @@ typedef struct nccl_net_ofi_rdma_send_comm { bool comm_active; /* Array of `num_rails` communicator rails */ - nccl_net_ofi_rdma_send_comm_rail_t rails[]; + nccl_net_ofi_rdma_send_comm_rail_t *rails; + /* Array of `num_control_rails` communicator rails */ + nccl_net_ofi_rdma_send_comm_rail_t *control_rails; } nccl_net_ofi_rdma_send_comm_t; @@ -573,9 +582,8 @@ typedef struct nccl_net_ofi_rdma_flush_buffer { /* * @brief RDMA receive communicator * - * Note that the RDMA receive communicator has a variable array - * member. Use function `calloc_rdma_recv_comm(int num_rails)' to - * allocate a RMDA receive communicator with `num_rails' rails. + * Use function `calloc_rdma_recv_comm(int num_rails, int num_control_rails)' to + * allocate a RDMA receive communicator with `num_rails'+`num_control_rails' rails. */ typedef struct nccl_net_ofi_rdma_recv_comm { /* This base receive communicator must be the first member of @@ -605,8 +613,6 @@ typedef struct nccl_net_ofi_rdma_recv_comm { #if HAVE_NVTX_TRACING nvtxDomainHandle_t nvtx_domain[NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; #endif - nccl_net_ofi_rdma_recv_comm_rail_t control_rail; - nccl_net_ofi_rdma_req_t *send_close_req; nccl_ofi_deque_elem_t cleanup_list_elem; @@ -618,11 +624,15 @@ typedef struct nccl_net_ofi_rdma_recv_comm { /* Number of rails */ int num_rails; + /* Number of control rails */ + int num_control_rails; bool comm_active; /* Array of `num_rails` communicator rails */ - nccl_net_ofi_rdma_recv_comm_rail_t rails[]; + nccl_net_ofi_rdma_recv_comm_rail_t *rails; + /* Array of `num_control_rails` communicator rails */ + nccl_net_ofi_rdma_recv_comm_rail_t *control_rails; } nccl_net_ofi_rdma_recv_comm_t; typedef struct nccl_net_ofi_rdma_listen_comm { @@ -708,16 +718,20 @@ struct nccl_net_ofi_rdma_ep { * and its base struct. */ nccl_net_ofi_ep_t base; - nccl_net_ofi_ep_rail_t control_rail; - /* Number of rails */ int num_rails; - bool use_long_rkeys; + /* Number of control rails */ + int num_control_rails; /* Array of `num_rails` endpoint rails */ nccl_net_ofi_ep_rail_t *rails; + /* Array of `num_control_rails` endpoint rails */ + nccl_net_ofi_ep_rail_t *control_rails; + + bool use_long_rkeys; + /* Pending requests queue */ nccl_ofi_deque_t *pending_reqs_queue; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 921cb935b..b1106fda0 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -250,11 +250,22 @@ static inline nccl_net_ofi_rdma_send_comm_rail_t *rdma_send_comm_get_rail(nccl_n int rail_id) { assert(s_comm->rails); - assert(rail_id < s_comm->num_init_rails); - assert(s_comm->num_init_rails <= s_comm->num_rails); + assert(rail_id < s_comm->num_rails); return &s_comm->rails[rail_id]; } +/* + * @brief Return send communicator control rail with index `rail_id` + */ +static inline nccl_net_ofi_rdma_send_comm_rail_t *rdma_send_comm_get_control_rail(nccl_net_ofi_rdma_send_comm_t *s_comm, + int rail_id) +{ + assert(s_comm->control_rails); + assert(rail_id < s_comm->num_init_control_rails); + assert(s_comm->num_init_control_rails <= s_comm->num_control_rails); + return &s_comm->control_rails[rail_id]; +} + /* * @brief Return receive communicator rail with index `rail_id` */ @@ -266,6 +277,16 @@ static inline nccl_net_ofi_rdma_recv_comm_rail_t *rdma_recv_comm_get_rail(nccl_n return &r_comm->rails[rail_id]; } +/* + * @brief Return receive communicator control rail with index `rail_id` + */ +static inline nccl_net_ofi_rdma_recv_comm_rail_t *rdma_recv_comm_get_control_rail(nccl_net_ofi_rdma_recv_comm_t *r_comm, + int rail_id) +{ + assert(r_comm->control_rails); + assert(rail_id < r_comm->num_control_rails); + return &r_comm->control_rails[rail_id]; +} static nccl_net_ofi_rdma_ep_t *rdma_recv_comm_get_ep(nccl_net_ofi_rdma_recv_comm_t *r_comm) { @@ -295,6 +316,17 @@ static inline nccl_net_ofi_ep_rail_t *rdma_endpoint_get_rail(nccl_net_ofi_rdma_e return &ep->rails[rail_id]; } +/* + * @brief Return control endpoint rail with index `rail_id` + */ +static inline nccl_net_ofi_ep_rail_t *rdma_endpoint_get_control_rail(nccl_net_ofi_rdma_ep_t *ep, + int rail_id) +{ + assert(ep->control_rails); + assert(rail_id < ep->num_control_rails); + return &ep->control_rails[rail_id]; +} + /* * @brief return the domain for the endpoint and rail. */ @@ -303,6 +335,14 @@ static inline struct fid_domain *rdma_endpoint_get_ofi_domain(nccl_net_ofi_rdma_ return rdma_endpoint_get_rail(ep, rail_id)->domain; } +/* + * @brief return the domain for the control endpoint and rail. + */ +static inline struct fid_domain *rdma_endpoint_get_ofi_control_domain(nccl_net_ofi_rdma_ep_t *ep, int rail_id) +{ + return rdma_endpoint_get_control_rail(ep, rail_id)->domain; +} + /* * @brief Write topology to NCCL topology file * @@ -2108,6 +2148,12 @@ static inline int free_send_ctrl_req(nccl_net_ofi_rdma_req_t *req, (nccl_net_ofi_rdma_recv_comm_t *)req->comm; rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req); + if (send_ctrl_data->ctrl_schedule != NULL) { + nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)req->comm->ep->device; + nccl_net_ofi_release_schedule(device->scheduler, send_ctrl_data->ctrl_schedule); + send_ctrl_data->ctrl_schedule = NULL; + } + if (send_ctrl_data->ctrl_fl_item) { nccl_ofi_freelist_entry_free(r_comm->ctrl_buff_fl, send_ctrl_data->ctrl_fl_item); send_ctrl_data->ctrl_fl_item = NULL; @@ -2293,9 +2339,10 @@ static inline int post_bounce_buffs_on_rail(nccl_net_ofi_rdma_ep_t *ep, static inline int post_bounce_buffs(nccl_net_ofi_rdma_ep_t *ep) { int ret = 0; + nccl_net_ofi_ep_rail_t *rail; for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { - nccl_net_ofi_ep_rail_t *rail = rdma_endpoint_get_rail(ep, rail_id); + rail = rdma_endpoint_get_rail(ep, rail_id); ret = post_bounce_buffs_on_rail(ep, rail); if (ret != 0) { NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail"); @@ -2303,10 +2350,13 @@ static inline int post_bounce_buffs(nccl_net_ofi_rdma_ep_t *ep) } } - ret = post_bounce_buffs_on_rail(ep, &ep->control_rail); - if (ret != 0) { - NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)"); - goto exit; + for (int rail_id = 0; rail_id < ep->num_control_rails; ++rail_id) { + rail = rdma_endpoint_get_control_rail(ep, rail_id); + ret = post_bounce_buffs_on_rail(ep, rail); + if (ret != 0) { + NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)"); + goto exit; + } } exit: @@ -2340,16 +2390,14 @@ static inline int post_bounce_buffs(nccl_net_ofi_rdma_ep_t *ep) static int init_send_comm_rails(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_net_ofi_rdma_ep_t *ep, int dev_id, nccl_ofi_rdma_ep_name_t *remote_ep_names, - int num_remote_rails) + int num_remote_rails, + nccl_ofi_rdma_ep_name_t *remote_control_ep_names, + int num_remote_control_rails) { int ret = 0; - - if (ep->num_rails != num_remote_rails) { - NCCL_OFI_WARN("Unexpected number of remote rails for dev %d. Expected %i but got %i", - dev_id, ep->num_rails, - num_remote_rails); - return -EINVAL; - } + nccl_net_ofi_rdma_send_comm_rail_t *comm_rail; + nccl_net_ofi_ep_rail_t *ep_rail; + nccl_ofi_rdma_ep_name_t *remote_rdma_ep_name; /** * In ENDPOINT_PER_COMM config, the ep address in the handle is not @@ -2360,13 +2408,32 @@ static int init_send_comm_rails(nccl_net_ofi_rdma_send_comm_t *s_comm, * is no longer an issue */ if (ofi_nccl_endpoint_per_communicator() != 0) { - s_comm->num_init_rails = 0; + s_comm->num_init_control_rails = 0; } - for (int rail_id = s_comm->num_init_rails; rail_id < s_comm->num_rails; ++rail_id) { - nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = &s_comm->rails[rail_id]; - nccl_net_ofi_ep_rail_t *ep_rail = &ep->rails[rail_id]; - nccl_ofi_rdma_ep_name_t *remote_rdma_ep_name = &remote_ep_names[rail_id]; + for (int rail_id = s_comm->num_init_control_rails; rail_id < s_comm->num_control_rails; ++rail_id) { + comm_rail = &s_comm->control_rails[rail_id]; + ep_rail = &ep->control_rails[rail_id]; + remote_rdma_ep_name = &remote_control_ep_names[rail_id]; + + comm_rail->local_ep = ep_rail->ofi_ep; + + /* Insert remote EP address to AV */ + ret = fi_av_insert(ep_rail->av, (void *)remote_rdma_ep_name->ep_name, 1, + &comm_rail->remote_addr, 0, NULL); + if (OFI_UNLIKELY(ret != 1)) { + NCCL_OFI_WARN("Unable to insert remote address into address vector " + "for device %d. RC: %s", + dev_id, fi_strerror(-ret)); + return -EINVAL; + } + ++(s_comm->num_init_control_rails); + } + + for (int rail_id = 0; rail_id < s_comm->num_rails; ++rail_id) { + comm_rail = &s_comm->rails[rail_id]; + ep_rail = &ep->rails[rail_id]; + remote_rdma_ep_name = &remote_ep_names[rail_id]; comm_rail->local_ep = ep_rail->ofi_ep; @@ -2379,7 +2446,6 @@ static int init_send_comm_rails(nccl_net_ofi_rdma_send_comm_t *s_comm, dev_id, fi_strerror(-ret)); return -EINVAL; } - ++(s_comm->num_init_rails); } return 0; @@ -2438,6 +2504,13 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm) return -EINVAL; } + if (conn_resp->num_control_rails != ep->num_control_rails) { + NCCL_OFI_WARN("Unexpected number of remote control rails for dev %d. Expected %i but got %i", + dev_id, ep->num_control_rails, + conn_resp->num_control_rails); + return -EINVAL; + } + /* Validate received comm ID */ if (OFI_UNLIKELY(conn_resp->local_comm_id >= device->num_comm_ids)) { NCCL_OFI_WARN("Received an invalid communicator ID %u for device %d", conn_resp->local_comm_id, @@ -2451,7 +2524,9 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm) /* Initialize rails `1...num_rails-1' */ ret = init_send_comm_rails(s_comm, ep, dev_id, conn_resp->ep_names, - conn_resp->num_rails); + conn_resp->num_rails, + conn_resp->control_ep_names, + conn_resp->num_control_rails); if (ret != 0) { return ret; } @@ -2605,13 +2680,18 @@ static int dereg_rails(nccl_net_ofi_rdma_mr_handle_t *handle) int ret = 0; int rc = 0; int num_rails = handle->num_rails; + int num_control_rails = handle->num_control_rails; - /* Cleanup memory registration for control */ - rc = fi_close(&handle->control_mr->fid); - if (OFI_UNLIKELY(rc != 0)) { - NCCL_OFI_WARN("Unable to de-register memory on control mr. RC: %d, Error: %s", - rc, fi_strerror(-rc)); - ret = rc; + /* Cleanup memory registration for control rails */ + for (int rail_id = 0; rail_id != num_control_rails; ++rail_id) { + /* No memory registration available for this rail */ + if (!handle->control_mr[rail_id]) continue; + rc = fi_close(&handle->control_mr[rail_id]->fid); + if (OFI_UNLIKELY(rc != 0)) { + NCCL_OFI_WARN("Unable to de-register memory on control MR. RC: %d, Error: %s", + rc, fi_strerror(-rc)); + ret = rc; + } } /* Cleanup memory registration for data rails */ @@ -2629,19 +2709,55 @@ static int dereg_rails(nccl_net_ofi_rdma_mr_handle_t *handle) return ret; } +static inline void free_rdma_mr_handle(nccl_net_ofi_rdma_mr_handle_t *handle) { + if (handle) { + if (handle->control_mr) { + free(handle->control_mr); + } + if (handle->mr) { + free(handle->mr); + } + free(handle); + } +} + /* * @brief Allocate a rdma memory registration handle with `num_rails' rails using `calloc()' * * @param num_rails * The number of rails of the allocated receive communicator + * @param num_control_rails + * The number of control rails of the allocated receive communicator * @return handle, on success * NULL, on error */ -static inline nccl_net_ofi_rdma_mr_handle_t *calloc_rdma_mr_handle(int num_rails) +static inline nccl_net_ofi_rdma_mr_handle_t *calloc_rdma_mr_handle(int num_rails, int num_control_rails) { - return (nccl_net_ofi_rdma_mr_handle_t *)calloc( - 1, - sizeof(nccl_net_ofi_rdma_mr_handle_t) + num_rails * sizeof(struct fid_mr *)); + nccl_net_ofi_rdma_mr_handle_t *ret_handle = (nccl_net_ofi_rdma_mr_handle_t *)calloc(1, sizeof(nccl_net_ofi_rdma_mr_handle_t)); + + if (OFI_UNLIKELY(!ret_handle)) { + NCCL_OFI_WARN("Unable to allocate memory registration handle"); + goto error; + } + + ret_handle->mr = (struct fid_mr **)calloc(num_rails, sizeof(struct fid_mr *)); + if (OFI_UNLIKELY(!ret_handle->mr)) { + NCCL_OFI_WARN("Unable to allocate memory registration handles array"); + goto error; + } + + ret_handle->control_mr = (struct fid_mr **)calloc(num_control_rails, sizeof(struct fid_mr *)); + if (OFI_UNLIKELY(!ret_handle->control_mr)) { + NCCL_OFI_WARN("Unable to allocate memory registration control handles array"); + goto error; + } + + return ret_handle; + +error: + + free_rdma_mr_handle(ret_handle); + return NULL; } /* @@ -2706,7 +2822,7 @@ static int dereg_mr_ep(nccl_net_ofi_rdma_mr_handle_t *mr_handle, ret = dereg_rails(mr_handle); - free(mr_handle); + free_rdma_mr_handle(mr_handle); return ret; } @@ -2727,11 +2843,13 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep, assert(device != NULL); int dev_id = device->base.dev_id; - int num_rails = device->num_rails; + int num_rails = ep->num_rails; + int num_control_rails = ep->num_control_rails; + nccl_ofi_idpool_t *key_pool = &device->base.mr_rkey_pool; /* Allocate rdma memory registration handle */ - ret_handle = calloc_rdma_mr_handle(num_rails); + ret_handle = calloc_rdma_mr_handle(num_rails, num_control_rails); if (OFI_UNLIKELY(!ret_handle)) { NCCL_OFI_WARN("Unable to allocate memory registration handle"); ret = -ENOMEM; @@ -2743,16 +2861,7 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep, if (OFI_UNLIKELY(ret != 0)) { NCCL_OFI_WARN("Could not set registration request attributes, dev: %d", dev_id); - free(ret_handle); - ret_handle = NULL; - goto exit; - } - - ret = register_rail_mr_buffer(ep->control_rail.domain, ep->control_rail.ofi_ep, - -1, type, &mr_attr, regattr_flags, - &ret_handle->control_mr); - if (OFI_UNLIKELY(ret != 0)) { - free(ret_handle); + free_rdma_mr_handle(ret_handle); ret_handle = NULL; goto exit; } @@ -2775,6 +2884,24 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep, } } + /* Register memory on each control rail */ + ret_handle->num_control_rails = num_control_rails; + for (int rail_id = 0; rail_id != num_control_rails; ++rail_id) { + nccl_net_ofi_ep_rail_t *rail = rdma_endpoint_get_control_rail(ep, rail_id); + domain = rdma_endpoint_get_ofi_control_domain(ep, rail_id); + + ret = register_rail_mr_buffer(domain, rail->ofi_ep, + dev_id, type, &mr_attr, regattr_flags, + &ret_handle->control_mr[rail_id]); + if (OFI_UNLIKELY(ret != 0)) { + if (dereg_mr_ep(ret_handle, key_pool, NULL) != 0) { + NCCL_OFI_WARN("Error de-registering MR"); + } + ret_handle = NULL; + goto exit; + } + } + exit: *mhandle = ret_handle; return ret; @@ -3065,6 +3192,7 @@ static inline int insert_send_ctrl_req( nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, nccl_net_ofi_rdma_req_t *recv_req) { + nccl_net_ofi_scheduler_t *scheduler = device->scheduler; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; nccl_net_ofi_rdma_req_t *send_ctrl_req = allocate_req(r_comm->nccl_ofi_reqs_fl); if (OFI_UNLIKELY(send_ctrl_req == NULL)) { @@ -3081,6 +3209,24 @@ static inline int insert_send_ctrl_req( rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(send_ctrl_req); + if (ep->num_control_rails > 1) { + size_t ctrl_msg_len = nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys); + send_ctrl_data->ctrl_schedule = scheduler->get_schedule(scheduler, ctrl_msg_len, ep->num_control_rails); + + if (OFI_UNLIKELY(!(send_ctrl_data->ctrl_schedule))) { + return -EINVAL; + } else if (OFI_UNLIKELY(send_ctrl_data->ctrl_schedule->num_xfer_infos != 1)) { + NCCL_OFI_WARN( + "Invalid schedule for outgoing control message (%zu bytes). Expected one rail, but got " + "%zu", + size, + send_ctrl_data->ctrl_schedule->num_xfer_infos); + return -EINVAL; + } + } else { + send_ctrl_data->ctrl_schedule = NULL; + } + send_ctrl_data->recv_req = recv_req; send_ctrl_data->ctrl_fl_item = NULL; @@ -3562,6 +3708,18 @@ static int alloc_and_reg_flush_buff(nccl_net_ofi_rdma_recv_comm_t *r_comm, int d return ret; } +static inline void free_rdma_recv_comm(nccl_net_ofi_rdma_recv_comm_t *r_comm) { + if (r_comm) { + if (r_comm->control_rails) { + free(r_comm->control_rails); + } + if (r_comm->rails) { + free(r_comm->rails); + } + free(r_comm); + } +} + static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm_t *r_comm) { nccl_net_ofi_rdma_device_t *device = NULL; @@ -3632,7 +3790,7 @@ static int recv_comm_destroy(nccl_net_ofi_rdma_recv_comm_t *r_comm) return ret; } - free(r_comm); + free_rdma_recv_comm(r_comm); ret = ep->base.release_ep(&ep->base); @@ -3767,6 +3925,18 @@ static int recv_comm_process_all_finalizing(void) return ret; } +static inline void free_rdma_send_comm(nccl_net_ofi_rdma_send_comm_t *s_comm) { + if (s_comm) { + if (s_comm->control_rails) { + free(s_comm->control_rails); + } + if (s_comm->rails) { + free(s_comm->rails); + } + free(s_comm); + } +} + static int send_comm_destroy(nccl_net_ofi_rdma_send_comm_t *s_comm) { int ret = 0; @@ -3813,7 +3983,7 @@ static int send_comm_destroy(nccl_net_ofi_rdma_send_comm_t *s_comm) return ret; } - free(s_comm); + free_rdma_send_comm(s_comm); ret = ep->base.release_ep(&ep->base); @@ -4092,15 +4262,37 @@ static int flush(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, * * @param num_rails * The number of rails of the allocated receive communicator + * @param num_control_rails + * The number of control rails of the allocated receive communicator * @return communicator, on success * NULL, on error */ -static inline nccl_net_ofi_rdma_recv_comm_t *calloc_rdma_recv_comm(int num_rails) +static inline nccl_net_ofi_rdma_recv_comm_t *calloc_rdma_recv_comm(int num_rails, int num_control_rails) { - return (nccl_net_ofi_rdma_recv_comm_t *)calloc( - 1, - sizeof(nccl_net_ofi_rdma_recv_comm_t) + - num_rails * sizeof(nccl_net_ofi_rdma_recv_comm_rail_t)); + nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)calloc(1, sizeof(nccl_net_ofi_rdma_recv_comm_t)); + if (OFI_UNLIKELY(!r_comm)) { + NCCL_OFI_WARN("Unable to allocate receive communicator"); + goto error; + } + + r_comm->rails = (nccl_net_ofi_rdma_recv_comm_rail_t *)calloc(num_rails, sizeof(nccl_net_ofi_rdma_recv_comm_rail_t)); + if (OFI_UNLIKELY(!r_comm->rails)) { + NCCL_OFI_WARN("Unable to allocate receive communicator rails array"); + goto error; + } + + r_comm->control_rails = (nccl_net_ofi_rdma_recv_comm_rail_t *)calloc(num_control_rails, sizeof(nccl_net_ofi_rdma_recv_comm_rail_t)); + if (OFI_UNLIKELY(!r_comm->control_rails)) { + NCCL_OFI_WARN("Unable to allocate receive communicator control rails array"); + goto error; + } + + return r_comm; + +error: + + free_rdma_recv_comm(r_comm); + return NULL; } static void init_rma_op_req(nccl_net_ofi_rdma_req_t *req, @@ -4258,14 +4450,19 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device nccl_net_ofi_rdma_ep_t *ep = NULL; int dev_id = device->base.dev_id; int num_rails = l_comm_ep->num_rails; + int num_control_rails = l_comm_ep->num_control_rails; if (num_rails < 1) { NCCL_OFI_WARN("Invalid number of rails. Expected at least one rail"); goto error; } + if (num_control_rails < 1) { + NCCL_OFI_WARN("Invalid number of control rails. Expected at least one rail"); + goto error; + } /* Build recv_comm */ - r_comm = calloc_rdma_recv_comm(num_rails); + r_comm = calloc_rdma_recv_comm(num_rails, num_control_rails); if (r_comm == NULL) { NCCL_OFI_WARN("Unable to allocate receive comm object for device %d", dev_id); @@ -4274,7 +4471,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device ret = nccl_net_ofi_mutex_init(&r_comm->ctrl_counter_lock, NULL); if (ret != 0) { - free(r_comm); + free_rdma_recv_comm(r_comm); return NULL; } @@ -4353,23 +4550,35 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device /* Add ourselves to ep's lookup array */ rdma_device_set_comm(device, r_comm->local_comm_id, &r_comm->base.base); - r_comm->control_rail.local_ep = ep->control_rail.ofi_ep;; - ret = fi_av_insert(ep->control_rail.av, (void *)conn_msg->control_ep_name.ep_name, 1, - &r_comm->control_rail.remote_addr, 0, NULL); - if (OFI_UNLIKELY(ret != 1)) { - NCCL_OFI_WARN("Unable to insert remote address into address vector " - "for device %d. RC: %s", - dev_id, fi_strerror(-ret)); - goto error; - } + /* Allocate array of control communicator rails */ + r_comm->num_control_rails = num_control_rails; - ret = fi_av_insert(ep->control_rail.av, (void *)ep->control_rail.local_ep_name, 1, - &r_comm->control_rail.local_addr, 0, NULL); - if (OFI_UNLIKELY(ret != 1)) { - NCCL_OFI_WARN("Unable to insert local address into address vector " - "for device %d. RC: %s", - dev_id, fi_strerror(-ret)); - goto error; + /* Initialize local and remote endpoint resources for each control rail */ + for (int rail_id = 0; rail_id != num_control_rails; ++rail_id) { + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = rdma_recv_comm_get_control_rail(r_comm, rail_id); + nccl_net_ofi_ep_rail_t *rail = rdma_endpoint_get_control_rail(ep, rail_id); + nccl_ofi_rdma_ep_name_t *remote_ep_name = &conn_msg->control_ep_names[rail_id]; + + comm_rail->local_ep = rail->ofi_ep; + + /* Insert remote EP address to AV */ + ret = fi_av_insert(rail->av, (void *)remote_ep_name->ep_name, 1, + &comm_rail->remote_addr, 0, NULL); + if (OFI_UNLIKELY(ret != 1)) { + NCCL_OFI_WARN("Unable to insert remote address into address vector " + "for device %d. RC: %s", + dev_id, fi_strerror(-ret)); + goto error; + } + + ret = fi_av_insert(rail->av, (void *)rail->local_ep_name, 1, + &comm_rail->local_addr, 0, NULL); + if (OFI_UNLIKELY(ret != 1)) { + NCCL_OFI_WARN("Unable to insert local address into address vector " + "for device %d. RC: %s", + dev_id, fi_strerror(-ret)); + goto error; + } } /* Allocate array of communicator rails */ @@ -4429,7 +4638,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device r_comm->msgbuff = nccl_ofi_msgbuff_init(NCCL_OFI_RDMA_MSGBUFF_SIZE, NCCL_OFI_RDMA_SEQ_BITS); if (!r_comm->msgbuff) { NCCL_OFI_WARN("Failed to allocate and initialize message buffer"); - free(r_comm); + free_rdma_recv_comm(r_comm); return NULL; } @@ -4468,7 +4677,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device } } nccl_net_ofi_mutex_destroy(&r_comm->ctrl_counter_lock); - free(r_comm); + free_rdma_recv_comm(r_comm); } return NULL; @@ -4492,23 +4701,35 @@ static int prepare_conn_resp(nccl_net_ofi_rdma_ep_t *ep, int dev_id) { int num_rails = ep->num_rails; + int num_control_rails = ep->num_control_rails; nccl_ofi_rdma_connection_info_t *conn_resp = &l_comm->conn_msg; + nccl_ofi_rdma_ep_name_t *rdma_ep_name; + nccl_net_ofi_ep_rail_t *ep_rail; - if (num_rails > MAX_NUM_RAILS) { - NCCL_OFI_WARN("Unexpected number of rails. Expected at most %i but got %i", - MAX_NUM_RAILS, num_rails); - return -EINVAL; - } + assert(num_rails <= MAX_NUM_RAILS); + assert(num_control_rails <= MAX_NUM_RAILS); conn_resp->type = NCCL_OFI_RDMA_MSG_CONN_RESP; /* Set number of rails to be sent back to remote for verification */ conn_resp->num_rails = num_rails; + conn_resp->num_control_rails = num_control_rails; /* Set libfabric endpoint names for each rail */ for (int rail_id = 0; rail_id != num_rails; ++rail_id) { - nccl_ofi_rdma_ep_name_t *rdma_ep_name = &conn_resp->ep_names[rail_id]; - nccl_net_ofi_ep_rail_t *ep_rail = rdma_endpoint_get_rail(ep, rail_id); + rdma_ep_name = &conn_resp->ep_names[rail_id]; + ep_rail = rdma_endpoint_get_rail(ep, rail_id); + + assert(sizeof(rdma_ep_name->ep_name) == sizeof(ep_rail->local_ep_name)); + memcpy(rdma_ep_name->ep_name, ep_rail->local_ep_name, + ep_rail->local_ep_name_len); + rdma_ep_name->ep_name_len = ep_rail->local_ep_name_len; + } + + /* Set libfabric endpoint names for each control rail */ + for (int rail_id = 0; rail_id != num_control_rails; ++rail_id) { + rdma_ep_name = &conn_resp->control_ep_names[rail_id]; + ep_rail = rdma_endpoint_get_control_rail(ep, rail_id); assert(sizeof(rdma_ep_name->ep_name) == sizeof(ep_rail->local_ep_name)); memcpy(rdma_ep_name->ep_name, ep_rail->local_ep_name, @@ -4545,7 +4766,7 @@ static int post_send_conn_resp(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_req_t *req) { ssize_t rc = 0; - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail;; + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = rdma_recv_comm_get_control_rail(r_comm, 0); req->state = NCCL_OFI_RDMA_REQ_PENDING; rc = fi_send(comm_rail->local_ep, (void *)conn_resp, sizeof(nccl_ofi_rdma_connection_info_t), NULL, @@ -4686,6 +4907,15 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm, goto exit; } + /* Number of remote control rails and number of local control rails match */ + if (conn_msg->num_control_rails != l_comm_ep->num_control_rails) { + NCCL_OFI_WARN("Unexpected number of remote control rails for dev %d. Expected %i but got %i", + dev_id, l_comm_ep->num_control_rails, + conn_msg->num_control_rails); + ret = -EINVAL; + goto exit; + } + /* Prepare receive communicator object for the received peer connection */ r_comm = prepare_recv_comm(device, l_comm_ep, conn_msg); if (OFI_UNLIKELY(r_comm == NULL)) { @@ -4848,6 +5078,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, int comm_id = 0; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_ep; + nccl_net_ofi_ep_rail_t *first_control_rail = rdma_endpoint_get_control_rail(ep, 0); /* Retrieve and validate device */ nccl_net_ofi_rdma_device_t *device = rdma_endpoint_get_device(ep); @@ -4857,14 +5088,14 @@ static int listen(nccl_net_ofi_ep_t *base_ep, /* Build handle */ memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t)); - assert(sizeof(handle->ep_name) == sizeof(ep->control_rail.local_ep_name)); - memcpy(handle->ep_name, ep->control_rail.local_ep_name, - ep->control_rail.local_ep_name_len); + assert(sizeof(handle->ep_name) == sizeof(first_control_rail->local_ep_name)); + memcpy(handle->ep_name, first_control_rail->local_ep_name, + first_control_rail->local_ep_name_len); /* We don't copy the size here since the handle doesn't have a size field. The size will be distributed later by the connect response message. Instead, zero the unused bytes here. */ - memset(handle->ep_name + ep->control_rail.local_ep_name_len, 0, - sizeof(handle->ep_name) - ep->control_rail.local_ep_name_len); + memset(handle->ep_name + first_control_rail->local_ep_name_len, 0, + sizeof(handle->ep_name) - first_control_rail->local_ep_name_len); /* Build listen_comm */ l_comm = (nccl_net_ofi_rdma_listen_comm_t *)calloc(1, @@ -5259,8 +5490,10 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req); nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; - // Get communicator rail information to xfer the req - nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail; + nccl_net_ofi_schedule_t *schedule = send_ctrl_data->ctrl_schedule; + nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail; + void *desc; + int rail_id; nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item = send_ctrl_data->ctrl_fl_item; @@ -5269,9 +5502,19 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) (freelist_regmr_fn_handle_t *)ctrl_fl_item->fl_reginfo.mr_handle; nccl_net_ofi_rdma_mr_handle_t *mr_handle = fl_handle->mr_handle; - void *desc = fi_mr_desc(mr_handle->control_mr); + if (schedule != NULL) { + /* Use round robin schedule for ctrl message */ + nccl_net_ofi_xfer_info_t *xfer_info = &schedule->rail_xfer_infos[0]; + rail_id = xfer_info->rail_id; + } else { + /* Always use control rail 0 for ctrl message */ + rail_id = 0; + } - NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, ep->control_rail.rail_id, req->comm, req, req->msg_seq_num); + comm_rail = rdma_recv_comm_get_control_rail(r_comm, rail_id); + assert(rail_id < mr_handle->num_control_rails); + desc = fi_mr_desc(mr_handle->control_mr[rail_id]); + NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, rail_id, req->comm, req, req->msg_seq_num); size_t ctrl_msg_len = nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys); @@ -5573,9 +5816,13 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t /* look for control messages and then retry the message search to avoid unnecessary polling / queueing. */ if (OFI_UNLIKELY(!polled_cq && !have_ctrl)) { - ret = ofi_process_cq_rail(ep, &ep->control_rail); - if (ret != 0) { - goto error; + for (int rail_id = 0; rail_id != ep->num_control_rails; ++rail_id) { + nccl_net_ofi_ep_rail_t *rail = rdma_endpoint_get_control_rail(ep, rail_id); + + ret = ofi_process_cq_rail(ep, rail); + if (OFI_UNLIKELY(ret != 0)) { + goto error; + } } polled_cq = true; goto retry; @@ -5722,6 +5969,7 @@ static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, nccl_ofi_rdma_connection_info_t *conn_msg) { int num_rails = ep->num_rails; + int num_control_rails = ep->num_control_rails; conn_msg->type = NCCL_OFI_RDMA_MSG_CONN; @@ -5733,14 +5981,16 @@ static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, /* Set number of rails to be sent back to remote for verification */ conn_msg->num_rails = num_rails; + conn_msg->num_control_rails = num_control_rails; - /* Set libfabric endpoint name for control rail */ - memcpy(conn_msg->control_ep_name.ep_name, - ep->control_rail.local_ep_name, - ep->control_rail.local_ep_name_len); - conn_msg->control_ep_name.ep_name_len = - ep->control_rail.local_ep_name_len; - + /* Set libfabric endpoint names for each control rail */ + for (int rail_id = 0; rail_id != num_control_rails; ++rail_id) { + memcpy(conn_msg->control_ep_names[rail_id].ep_name, + ep->control_rails[rail_id].local_ep_name, + ep->control_rails[rail_id].local_ep_name_len); + conn_msg->control_ep_names[rail_id].ep_name_len = + ep->control_rails[rail_id].local_ep_name_len; + } /* Set libfabric endpoint names for each rail */ for (int rail_id = 0; rail_id != num_rails; ++rail_id) { @@ -5757,15 +6007,37 @@ static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, * * @param num_rails * The number of rails of the allocated send communicator + * @param num_control_rails + * The number of control rails of the allocated send communicator * @return communicator, on success * NULL, on error */ -static inline nccl_net_ofi_rdma_send_comm_t *calloc_rdma_send_comm(int num_rails) +static inline nccl_net_ofi_rdma_send_comm_t *calloc_rdma_send_comm(int num_rails, int num_control_rails) { - return (nccl_net_ofi_rdma_send_comm_t *)calloc( - 1, - sizeof(nccl_net_ofi_rdma_send_comm_t) + - num_rails * sizeof(nccl_net_ofi_rdma_send_comm_rail_t)); + nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)calloc(1, sizeof(nccl_net_ofi_rdma_send_comm_t)); + if (OFI_UNLIKELY(!s_comm)) { + NCCL_OFI_WARN("Unable to allocate send communicator"); + goto error; + } + + s_comm->rails = (nccl_net_ofi_rdma_send_comm_rail_t *)calloc(num_rails, sizeof(nccl_net_ofi_rdma_send_comm_rail_t)); + if (OFI_UNLIKELY(!s_comm->rails)) { + NCCL_OFI_WARN("Unable to allocate send communicator rails array"); + goto error; + } + + s_comm->control_rails = (nccl_net_ofi_rdma_send_comm_rail_t *)calloc(num_control_rails, sizeof(nccl_net_ofi_rdma_send_comm_rail_t)); + if (OFI_UNLIKELY(!s_comm->control_rails)) { + NCCL_OFI_WARN("Unable to allocate send communicator control rails array"); + goto error; + } + + return s_comm; + +error: + + free_rdma_send_comm(s_comm); + return NULL; } /* @@ -5780,6 +6052,7 @@ static inline nccl_net_ofi_rdma_send_comm_t *calloc_rdma_send_comm(int num_rails static inline int init_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep) { int ret = 0; + nccl_net_ofi_ep_rail_t *rail; ret = nccl_ofi_freelist_init(sizeof(nccl_net_ofi_rdma_req_t), ofi_nccl_rdma_min_posted_bounce_buffers(), 16, 0, @@ -5804,25 +6077,29 @@ static inline int init_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep) * The *_bounce_posted limits are used in the progress engine to * determine if the receive queue is hydrated with sufficient buffers. * The parameters account for all the rails, so scale down bounds to - * what a single rail would need for the control endpoint. + * what a single rail would need. */ - ep->control_rail.min_bounce_posted = NCCL_OFI_DIV_CEIL( - ofi_nccl_rdma_min_posted_bounce_buffers(), ep->num_rails + for (int rail_id = 0; rail_id < ep->num_control_rails; ++rail_id) { + rail = rdma_endpoint_get_control_rail(ep, rail_id); + rail->min_bounce_posted = NCCL_OFI_DIV_CEIL( + ofi_nccl_rdma_min_posted_bounce_buffers(), ep->num_control_rails ); - ep->control_rail.max_bounce_posted = NCCL_OFI_DIV_CEIL( - ofi_nccl_rdma_max_posted_bounce_buffers(), ep->num_rails + rail->max_bounce_posted = NCCL_OFI_DIV_CEIL( + ofi_nccl_rdma_max_posted_bounce_buffers(), ep->num_control_rails ); - ep->control_rail.num_bounce_posted = 0; - ret = nccl_net_ofi_mutex_init(&ep->control_rail.bounce_mutex, NULL); + rail->num_bounce_posted = 0; + nccl_net_ofi_mutex_init(&rail->bounce_mutex, NULL); + } for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { - nccl_net_ofi_ep_rail_t *rail = rdma_endpoint_get_rail(ep, rail_id); + rail = rdma_endpoint_get_rail(ep, rail_id); rail->min_bounce_posted = NCCL_OFI_DIV_CEIL( ofi_nccl_rdma_min_posted_bounce_buffers(), ep->num_rails ); rail->max_bounce_posted = NCCL_OFI_DIV_CEIL( ofi_nccl_rdma_max_posted_bounce_buffers(), ep->num_rails ); + rail->num_bounce_posted = 0; nccl_net_ofi_mutex_init(&rail->bounce_mutex, NULL); } @@ -5841,6 +6118,8 @@ static inline int init_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep) static inline int fini_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep) { int ret = 0; + nccl_net_ofi_ep_rail_t *rail; + ret = nccl_ofi_freelist_fini(ep->bounce_buff_fl); if (ret != 0) { NCCL_OFI_WARN("Failed to fini bounce_buff_fl"); @@ -5854,11 +6133,14 @@ static inline int fini_bounce_buffers(nccl_net_ofi_rdma_ep_t *ep) } for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { - nccl_net_ofi_ep_rail_t *rail = rdma_endpoint_get_rail(ep, rail_id); + rail = rdma_endpoint_get_rail(ep, rail_id); nccl_net_ofi_mutex_destroy(&rail->bounce_mutex); } - nccl_net_ofi_mutex_destroy(&ep->control_rail.bounce_mutex); + for (int rail_id = 0; rail_id < ep->num_control_rails; ++rail_id) { + rail = rdma_endpoint_get_control_rail(ep, rail_id); + nccl_net_ofi_mutex_destroy(&rail->bounce_mutex); + } return ret; } @@ -5988,7 +6270,7 @@ static int rma_write_inline(nccl_net_ofi_send_comm_t *send_comm, void* src, size * @brief Creates send communication for a peer * * Allocate and Initalize send communicator and its resources; Only - * the first communicator rail is initialized. Use function + * the first communicator control rail is initialized. Use function * init_send_comm_rails() to initialize the remaining communicator * rails. * @@ -6013,8 +6295,11 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, fi_addr_t remote_addr; nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL; int num_rails = ep->num_rails; + int num_control_rails = ep->num_control_rails; int rail_id = 0; - nccl_net_ofi_ep_rail_t *control_rail = &ep->control_rail; + nccl_net_ofi_ep_rail_t *first_control_rail = rdma_endpoint_get_control_rail(ep, 0); + nccl_net_ofi_rdma_send_comm_rail_t *first_comm_control_rail; + *s_comm = NULL; /* Retrieve and validate device */ @@ -6026,7 +6311,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, int dev_id = device->base.dev_id; /* Allocate and initialize send_comm */ - ret_s_comm = calloc_rdma_send_comm(num_rails); + ret_s_comm = calloc_rdma_send_comm(num_rails, num_control_rails); if (OFI_UNLIKELY(ret_s_comm == NULL)) { NCCL_OFI_WARN("Couldn't allocate send comm object for dev %d", dev_id); return -ENOMEM; @@ -6034,7 +6319,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret = nccl_net_ofi_mutex_init(&ret_s_comm->ctrl_recv_lock, NULL); if (ret != 0) { - free(ret_s_comm); + free_rdma_send_comm(ret_s_comm); return ret; } @@ -6079,9 +6364,10 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, /* Allocate communicator rails array */ ret_s_comm->num_rails = num_rails; + ret_s_comm->num_control_rails = num_control_rails; /* Insert remote name into AV of first rail */ - ret = fi_av_insert(control_rail->av, + ret = fi_av_insert(first_control_rail->av, (void *)handle->ep_name, 1, &remote_addr, 0, NULL); if (OFI_UNLIKELY(ret != 1)) { @@ -6092,11 +6378,12 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, } /* Store remote address of first rail in communicator */ - ret_s_comm->control_rail.remote_addr = remote_addr; + first_comm_control_rail = &ret_s_comm->control_rails[0]; + first_comm_control_rail->remote_addr = remote_addr; /* Store local libfabric endpoint of control rail */ - ret_s_comm->control_rail.local_ep = control_rail->ofi_ep; - ret_s_comm->num_init_rails = 0; + first_comm_control_rail->local_ep = first_control_rail->ofi_ep; + ret_s_comm->num_init_control_rails = 1; /* Allocate request free list */ ret = nccl_ofi_freelist_init(sizeof(nccl_net_ofi_rdma_req_t), 16, 16, @@ -6140,7 +6427,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, } } nccl_net_ofi_mutex_destroy(&ret_s_comm->ctrl_recv_lock); - free(ret_s_comm); + free_rdma_send_comm(ret_s_comm); } return ret; @@ -6217,7 +6504,7 @@ static int post_send_conn(nccl_net_ofi_rdma_send_comm_t *s_comm, nccl_net_ofi_rdma_req_t *req) { ssize_t rc = 0; - nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = &s_comm->control_rail; + nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = rdma_send_comm_get_control_rail(s_comm, 0); /* * TODO: replace it with API of FI_INJECT type when most of @@ -6447,8 +6734,13 @@ static void ep_rail_release(nccl_net_ofi_ep_rail_t *rail, int dev_id, struct fid */ static void release_rdma_ep_resources(nccl_net_ofi_rdma_ep_t *ep, int dev_id) { - nccl_net_ofi_ep_rail_t * rail = &ep->control_rail; - ep_rail_release(rail, dev_id, NULL); + nccl_net_ofi_ep_rail_t *rail; + + for (int rail_id = 0; rail_id != ep->num_control_rails; ++rail_id) { + rail = rdma_endpoint_get_control_rail(ep, rail_id); + ep_rail_release(rail, dev_id, NULL); + } + for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) { rail = rdma_endpoint_get_rail(ep, rail_id); ep_rail_release(rail, dev_id, rail->cq); @@ -6549,30 +6841,33 @@ static int init_rail_ofi_resources(nccl_net_ofi_rdma_device_t *device, int dev_id = device->base.dev_id; nccl_net_ofi_rdma_device_rail_t *rail_dev; nccl_net_ofi_ep_rail_t *rail; + nccl_net_ofi_ep_rail_t *control_rail; /* Initialize libfabric resources of endpoint rails */ - for (int rail_id = 0; rail_id != device->num_rails; ++rail_id) { + for (int rail_id = 0; rail_id != ep->num_rails; ++rail_id) { rail_dev = rdma_device_get_rail(device, rail_id); rail = rdma_endpoint_get_rail(ep, rail_id); ret = ep_rail_init(ep, dev_id, rail_id, rail_dev, rail); if (ret != 0) { + NCCL_OFI_WARN("Initializing rail %d failed", rail_id); goto exit; } } - /* we pass 0 as the railid for the control rail, so - * that any lookups based on railid in the domain find - * the right domain */ - rail_dev = rdma_device_get_rail(device, 0); - rail = rdma_endpoint_get_rail(ep, 0); - ep->control_rail.cq = rail->cq; - ret = ep_rail_init(ep, dev_id, 0, rail_dev, &ep->control_rail); - if (ret != 0) { - NCCL_OFI_WARN("Initializing control rail failed"); - goto exit; - } + /* Initialize libfabric resources of endpoint control rails */ + for (int rail_id = 0; rail_id != ep->num_control_rails; ++rail_id) { + rail_dev = rdma_device_get_rail(device, rail_id); + rail = rdma_endpoint_get_rail(ep, rail_id); + control_rail = rdma_endpoint_get_control_rail(ep, rail_id); + control_rail->cq = rail->cq; + ret = ep_rail_init(ep, dev_id, rail_id, rail_dev, control_rail); + if (ret != 0) { + NCCL_OFI_WARN("Initializing control rail %d failed", rail_id); + goto exit; + } + } exit: if (ret != 0) { @@ -6666,6 +6961,7 @@ static int nccl_net_ofi_rdma_endpoint_free(nccl_net_ofi_ep_t *base_ep) return ret; } + free(ep->control_rails); free(ep->rails); free(ep); @@ -6740,6 +7036,27 @@ static int nccl_net_ofi_rdma_device_create_endpoint(nccl_net_ofi_device_t *base_ goto error; } + if (ofi_nccl_rdma_rr_ctrl_msg()) { + /* + * Round robin the control message across all rails by using dedicated + * endpoints with CQs shared with the data endpoints. + */ + ep->num_control_rails = device->num_rails; + } else { + /* + * Use a single rail for control messages, with a dedicated + * endpoint and a CQ shared with the data endpoint. + */ + ep->num_control_rails = 1; + } + + ep->control_rails = (nccl_net_ofi_ep_rail_t *)calloc(ep->num_control_rails, sizeof(nccl_net_ofi_ep_rail_t)); + if (!ep->control_rails) { + NCCL_OFI_WARN("Unable to allocate rdma control rails"); + ret = -ENOMEM; + goto error; + } + ret = nccl_ofi_deque_init(&ep->pending_reqs_queue); if (ret != 0) { NCCL_OFI_WARN("Failed to init pending_reqs_queue: %d", ret);