Skip to content

Commit

Permalink
rdma: Support early completion of recv() requests
Browse files Browse the repository at this point in the history
  • Loading branch information
yexiang-aws committed Feb 28, 2025
1 parent 1ce9f19 commit a635123
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 22 deletions.
5 changes: 5 additions & 0 deletions include/nccl_ofi_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,11 @@ OFI_NCCL_PARAM_INT(use_low_lat_tc, "USE_LOW_LATENCY_TC", 1);
*/
OFI_NCCL_PARAM_INT(force_num_rails, "FORCE_NUM_RAILS", 0);

/*
* 1 to enable early completion, 0 to disable it. It also requires the underlying provider works in FI_PROGRESS_AUTO data progress model
*/
OFI_NCCL_PARAM_INT(is_early_completion_enabled, "EARLY_COMPL", 0);

#ifdef __cplusplus
} // End extern "C"
#endif
Expand Down
3 changes: 3 additions & 0 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ enum nccl_ofi_rdma_msg_type {
NCCL_OFI_RDMA_MSG_CTRL,
NCCL_OFI_RDMA_MSG_EAGER,
NCCL_OFI_RDMA_MSG_CLOSE,
NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE,
NCCL_OFI_RDMA_MSG_INVALID = 15,
NCCL_OFI_RDMA_MSG_MAX = NCCL_OFI_RDMA_MSG_INVALID,
};
Expand Down Expand Up @@ -264,6 +265,8 @@ typedef struct {
/* Total number of completions. Expect one completion for receiving the
* control message and one completion for each send segment. */
int total_num_compls;
/* True to use fi_write instead of fi_writedata in send() */
bool use_fi_write;
#if HAVE_NVTX_TRACING
nvtxRangeId_t trace_id;
nvtxRangeId_t seg_trace_id[MAX_NUM_RAILS];
Expand Down
8 changes: 0 additions & 8 deletions src/nccl_ofi_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -789,14 +789,6 @@ ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
return check_return(ncclInvalidArgument);
}

/*
* Reset to NULL for now until optional receive completion logic is
* implemented
*/
if (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) {
*request = NULL;
}

ncclResult_t validation_result = msg_length_verify_max_size(sizes, n);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
Expand Down
81 changes: 67 additions & 14 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ static bool is_max_write_inline_size_initialized = false;
/* CPU cache line size */
static ssize_t cpu_cache_line_size;

static bool early_completion = true;

/* Function prototypes */
static int send_progress(nccl_net_ofi_rdma_req_t *req);

Expand Down Expand Up @@ -978,6 +980,10 @@ static inline int update_send_data_from_remote(nccl_net_ofi_rdma_send_comm_t *s_
send_data->wdata =
GET_RDMA_WRITE_IMM_DATA(s_comm->remote_comm_id, req->msg_seq_num, send_data->schedule->num_xfer_infos);

send_data->use_fi_write = false;
if (ctrl_msg->type == NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE) {
send_data->use_fi_write = true;
}
return 0;
}

