Skip to content

Commit

Permalink
rdma: add option to round robin the ctrl msg
Browse files Browse the repository at this point in the history
PR aws#543 has moved the control message to its own dedicated
endpoint on a single rail. As a result, the control message is not sent on
all rails in a round-robin fashion anymore. This has impacted performance
in some cases, so this is adding an environment variable to optionally enable
the separate control message endpoint, but for now we use as a default the old
behavior of sending the control message round-robining across rails.

Signed-off-by: Amedeo Sapio <asapio@amazon.com>
  • Loading branch information
AmedeoSapio committed Oct 18, 2024
1 parent b93dbdf commit 5c6b2cb
Show file tree
Hide file tree
Showing 3 changed files with 127 additions and 43 deletions.
6 changes: 6 additions & 0 deletions include/nccl_ofi_param.h
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,12 @@ OFI_NCCL_PARAM_INT(rdma_min_posted_bounce_buffers, "RDMA_MIN_POSTED_BOUNCE_BUFFE
*/
OFI_NCCL_PARAM_INT(rdma_max_posted_bounce_buffers, "RDMA_MAX_POSTED_BOUNCE_BUFFERS", 128);

/*
* Whether to spread the control message across multiple rails in round robin fashion or
* send it consistenly on one rail with a dedicated endpoint.
*/
OFI_NCCL_PARAM_INT(rdma_rr_ctrl_msg, "RR_CTRL_MSG", 1);

/*
* Internode network latency reported to NCCL. Defaults to 0, unless the configured
* platform sets a specific value.
Expand Down
4 changes: 4 additions & 0 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,10 @@ typedef struct {
typedef struct {
/* Pointer to the allocated control buffer from freelist */
nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item;
/* Schedule used to transfer the control buffer. We save the
* pointer to reference it when transferring the buffer over
* network. */
nccl_net_ofi_schedule_t *ctrl_schedule;
/* Pointer to recv parent request */
nccl_net_ofi_rdma_req_t *recv_req;
#if HAVE_NVTX_TRACING
Expand Down
160 changes: 117 additions & 43 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -1925,9 +1925,11 @@ static int ofi_process_cq(nccl_net_ofi_rdma_ep_t *ep)
}
}

ret = ofi_process_cq_rail(ep, &ep->control_rail);
if (ret != 0) {
goto exit;
if (!ofi_nccl_rdma_rr_ctrl_msg()) {
ret = ofi_process_cq_rail(ep, &ep->control_rail);
if (ret != 0) {
goto exit;
}
}

/* Process any pending requests */
Expand Down Expand Up @@ -2114,6 +2116,12 @@ static inline int free_send_ctrl_req(nccl_net_ofi_rdma_req_t *req,
(nccl_net_ofi_rdma_recv_comm_t *)req->comm;
rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req);

if (send_ctrl_data->ctrl_schedule != NULL) {
nccl_net_ofi_rdma_device_t *device = (nccl_net_ofi_rdma_device_t *)req->comm->ep->device;
nccl_net_ofi_release_schedule(device->scheduler, send_ctrl_data->ctrl_schedule);
send_ctrl_data->ctrl_schedule = NULL;
}

if (send_ctrl_data->ctrl_fl_item) {
nccl_ofi_freelist_entry_free(r_comm->ctrl_buff_fl, send_ctrl_data->ctrl_fl_item);
send_ctrl_data->ctrl_fl_item = NULL;
Expand Down Expand Up @@ -2309,10 +2317,12 @@ static inline int post_bounce_buffs(nccl_net_ofi_rdma_ep_t *ep)
}
}

ret = post_bounce_buffs_on_rail(ep, &ep->control_rail);
if (ret != 0) {
NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)");
goto exit;
if (!ofi_nccl_rdma_rr_ctrl_msg()) {
ret = post_bounce_buffs_on_rail(ep, &ep->control_rail);
if (ret != 0) {
NCCL_OFI_WARN("Failed call to post_bounce_buffs_on_rail(control_rail)");
goto exit;
}
}

exit:
Expand Down Expand Up @@ -2613,11 +2623,13 @@ static int dereg_rails(nccl_net_ofi_rdma_mr_handle_t *handle)
int num_rails = handle->num_rails;

