Skip to content

Commit

Permalink
rdma: Support early completion of recv() requests
Browse files Browse the repository at this point in the history
  • Loading branch information
yexiang-aws committed Feb 28, 2025
1 parent 1ce9f19 commit 892a92f
Show file tree
Hide file tree
Showing 4 changed files with 85 additions and 22 deletions.
8 changes: 8 additions & 0 deletions include/nccl_ofi_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -396,6 +396,14 @@ OFI_NCCL_PARAM_INT(use_low_lat_tc, "USE_LOW_LATENCY_TC", 1);
*/
OFI_NCCL_PARAM_INT(force_num_rails, "FORCE_NUM_RAILS", 0);

/*
* 1 to enable early completion, 0 to disable it.
* Default at -1 to follow the data progress model, given that
* early completion feature is contigent on FI_PROGRESS_AUTO data progress model
* i.e. enabled when FI_PROGRESS_AUTO, otherwise disabled
*/
OFI_NCCL_PARAM_INT(early_completion, "EARLY_COMPLETION", -1);

#ifdef __cplusplus
} // End extern "C"
#endif
Expand Down
3 changes: 3 additions & 0 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ enum nccl_ofi_rdma_msg_type {
NCCL_OFI_RDMA_MSG_CTRL,
NCCL_OFI_RDMA_MSG_EAGER,
NCCL_OFI_RDMA_MSG_CLOSE,
NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION,
NCCL_OFI_RDMA_MSG_INVALID = 15,
NCCL_OFI_RDMA_MSG_MAX = NCCL_OFI_RDMA_MSG_INVALID,
};
Expand Down Expand Up @@ -264,6 +265,8 @@ typedef struct {
/* Total number of completions. Expect one completion for receiving the
* control message and one completion for each send segment. */
int total_num_compls;
/* True to use fi_write instead of fi_writedata in send() */
bool use_fi_write;
#if HAVE_NVTX_TRACING
nvtxRangeId_t trace_id;
nvtxRangeId_t seg_trace_id[MAX_NUM_RAILS];
Expand Down
8 changes: 0 additions & 8 deletions src/nccl_ofi_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -789,14 +789,6 @@ ncclResult_t nccl_net_ofi_irecv_v9(void* recvComm, int n, void** data,
return check_return(ncclInvalidArgument);
}

/*
* Reset to NULL for now until optional receive completion logic is
* implemented
*/
if (*request == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) {
*request = NULL;
}

ncclResult_t validation_result = msg_length_verify_max_size(sizes, n);
if (validation_result != ncclSuccess) {
return check_return(validation_result);
Expand Down
88 changes: 74 additions & 14 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,8 @@ static bool is_max_write_inline_size_initialized = false;
/* CPU cache line size */
static ssize_t cpu_cache_line_size;

static bool early_completion = false;

/* Function prototypes */
static int send_progress(nccl_net_ofi_rdma_req_t *req);

Expand Down Expand Up @@ -978,6 +980,7 @@ static inline int update_send_data_from_remote(nccl_net_ofi_rdma_send_comm_t *s_
send_data->wdata =
GET_RDMA_WRITE_IMM_DATA(s_comm->remote_comm_id, req->msg_seq_num, send_data->schedule->num_xfer_infos);

send_data->use_fi_write = (ctrl_msg->type == NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION);
return 0;
}

Expand Down Expand Up @@ -1378,6 +1381,8 @@ static inline int handle_rx_buff_recv(nccl_net_ofi_rdma_device_t *device, int ra
goto exit;
}
break;
case NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION:
/* fall through to NCCL_OFI_RDMA_MSG_CTRL case */
case NCCL_OFI_RDMA_MSG_CTRL:
/* CTRL receive completion */
assert(cq_entry->len == nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys));
Expand Down Expand Up @@ -3344,7 +3349,8 @@ static inline int insert_send_ctrl_req(
int dev_id, uint16_t msg_seq_num, void *buff,
size_t size,
nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle,
nccl_net_ofi_rdma_req_t *recv_req)
nccl_net_ofi_rdma_req_t *recv_req,
bool recv_completion_optional)
{
nccl_net_ofi_scheduler_t *scheduler = device->scheduler;
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
Expand Down Expand Up @@ -3406,7 +3412,8 @@ static inline int insert_send_ctrl_req(

nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = rdma_send_ctrl_get_msg(send_ctrl_data);

ctrl_msg->type = NCCL_OFI_RDMA_MSG_CTRL;
/* If early completion is turned on, CTRL msg type will be NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION to influence send() behavior */
ctrl_msg->type = recv_completion_optional ? NCCL_OFI_RDMA_MSG_CTRL_NO_COMPLETION : NCCL_OFI_RDMA_MSG_CTRL;
ctrl_msg->remote_comm_id = r_comm->remote_comm_id;
ctrl_msg->msg_seq_num = msg_seq_num;
ctrl_msg->buff_addr = (uint64_t)buff;
Expand Down Expand Up @@ -3482,7 +3489,8 @@ static inline int allocate_rdma_recv_req(
int dev_id, uint16_t msg_seq_num, void *buff,
size_t size,
nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle,
nccl_net_ofi_rdma_req_t **ret_req)
nccl_net_ofi_rdma_req_t **ret_req,
bool recv_completion_optional)
{
int ret = 0;
rdma_req_recv_data_t *recv_data;
Expand All @@ -3503,14 +3511,15 @@ static inline int allocate_rdma_recv_req(
req->msg_seq_num = msg_seq_num;

recv_data = get_recv_data(req);
recv_data->total_num_compls = 2;
/* In the case of early completion, only expect the completion for control msg itself */
recv_data->total_num_compls = recv_completion_optional ? 1 : 2;
recv_data->eager_copy_req = NULL;
recv_data->dst_buff = buff;
recv_data->dst_len = size;
recv_data->dest_mr_handle = buff_mr_handle;

/* TODO consolidate arguments to insert_send_ctrl_req and insert_recv_segms_req */
ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req);
ret = insert_send_ctrl_req(r_comm, device, dev_id, msg_seq_num, buff, size, buff_mr_handle, req, recv_completion_optional);
if (ret) {
NCCL_OFI_WARN("Failed to insert send ctrl request into recv request");
return ret;
Expand Down Expand Up @@ -3606,9 +3615,14 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,
nccl_net_ofi_rdma_mr_handle_t **mr_handles = (nccl_net_ofi_rdma_mr_handle_t **)mhandles;
uint16_t msg_seq_num = 0;
bool eager = false;
bool recv_completion_optional = false;

assert(r_comm != NULL);

if (early_completion && *base_req == (void *)NCCL_NET_OPTIONAL_RECV_COMPLETION) {
recv_completion_optional = true;
}

if (r_comm->comm_active == false) {
NCCL_OFI_WARN("Called irecv on inactive communicator");
ret = -EINVAL;
Expand Down Expand Up @@ -3677,7 +3691,7 @@ static int recv(nccl_net_ofi_recv_comm_t *recv_comm, int n, void **buffers,

ret = allocate_rdma_recv_req(r_comm, device, dev_id, msg_seq_num,
buffers[0], sizes[0],
mr_handles[0], &req);
mr_handles[0], &req, recv_completion_optional);
if (ret != 0) {
goto error;
}
Expand Down Expand Up @@ -5499,7 +5513,8 @@ static int post_rma_write(nccl_net_ofi_rdma_req_t *req)

static int post_rdma_write(nccl_net_ofi_rdma_req_t *req,
nccl_net_ofi_rdma_send_comm_rail_t *comm_rail,
nccl_net_ofi_xfer_info_t *xfer_info)
nccl_net_ofi_xfer_info_t *xfer_info,
bool use_fi_write)
{
rdma_req_send_data_t *send_data = get_send_data(req);
assert(xfer_info->rail_id < send_data->buff_mr_handle->num_rails);
Expand All @@ -5509,12 +5524,19 @@ static int post_rdma_write(nccl_net_ofi_rdma_req_t *req,

ssize_t rc;
/* Post RDMA write */
rc = fi_writedata(comm_rail->local_ep, (void*)((uintptr_t)send_data->buff + xfer_info->offset),
xfer_info->msg_size, desc, send_data->wdata,
comm_rail->remote_addr,
send_data->remote_buff + xfer_info->offset,
send_data->remote_mr_key[rail_id], (void *)&req->ctx[rail_id]);

if (use_fi_write) {
rc = fi_write(comm_rail->local_ep, (void*)((uintptr_t)send_data->buff + xfer_info->offset),
xfer_info->msg_size, desc,
comm_rail->remote_addr,
send_data->remote_buff + xfer_info->offset,
send_data->remote_mr_key[rail_id], (void *)&req->ctx[rail_id]);
} else {
rc = fi_writedata(comm_rail->local_ep, (void*)((uintptr_t)send_data->buff + xfer_info->offset),
xfer_info->msg_size, desc, send_data->wdata,
comm_rail->remote_addr,
send_data->remote_buff + xfer_info->offset,
send_data->remote_mr_key[rail_id], (void *)&req->ctx[rail_id]);
}
if ((rc != 0) && (rc != -FI_EAGAIN)) {
NCCL_OFI_WARN("fi_writedata failed; RC: %zd, Error: %s",
rc, fi_strerror(-rc));
Expand Down Expand Up @@ -5643,7 +5665,7 @@ static int send_progress(nccl_net_ofi_rdma_req_t *req)
nccl_net_ofi_rdma_send_comm_rail_t *comm_rail =
rdma_send_comm_get_rail(s_comm, xfer_info->rail_id);

ret = post_rdma_write(req, comm_rail, xfer_info);
ret = post_rdma_write(req, comm_rail, xfer_info, send_data->use_fi_write);

if (ret == 0) // Successfully sent the xfer with this rail
send_data->xferred_rail_id++;
Expand Down Expand Up @@ -8115,6 +8137,44 @@ int nccl_net_ofi_rdma_init(const char *provider_filter,
goto error;
}

/*
* NCCL Net v9 API Optimization for LL/LL128 Protocols
*
* Background:
* When using LL (Low Latency) or LL128 protocols, NCCL sets the request pointer
* to NCCL_NET_OPTIONAL_RECV_COMPLETION in irecv() calls. This indicates that
* the plugin can complete a receiver request early without plugin explicitly
* polling the CQ to validate data arrival. This is achievable because NCCL itself
* following LL protocol semantics will validate data arrival by checking the flag bytes.
*
* Plugin Optimization Details:
* 1. Receiver Side:
* - Marks request completion immediately after CTRL message send completion
* - Does not wait for RDMA write operation completion
*
* 2. Sender Side:
* - Uses fi_write instead of fi_writedata, to eliminate unnecessary CQ entries on RX side
*
* Requirements:
* - Provider must use FI_PROGRESS_AUTO data progress model
*/
switch (ofi_nccl_early_completion()) {
case -1:
early_completion = data_progress_auto;
break;
case 1:
if (!data_progress_auto) {
NCCL_OFI_WARN("Early completion enablement failed due to provider data progress model is not FI_PROGRESS_AUTO");
ret = -ENOTSUP;
goto error;
}
early_completion = true;
break;
case 0:
early_completion = false;
break;
}

/* Create NCCL OFI topology */
topo = nccl_ofi_topo_create(provider_list);
if (!topo) {
Expand Down

0 comments on commit 892a92f

Please sign in to comment.