diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 67c0a812e..16ecb16c6 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -80,8 +80,10 @@ typedef enum nccl_net_ofi_rdma_req_type { NCCL_OFI_RDMA_RECV_SEGMS, /* Eager local copy request. Subrequest of NCCL_OFI_RDMA_RECV */ NCCL_OFI_RDMA_EAGER_COPY, - /* Rx buff post request */ - NCCL_OFI_RDMA_RX_BUFF, + /* Ctrl rx buff post request */ + NCCL_OFI_RDMA_CTRL_RX_BUFF, + /* Eager rx buff post request */ + NCCL_OFI_RDMA_EAGER_RX_BUFF, /* Flush request */ NCCL_OFI_RDMA_FLUSH, /* Connect message send request */ diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 3b7532868..2346deeec 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -633,7 +633,8 @@ static inline int get_properties(nccl_net_ofi_device_t *base_dev, * @brief Return rx data struct of rx request */ static inline rdma_req_rx_buff_data_t *get_rx_buff_data(nccl_net_ofi_rdma_req_t *req) { - assert(req->type == NCCL_OFI_RDMA_RX_BUFF); + assert((req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF) || + (req->type == NCCL_OFI_RDMA_EAGER_RX_BUFF)); return &req->rx_buff_data; } @@ -1242,7 +1243,7 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm); static int handle_close_msg_recv(nccl_net_ofi_rdma_req_t *rx_buff_req) { - assert(rx_buff_req->type == NCCL_OFI_RDMA_RX_BUFF); + assert(rx_buff_req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF); rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(rx_buff_req); @@ -1287,7 +1288,8 @@ static inline int handle_rx_buff_recv(nccl_net_ofi_rdma_device_t *device, int ra NCCL_OFI_WARN("RECV event had NULL ctx!"); return -EINVAL; } - if (OFI_UNLIKELY(rx_buff_req->type != NCCL_OFI_RDMA_RX_BUFF)) { + if (OFI_UNLIKELY((eager && (rx_buff_req->type != NCCL_OFI_RDMA_EAGER_RX_BUFF)) + || ((!eager) && (rx_buff_req->type != NCCL_OFI_RDMA_CTRL_RX_BUFF)))) { NCCL_OFI_WARN("Invalid non-rx_buff request as ctx!"); return -EINVAL; } @@ -1530,8 +1532,10 @@ static const char *req_type_str(nccl_net_ofi_rdma_req_type_t type) return "SEND_CLOSE"; case NCCL_OFI_RDMA_RECV_SEGMS: return "RECV_SEGMS"; - case NCCL_OFI_RDMA_RX_BUFF: - return "RX_BUFF"; + case NCCL_OFI_RDMA_EAGER_RX_BUFF: + return "EAGER_RX_BUFF"; + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + return "CTRL_RX_BUFF"; case NCCL_OFI_RDMA_FLUSH: return "FLUSH"; case NCCL_OFI_RDMA_EAGER_COPY: @@ -1656,7 +1660,8 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_ case NCCL_OFI_RDMA_SEND_CLOSE: case NCCL_OFI_RDMA_RECV_SEGMS: case NCCL_OFI_RDMA_EAGER_COPY: - case NCCL_OFI_RDMA_RX_BUFF: + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + case NCCL_OFI_RDMA_EAGER_RX_BUFF: case NCCL_OFI_RDMA_FLUSH: case NCCL_OFI_RDMA_SEND_CONN: case NCCL_OFI_RDMA_RECV_CONN: @@ -1691,7 +1696,8 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_ case NCCL_OFI_RDMA_SEND_CTRL: case NCCL_OFI_RDMA_SEND_CLOSE: case NCCL_OFI_RDMA_RECV_SEGMS: - case NCCL_OFI_RDMA_RX_BUFF: + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + case NCCL_OFI_RDMA_EAGER_RX_BUFF: case NCCL_OFI_RDMA_SEND_CONN: case NCCL_OFI_RDMA_RECV_CONN: case NCCL_OFI_RDMA_RECV_CONN_RESP: @@ -1774,7 +1780,7 @@ static inline int process_err_completion(nccl_net_ofi_rdma_device_t *device, err_entry.prov_errno, fi_cq_strerror(cq, err_entry.prov_errno, err_entry.err_data, NULL, 0), (long)err_entry.len, nccl_net_ofi_req_str(req)); - if (req->type == NCCL_OFI_RDMA_RX_BUFF) { + if ((req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF) || (req->type == NCCL_OFI_RDMA_EAGER_RX_BUFF)) { /* A rx buffer receive failed -- this is an internal error so bail out */ NCCL_OFI_WARN("Fatal: rx buffer recv completed with error"); } else { @@ -1850,7 +1856,8 @@ static int receive_progress(nccl_net_ofi_rdma_req_t *req, bool add_to_pending) case NCCL_OFI_RDMA_RECV: case NCCL_OFI_RDMA_SEND: case NCCL_OFI_RDMA_RECV_SEGMS: - case NCCL_OFI_RDMA_RX_BUFF: + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + case NCCL_OFI_RDMA_EAGER_RX_BUFF: case NCCL_OFI_RDMA_SEND_CONN: case NCCL_OFI_RDMA_RECV_CONN: case NCCL_OFI_RDMA_RECV_CONN_RESP: @@ -1910,7 +1917,8 @@ static int process_pending_reqs(nccl_net_ofi_rdma_ep_t *ep) switch (req->type) { case NCCL_OFI_RDMA_WRITE: case NCCL_OFI_RDMA_SEND: - case NCCL_OFI_RDMA_RX_BUFF: + case NCCL_OFI_RDMA_CTRL_RX_BUFF: + case NCCL_OFI_RDMA_EAGER_RX_BUFF: rc = send_progress(req); break; case NCCL_OFI_RDMA_READ: @@ -2290,7 +2298,7 @@ static inline int free_invalid(nccl_net_ofi_rdma_req_t *req, return -EINVAL; } -static inline int free_rx_buff_req(nccl_net_ofi_rdma_req_t *req, +static inline int eager_rx_buff_req_free(nccl_net_ofi_rdma_req_t *req, bool dec_inflight_reqs) { assert(!dec_inflight_reqs); @@ -2303,16 +2311,58 @@ static inline int free_rx_buff_req(nccl_net_ofi_rdma_req_t *req, return free_base_req(NULL, ep->rx_buff_reqs_fl, req, false); } -static inline nccl_net_ofi_rdma_req_t *alloc_rx_buff_req(nccl_net_ofi_rdma_ep_t *ep, - nccl_net_ofi_ep_rail_t *rail) +static inline nccl_net_ofi_rdma_req_t *eager_rx_buff_req_alloc(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) { nccl_net_ofi_rdma_req_t *req = allocate_req(ep->rx_buff_reqs_fl); if (!req) return NULL; req->comm = NULL; - req->type = NCCL_OFI_RDMA_RX_BUFF; + req->type = NCCL_OFI_RDMA_EAGER_RX_BUFF; req->dev_id = rdma_endpoint_get_device(ep)->base.dev_id; - req->free = free_rx_buff_req; + req->free = eager_rx_buff_req_free; + + rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req); + + nccl_ofi_freelist_elem_t *rx_buff_fl_elem = + nccl_ofi_freelist_entry_alloc(ep->rx_buff_fl); + if (!rx_buff_fl_elem) { + NCCL_OFI_WARN("Failed to allocate rx_buff_fl_elem"); + req->free(req, false); + return NULL; + } + assert(NCCL_OFI_IS_PTR_ALIGNED(rx_buff_fl_elem->ptr, EAGER_RX_BUFFER_ALIGNMENT)); + + rx_buff_data->rx_buff_fl_elem = rx_buff_fl_elem; + rx_buff_data->buff_len = ep->rx_buff_size; + rx_buff_data->rail = rail; + rx_buff_data->ep = ep; + return req; +} + +static inline int ctrl_rx_buff_req_free(nccl_net_ofi_rdma_req_t *req, + bool dec_inflight_reqs) +{ + assert(!dec_inflight_reqs); + rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req); + nccl_net_ofi_rdma_ep_t *ep = rx_buff_data->ep; + /* Free buffer */ + if (rx_buff_data->rx_buff_fl_elem) { + nccl_ofi_freelist_entry_free(ep->rx_buff_fl, rx_buff_data->rx_buff_fl_elem); + } + return free_base_req(NULL, ep->rx_buff_reqs_fl, req, false); +} + +static inline nccl_net_ofi_rdma_req_t *ctrl_rx_buff_req_alloc(nccl_net_ofi_rdma_ep_t *ep, + nccl_net_ofi_ep_rail_t *rail) +{ + nccl_net_ofi_rdma_req_t *req = allocate_req(ep->rx_buff_reqs_fl); + if (!req) return NULL; + + req->comm = NULL; + req->type = NCCL_OFI_RDMA_CTRL_RX_BUFF; + req->dev_id = rdma_endpoint_get_device(ep)->base.dev_id; + req->free = ctrl_rx_buff_req_free; rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req); @@ -5551,7 +5601,8 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req) // Successfully sent the xfer with this rail rma_op_data->xferred_rail_id++; } - } else if (req->type == NCCL_OFI_RDMA_RX_BUFF) { // Post rx Buffer + } else if (req->type == NCCL_OFI_RDMA_CTRL_RX_BUFF || + req->type == NCCL_OFI_RDMA_EAGER_RX_BUFF) { // Post rx Buffer rdma_req_rx_buff_data_t *rx_buff_data = get_rx_buff_data(req); /* Get ep rail information to xfer the req */ assert(rx_buff_data->rail != NULL); @@ -6170,7 +6221,7 @@ static inline int init_rx_buffers(nccl_net_ofi_rdma_ep_t *ep) ); rail->num_rx_buff_posted = 0; nccl_net_ofi_mutex_init(&rail->rx_buff_mutex, NULL); - rail->rx_buff_req_alloc = alloc_rx_buff_req; + rail->rx_buff_req_alloc = ctrl_rx_buff_req_alloc; } for (int rail_id = 0; rail_id < ep->num_rails; ++rail_id) { @@ -6183,7 +6234,7 @@ static inline int init_rx_buffers(nccl_net_ofi_rdma_ep_t *ep) ); rail->num_rx_buff_posted = 0; nccl_net_ofi_mutex_init(&rail->rx_buff_mutex, NULL); - rail->rx_buff_req_alloc = alloc_rx_buff_req; + rail->rx_buff_req_alloc = eager_rx_buff_req_alloc; } return ret;