/* Cleanup memory registration for control */
rc = fi_close(&handle->control_mr->fid);
if (OFI_UNLIKELY(rc != 0)) {
NCCL_OFI_WARN("Unable to de-register memory on control mr. RC: %d, Error: %s",
rc, fi_strerror(-rc));
ret = rc;
if (handle->control_mr != NULL) {
rc = fi_close(&handle->control_mr->fid);
if (OFI_UNLIKELY(rc != 0)) {
NCCL_OFI_WARN("Unable to de-register memory on control mr. RC: %d, Error: %s",
rc, fi_strerror(-rc));
ret = rc;
}
}

/* Cleanup memory registration for data rails */
Expand Down Expand Up @@ -2754,13 +2766,15 @@ static inline int reg_mr_on_device(nccl_net_ofi_rdma_ep_t *ep,
goto exit;
}

ret = register_rail_mr_buffer(ep->control_rail.domain, ep->control_rail.ofi_ep,
-1, type, &mr_attr, regattr_flags,
&ret_handle->control_mr);
if (OFI_UNLIKELY(ret != 0)) {
free(ret_handle);
ret_handle = NULL;
goto exit;
if (!ofi_nccl_rdma_rr_ctrl_msg()) {
ret = register_rail_mr_buffer(ep->control_rail.domain, ep->control_rail.ofi_ep,
-1, type, &mr_attr, regattr_flags,
&ret_handle->control_mr);
if (OFI_UNLIKELY(ret != 0)) {
free(ret_handle);
ret_handle = NULL;
goto exit;
}
}

/* Register memory on each rail */
Expand Down Expand Up @@ -3071,6 +3085,7 @@ static inline int insert_send_ctrl_req(
nccl_net_ofi_rdma_mr_handle_t *buff_mr_handle,
nccl_net_ofi_rdma_req_t *recv_req)
{
nccl_net_ofi_scheduler_t *scheduler = device->scheduler;
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;
nccl_net_ofi_rdma_req_t *send_ctrl_req = allocate_req(r_comm->nccl_ofi_reqs_fl);
if (OFI_UNLIKELY(send_ctrl_req == NULL)) {
Expand All @@ -3087,6 +3102,24 @@ static inline int insert_send_ctrl_req(

rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(send_ctrl_req);

if (ofi_nccl_rdma_rr_ctrl_msg()) {
size_t ctrl_msg_len = nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys);
send_ctrl_data->ctrl_schedule = scheduler->get_schedule(scheduler, ctrl_msg_len, device->num_rails);

if (OFI_UNLIKELY(!(send_ctrl_data->ctrl_schedule))) {
return -EINVAL;
} else if (OFI_UNLIKELY(send_ctrl_data->ctrl_schedule->num_xfer_infos != 1)) {
NCCL_OFI_WARN(
"Invalid schedule for outgoing control message (%zu bytes). Expected one rail, but got "
"%zu",
size,
send_ctrl_data->ctrl_schedule->num_xfer_infos);
return -EINVAL;
}
} else {
send_ctrl_data->ctrl_schedule = NULL;
}

send_ctrl_data->recv_req = recv_req;
send_ctrl_data->ctrl_fl_item = NULL;

Expand Down Expand Up @@ -4552,7 +4585,13 @@ static int post_send_conn_resp(nccl_net_ofi_rdma_recv_comm_t *r_comm,
nccl_net_ofi_rdma_req_t *req)
{
ssize_t rc = 0;
nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail;;
nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail;

if (ofi_nccl_rdma_rr_ctrl_msg()) {
comm_rail = rdma_recv_comm_get_rail(r_comm, 0);
} else {
comm_rail = &r_comm->control_rail;
}

req->state = NCCL_OFI_RDMA_REQ_PENDING;
rc = fi_send(comm_rail->local_ep, (void *)conn_resp, sizeof(nccl_ofi_rdma_connection_info_t), NULL,
Expand Down Expand Up @@ -4864,14 +4903,24 @@ static int listen(nccl_net_ofi_ep_t *base_ep,

/* Build handle */
memset(handle, 0, sizeof(nccl_net_ofi_conn_handle_t));
assert(sizeof(handle->ep_name) == sizeof(ep->control_rail.local_ep_name));
memcpy(handle->ep_name, ep->control_rail.local_ep_name,
ep->control_rail.local_ep_name_len);

/* We don't copy the size here since the handle doesn't have a size field.
The size will be distributed later by the connect response message.
Instead, zero the unused bytes here. */
memset(handle->ep_name + ep->control_rail.local_ep_name_len, 0,
sizeof(handle->ep_name) - ep->control_rail.local_ep_name_len);
if (!ofi_nccl_rdma_rr_ctrl_msg()) {
assert(sizeof(handle->ep_name) == sizeof(ep->control_rail.local_ep_name));
memcpy(handle->ep_name, ep->control_rail.local_ep_name, ep->control_rail.local_ep_name_len);
memset(handle->ep_name + ep->control_rail.local_ep_name_len,
0,
sizeof(handle->ep_name) - ep->control_rail.local_ep_name_len);
} else {
nccl_net_ofi_ep_rail_t *first_rail = rdma_endpoint_get_rail(ep, 0);
assert(sizeof(handle->ep_name) == sizeof(first_rail->local_ep_name));
memcpy(handle->ep_name, first_rail->local_ep_name, first_rail->local_ep_name_len);
memset(handle->ep_name + ep->control_rail.local_ep_name_len,
0,
sizeof(handle->ep_name) - ep->control_rail.local_ep_name_len);
}

/* Build listen_comm */
l_comm = (nccl_net_ofi_rdma_listen_comm_t *)calloc(1,
Expand Down Expand Up @@ -5265,9 +5314,9 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req)
nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm;
rdma_req_send_ctrl_data_t *send_ctrl_data = get_send_ctrl_data(req);
nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep;

// Get communicator rail information to xfer the req
nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = &r_comm->control_rail;
nccl_net_ofi_schedule_t *schedule = send_ctrl_data->ctrl_schedule;
nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail;
void *desc;

nccl_net_ofi_rdma_ctrl_fl_item_t *ctrl_fl_item = send_ctrl_data->ctrl_fl_item;

Expand All @@ -5276,9 +5325,19 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req)
(freelist_regmr_fn_handle_t *)ctrl_fl_item->fl_reginfo.mr_handle;
nccl_net_ofi_rdma_mr_handle_t *mr_handle = fl_handle->mr_handle;

void *desc = fi_mr_desc(mr_handle->control_mr);

NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, ep->control_rail.rail_id, req->comm, req, req->msg_seq_num);
if (schedule != NULL) {
/* Use round robin schedule for ctrl message */
nccl_net_ofi_xfer_info_t *xfer_info = &schedule->rail_xfer_infos[0];
comm_rail = rdma_recv_comm_get_rail(r_comm, xfer_info->rail_id);
assert(xfer_info->rail_id < mr_handle->num_rails);
desc = fi_mr_desc(mr_handle->mr[xfer_info->rail_id]);
NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, xfer_info->rail_id, req->comm, req, req->msg_seq_num);
} else {
/* Use control QP for ctrl message */
comm_rail = &r_comm->control_rail;
desc = fi_mr_desc(mr_handle->control_mr);
NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, ep->control_rail.rail_id, req->comm, req, req->msg_seq_num);
}

size_t ctrl_msg_len = nccl_net_ofi_rdma_ctrl_msg_size(ep->num_rails, ep->use_long_rkeys);

Expand Down Expand Up @@ -5577,15 +5636,17 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t
goto error;
}

/* look for control messages and then retry the message search
to avoid unnecessary polling / queueing. */
if (OFI_UNLIKELY(!polled_cq && !have_ctrl)) {
ret = ofi_process_cq_rail(ep, &ep->control_rail);
if (ret != 0) {
goto error;
if (!ofi_nccl_rdma_rr_ctrl_msg()) {
/* look for control messages and then retry the message search
to avoid unnecessary polling / queueing. */
if (OFI_UNLIKELY(!polled_cq && !have_ctrl)) {
ret = ofi_process_cq_rail(ep, &ep->control_rail);
if (ret != 0) {
goto error;
}
polled_cq = true;
goto retry;
}
polled_cq = true;
goto retry;
}

/* Determine if this should be sent eagerly. */
Expand Down Expand Up @@ -6021,9 +6082,16 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
nccl_net_ofi_rdma_send_comm_t *ret_s_comm = NULL;
int num_rails = ep->num_rails;
int rail_id = 0;
nccl_net_ofi_ep_rail_t *control_rail = &ep->control_rail;
nccl_net_ofi_ep_rail_t *control_rail;

*s_comm = NULL;

if (ofi_nccl_rdma_rr_ctrl_msg()) {
control_rail = rdma_endpoint_get_rail(ep, 0);
} else {
control_rail = &ep->control_rail;
}

/* Retrieve and validate device */
nccl_net_ofi_rdma_device_t *device = rdma_endpoint_get_device(ep);
if (OFI_UNLIKELY(device == NULL)) {
Expand Down Expand Up @@ -6098,12 +6166,18 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
goto error;
}

/* Store remote address of first rail in communicator */
ret_s_comm->control_rail.remote_addr = remote_addr;
if (ofi_nccl_rdma_rr_ctrl_msg()) {
/* Store remote address of first rail in communicator */
ret_s_comm->rails[0].remote_addr = remote_addr;
/* Store local libfabric endpoint of first rail */
ret_s_comm->rails[0].local_ep = control_rail->ofi_ep;
ret_s_comm->num_init_rails = 1;
}

/* Store remote address of control rail in communicator */
ret_s_comm->control_rail.remote_addr = remote_addr;
/* Store local libfabric endpoint of control rail */
ret_s_comm->control_rail.local_ep = control_rail->ofi_ep;
ret_s_comm->num_init_rails = 0;

/* Allocate request free list */
ret = nccl_ofi_freelist_init(sizeof(nccl_net_ofi_rdma_req_t), 16, 16,
Expand Down

0 comments on commit 5c6b2cb

Please sign in to comment.