diff --git a/include/nccl_ofi_param.h b/include/nccl_ofi_param.h index 023f29629..a527aaf0c 100644 --- a/include/nccl_ofi_param.h +++ b/include/nccl_ofi_param.h @@ -396,6 +396,11 @@ OFI_NCCL_PARAM_INT(use_low_lat_tc, "USE_LOW_LATENCY_TC", 1); */ OFI_NCCL_PARAM_INT(force_num_rails, "FORCE_NUM_RAILS", 0); +/* + * 1 to enable early completion, 0 to disable it. It also requires the underlying provider works in FI_PROGRESS_AUTO data progress model + */ +OFI_NCCL_PARAM_INT(is_early_completion_enabled, "EARLY_COMPL", 0); + #ifdef __cplusplus } // End extern "C" #endif diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 7e317fb56..1145cda9d 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -104,6 +104,7 @@ enum nccl_ofi_rdma_msg_type { NCCL_OFI_RDMA_MSG_CTRL, NCCL_OFI_RDMA_MSG_EAGER, NCCL_OFI_RDMA_MSG_CLOSE, + NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE, NCCL_OFI_RDMA_MSG_INVALID = 15, NCCL_OFI_RDMA_MSG_MAX = NCCL_OFI_RDMA_MSG_INVALID, }; @@ -264,6 +265,8 @@ typedef struct { /* Total number of completions. Expect one completion for receiving the * control message and one completion for each send segment. */ int total_num_compls; + /* True to use fi_write instead of fi_writedata in send() */ + bool use_fi_write; #if HAVE_NVTX_TRACING nvtxRangeId_t trace_id; nvtxRangeId_t seg_trace_id[MAX_NUM_RAILS]; diff --git a/src/nccl_ofi_api.c b/src/nccl_ofi_api.c index a77ac2678..6e9835e31 100644 --- a/src/nccl_ofi_api.c +++ b/src/nccl_ofi_api.c @@ -789,14 +789,6 @@ ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data, return check_return(ncclInvalidArgument); } - /* - * Reset to NULL for now until optional receive completion logic is - * implemented - */ - if (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) { - *request = NULL; - } - ncclResult_t validation_result = msg_length_verify_max_size(sizes, n); if (validation_result != ncclSuccess) { return check_return(validation_result); diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 36f11cff9..a1914139e 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -107,6 +107,8 @@ static bool is_max_write_inline_size_initialized = false; /* CPU cache line size */ static ssize_t cpu_cache_line_size; +static bool early_completion = true; + /* Function prototypes */ static int send_progress(nccl_net_ofi_rdma_req_t *req); @@ -978,6 +980,10 @@ static inline int update_send_data_from_remote(nccl_net_ofi_rdma_send_comm_t *s_ send_data->wdata = GET_RDMA_WRITE_IMM_DATA(s_comm->remote_comm_id, req->msg_seq_num, send_data->schedule->num_xfer_infos); + send_data->use_fi_write = false; + if (ctrl_msg->type == NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE) { + send_data->use_fi_write = true; + } return 0; } @@ -1378,6 +1384,8 @@ static inline int handle_rx_buff_recv(nccl_net_ofi_rdma_device_t *device, int ra goto exit; } break; + case NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE: + /* fall through to NCCL_OFI_RDMA_MSG_CTRL case */ case NCCL_OFI_RDMA_MSG_CTRL: /* CTRL receive completion */ assert(cq_entry->len == nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys)); @@ -3344,7 +3352,8 @@ static inline int insert_send_ctrl_req( int dev_id, uint16_t msg_seq_num, void *buff, size_t size, nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, - nccl_net_ofi_rdma_req_t *recv_req) + nccl_net_ofi_rdma_req_t *recv_req, + bool recv_completion_optional) { nccl_net_ofi_scheduler_t *scheduler = device->scheduler; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; @@ -3406,7 +3415,8 @@ static inline int insert_send_ctrl_req( nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = rdma_send_ctrl_get_msg(send_ctrl_data); - ctrl_msg->type = NCCL_OFI_RDMA_MSG_CTRL; + /* from v9 API, if optional completion is set, then set to NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE */ + ctrl_msg->type = recv_completion_optional ? NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE : NCCL_OFI_RDMA_MSG_CTRL; ctrl_msg->remote_comm_id = r_comm->remote_comm_id; ctrl_msg->msg_seq_num = msg_seq_num; ctrl_msg->buff_addr = (uint64_t)buff; @@ -3482,7 +3492,8 @@ static inline int allocate_rdma_recv_req( int dev_id, uint16_t msg_seq_num, void *buff, size_t size, nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle, - nccl_net_ofi_rdma_req_t **ret_req) + nccl_net_ofi_rdma_req_t **ret_req, + bool recv_completion_optional) { int ret = 0; rdma_req_recv_data_t *recv_data; @@ -3503,14 +3514,15 @@ static inline int allocate_rdma_recv_req( req->msg_seq_num = msg_seq_num; recv_data = get_recv_data(req); - recv_data->total_num_compls = 2; + /* In the case of early completion, only needs the completion for control msg itself */ + recv_data->total_num_compls = recv_completion_optional ? 1 : 2; recv_data->eager_copy_req = NULL; recv_data->dst_buff = buff; recv_data->dst_len = size; recv_data->dest_mr_handle = buff_mr_handle; /* TODO consolidate arguments to insert_send_ctrl_req and insert_recv_segms_req */ - ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req); + ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req, recv_completion_optional); if (ret) { NCCL_OFI_WARN("Failed to insert send ctrl request into recv request"); return ret; @@ -3606,9 +3618,20 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, nccl_net_ofi_rdma_mr_handle_t **mr_handles = (nccl_net_ofi_rdma_mr_handle_t **)mhandles; uint16_t msg_seq_num = 0; bool eager = false; + bool recv_completion_optional = false; assert(r_comm != NULL); + /* + * If the provider data progress model is FI_PROGRESS_AUTO, we have an + * opportunity to early-complete a recv() request with a hint from NCCL + * v9 API when it uses LL* protocols. In that case, sender uses fi_write + * to avoid cq entry insertion. + */ + if (early_completion && *base_req == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) { + recv_completion_optional = true; + } + if (r_comm->comm_active == false) { NCCL_OFI_WARN("Called irecv on inactive communicator"); ret = -EINVAL; @@ -3677,7 +3700,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers, ret = allocate_rdma_recv_req(r_comm, device, dev_id, msg_seq_num, buffers[0], sizes[0], - mr_handles[0], &req); + mr_handles[0], &req, recv_completion_optional); if (ret != 0) { goto error; } @@ -5499,7 +5522,8 @@ static int post_rma_write(nccl_net_ofi_rdma_req_t *req) static int post_rdma_write(nccl_net_ofi_rdma_req_t *req, nccl_net_ofi_rdma_send_comm_rail_t *comm_rail, - nccl_net_ofi_xfer_info_t *xfer_info) + nccl_net_ofi_xfer_info_t *xfer_info, + bool use_fi_write) { rdma_req_send_data_t *send_data = get_send_data(req); assert(xfer_info->rail_id < send_data->buff_mr_handle->num_rails); @@ -5509,12 +5533,19 @@ static int post_rdma_write(nccl_net_ofi_rdma_req_t *req, ssize_t rc; /* Post RDMA write */ - rc = fi_writedata(comm_rail->local_ep, (void*)((uintptr_t)send_data->buff + xfer_info->offset), - xfer_info->msg_size, desc, send_data->wdata, - comm_rail->remote_addr, - send_data->remote_buff + xfer_info->offset, - send_data->remote_mr_key[rail_id], (void *)&req->ctx[rail_id]); - + if (use_fi_write) { + rc = fi_write(comm_rail->local_ep, (void*)((uintptr_t)send_data->buff + xfer_info->offset), + xfer_info->msg_size, desc, + comm_rail->remote_addr, + send_data->remote_buff + xfer_info->offset, + send_data->remote_mr_key[rail_id], (void *)&req->ctx[rail_id]); + } else { + rc = fi_writedata(comm_rail->local_ep, (void*)((uintptr_t)send_data->buff + xfer_info->offset), + xfer_info->msg_size, desc, send_data->wdata, + comm_rail->remote_addr, + send_data->remote_buff + xfer_info->offset, + send_data->remote_mr_key[rail_id], (void *)&req->ctx[rail_id]); + } if ((rc != 0) && (rc != -FI_EAGAIN)) { NCCL_OFI_WARN("fi_writedata failed; RC: %zd, Error: %s", rc, fi_strerror(-rc)); @@ -5643,7 +5674,7 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req) nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = rdma_send_comm_get_rail(s_comm, xfer_info->rail_id); - ret = post_rdma_write(req, comm_rail, xfer_info); + ret = post_rdma_write(req, comm_rail, xfer_info, send_data->use_fi_write); if (ret == 0) // Successfully sent the xfer with this rail send_data->xferred_rail_id++; @@ -8115,6 +8146,34 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, goto error; } + /* + * NCCL Net v9 API Optimization for LL/LL128 Protocols + * + * Background: + * When using LL (Low Latency) or LL128 protocols, NCCL sets the request pointer + * to NCCL_NET_OPTIONAL_RECV_COMPLETION in irecv() calls. This indicates that + * the plugin can complete a receiver request early without plugin explicitly + * polling the CQ to validate data arrival. This is achievable because NCCL itself + * following LL protocol semantics will validate data arrival by checking the flag bytes. + * + * Plugin Optimization Details: + * 1. Receiver Side: + * - Marks request completion immediately after CTRL message send completion + * - Does not wait for RDMA write operation completion + * + * 2. Sender Side: + * - Uses fi_write instead of fi_writedata, to eliminate unnecessary CQ entries on RX side + * + * Requirements: + * - Provider must use FI_PROGRESS_AUTO data progress model + */ + if (ofi_nccl_is_early_completion_enabled() && !data_progress_auto) { + NCCL_OFI_WARN("Early completion enablement failed due to provider data progress model is not FI_PROGRESS_AUTO"); + ret = -ENOTSUP; + goto error; + } + early_completion = ofi_nccl_is_early_completion_enabled(); + /* Create NCCL OFI topology */ topo = nccl_ofi_topo_create(provider_list); if (!topo) {