Expand Down Expand Up @@ -1378,6 +1384,8 @@ static inline int handle_rx_buff_recv(nccl_net_ofi_rdma_device_t *device, int ra
goto exit;
}
break;
case NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE:
/* fall through to NCCL_OFI_RDMA_MSG_CTRL case */
case NCCL_OFI_RDMA_MSG_CTRL:
/* CTRL receive completion */
assert(cq_entry->len == nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys));
Expand Down Expand Up @@ -3344,7 +3352,8 @@ static inline int insert_send_ctrl_req(
int dev_id, uint16_t msg_seq_num, void *buff,
size_t size,
nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle,
nccl_net_ofi_rdma_req_t *recv_req)
nccl_net_ofi_rdma_req_t *recv_req,
bool recv_completion_optional)
{
nccl_net_ofi_scheduler_t *scheduler = device->scheduler;
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
Expand Down Expand Up @@ -3406,7 +3415,8 @@ static inline int insert_send_ctrl_req(

nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = rdma_send_ctrl_get_msg(send_ctrl_data);

ctrl_msg->type = NCCL_OFI_RDMA_MSG_CTRL;
/* If early completion is turned on, CTRL msg type will be NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE to influence send() behavior */
ctrl_msg->type = recv_completion_optional ? NCCL_OFI_RDMA_MSG_CTRL_FI_WRITE : NCCL_OFI_RDMA_MSG_CTRL;
ctrl_msg->remote_comm_id = r_comm->remote_comm_id;
ctrl_msg->msg_seq_num = msg_seq_num;
ctrl_msg->buff_addr = (uint64_t)buff;
Expand Down Expand Up @@ -3482,7 +3492,8 @@ static inline int allocate_rdma_recv_req(
int dev_id, uint16_t msg_seq_num, void *buff,
size_t size,
nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle,
nccl_net_ofi_rdma_req_t **ret_req)
nccl_net_ofi_rdma_req_t **ret_req,
bool recv_completion_optional)
{
int ret = 0;
rdma_req_recv_data_t *recv_data;
Expand All @@ -3503,14 +3514,15 @@ static inline int allocate_rdma_recv_req(
req->msg_seq_num = msg_seq_num;

recv_data = get_recv_data(req);
recv_data->total_num_compls = 2;
/* In the case of early completion, only expect the completion for control msg itself */
recv_data->total_num_compls = recv_completion_optional ? 1 : 2;
recv_data->eager_copy_req = NULL;
recv_data->dst_buff = buff;
recv_data->dst_len = size;
recv_data->dest_mr_handle = buff_mr_handle;

/* TODO consolidate arguments to insert_send_ctrl_req and insert_recv_segms_req */
ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req);
ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req, recv_completion_optional);
if (ret) {
NCCL_OFI_WARN("Failed to insert send ctrl request into recv request");
return ret;
Expand Down Expand Up @@ -3606,9 +3618,14 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
nccl_net_ofi_rdma_mr_handle_t **mr_handles = (nccl_net_ofi_rdma_mr_handle_t **)mhandles;
uint16_t msg_seq_num = 0;
bool eager = false;
bool recv_completion_optional = false;

assert(r_comm != NULL);

if (early_completion && *base_req == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) {
recv_completion_optional = true;
}

if (r_comm->comm_active == false) {
NCCL_OFI_WARN("Called irecv on inactive communicator");
ret = -EINVAL;
Expand Down Expand Up @@ -3677,7 +3694,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,

ret = allocate_rdma_recv_req(r_comm, device, dev_id, msg_seq_num,
buffers[0], sizes[0],
mr_handles[0], &req);
mr_handles[0], &req, recv_completion_optional);
if (ret != 0) {
goto error;
}
Expand Down Expand Up @@ -5499,7 +5516,8 @@ static int post_rma_write(nccl_net_ofi_rdma_req_t *req)

static int post_rdma_write(nccl_net_ofi_rdma_req_t *req,
nccl_net_ofi_rdma_send_comm_rail_t *comm_rail,
nccl_net_ofi_xfer_info_t *xfer_info)
nccl_net_ofi_xfer_info_t *xfer_info,
bool use_fi_write)
{
rdma_req_send_data_t *send_data = get_send_data(req);
assert(xfer_info->rail_id < send_data->buff_mr_handle->num_rails);
Expand All @@ -5509,12 +5527,19 @@ static int post_rdma_write(nccl_net_ofi_rdma_req_t *req,

ssize_t rc;
/* Post RDMA write */
rc = fi_writedata(comm_rail->local_ep, (void*)((uintptr_t)send_data->buff + xfer_info->offset),
xfer_info->msg_size, desc, send_data->wdata,
comm_rail->remote_addr,
send_data->remote_buff + xfer_info->offset,
send_data->remote_mr_key[rail_id], (void *)&req->ctx[rail_id]);

if (use_fi_write) {
rc = fi_write(comm_rail->local_ep, (void*)((uintptr_t)send_data->buff + xfer_info->offset),
xfer_info->msg_size, desc,
comm_rail->remote_addr,
send_data->remote_buff + xfer_info->offset,
send_data->remote_mr_key[rail_id], (void *)&req->ctx[rail_id]);
} else {
rc = fi_writedata(comm_rail->local_ep, (void*)((uintptr_t)send_data->buff + xfer_info->offset),
xfer_info->msg_size, desc, send_data->wdata,
comm_rail->remote_addr,
send_data->remote_buff + xfer_info->offset,
send_data->remote_mr_key[rail_id], (void *)&req->ctx[rail_id]);
}
if ((rc != 0) && (rc != -FI_EAGAIN)) {
NCCL_OFI_WARN("fi_writedata failed; RC: %zd, Error: %s",
rc, fi_strerror(-rc));
Expand Down Expand Up @@ -5643,7 +5668,7 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req)
nccl_net_ofi_rdma_send_comm_rail_t *comm_rail =
rdma_send_comm_get_rail(s_comm, xfer_info->rail_id);

ret = post_rdma_write(req, comm_rail, xfer_info);
ret = post_rdma_write(req, comm_rail, xfer_info, send_data->use_fi_write);

if (ret == 0) // Successfully sent the xfer with this rail
send_data->xferred_rail_id++;
Expand Down Expand Up @@ -8115,6 +8140,34 @@ int nccl_net_ofi_rdma_init(const char *provider_filter,
goto error;
}

/*
* NCCL Net v9 API Optimization for LL/LL128 Protocols
*
* Background:
* When using LL (Low Latency) or LL128 protocols, NCCL sets the request pointer
* to NCCL_NET_OPTIONAL_RECV_COMPLETION in irecv() calls. This indicates that
* the plugin can complete a receiver request early without plugin explicitly
* polling the CQ to validate data arrival. This is achievable because NCCL itself
* following LL protocol semantics will validate data arrival by checking the flag bytes.
*
* Plugin Optimization Details:
* 1. Receiver Side:
* - Marks request completion immediately after CTRL message send completion
* - Does not wait for RDMA write operation completion
*
* 2. Sender Side:
* - Uses fi_write instead of fi_writedata, to eliminate unnecessary CQ entries on RX side
*
* Requirements:
* - Provider must use FI_PROGRESS_AUTO data progress model
*/
if (ofi_nccl_is_early_completion_enabled() && !data_progress_auto) {
NCCL_OFI_WARN("Early completion enablement failed due to provider data progress model is not FI_PROGRESS_AUTO");
ret = -ENOTSUP;
goto error;
}
early_completion = ofi_nccl_is_early_completion_enabled();

/* Create NCCL OFI topology */
topo = nccl_ofi_topo_create(provider_list);
if (!topo) {
Expand Down

0 comments on commit a635123

Please sign in to comment.