Skip to content

Commit

Permalink
Communicator ID from 64 bits to 32 bits
Browse files Browse the repository at this point in the history
Reduced size of communicator ID (both for SENDRECV and RDMA) to 32 bits, since
the communicator ID is transmitted on the wire and 32 bits are more than enough.

Signed-off-by: Amedeo Sapio <asapio@amazon.com>
  • Loading branch information
AmedeoSapio committed Mar 20, 2024
1 parent 4cdc7cc commit e3f9fd4
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 37 deletions.
2 changes: 1 addition & 1 deletion include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ _Static_assert(sizeof(nccl_ofi_connection_info_t) == 80,

typedef struct nccl_net_ofi_conn_handle {
char ep_name[MAX_EP_ADDR];
uint64_t comm_id;
uint32_t comm_id;
/* Save temporary communicator state when creating send communicator */
save_comm_state_t state;
} nccl_net_ofi_conn_handle_t;
Expand Down
22 changes: 11 additions & 11 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,13 +92,13 @@ typedef struct nccl_net_ofi_rdma_ctrl_msg {

/* A comm identitifer that uniquely identifies the comm
* on the receiver side */
uint64_t remote_comm_id;
uint32_t remote_comm_id;

uint64_t buff_addr;
uint64_t buff_len;
uint64_t buff_mr_key[MAX_NUM_RAILS];
} nccl_net_ofi_rdma_ctrl_msg_t;
_Static_assert(sizeof(nccl_net_ofi_rdma_ctrl_msg_t) == 64,
_Static_assert(sizeof(nccl_net_ofi_rdma_ctrl_msg_t) == 56,
"Wrong size for RDMA Control message");

/* Structure used to store control messages in a free list */
Expand Down Expand Up @@ -325,18 +325,18 @@ typedef struct nccl_ofi_rdma_connection_info {

/* A comm identitifer that uniquely identifies the comm on the sender
side. The receiver must use this ID when sending messages to sender */
uint64_t local_comm_id;
uint32_t local_comm_id;

/* A comm identitifer that uniquely identifies the comm
* on the receiver side */
uint64_t remote_comm_id;
uint32_t remote_comm_id;

/* Array of `MAX_NUM_RAILS` `nccl_ofi_rdma_ep_name_t`
* structs. The member `num_rails` indicates the number of
* entries that are in use. */
nccl_ofi_rdma_ep_name_t ep_names[MAX_NUM_RAILS];
} nccl_ofi_rdma_connection_info_t;
_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 248,
_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 240,
"Wrong size for RDMA connect message");

/*
Expand Down Expand Up @@ -371,9 +371,9 @@ typedef struct nccl_net_ofi_rdma_send_comm {
nccl_ofi_freelist_t *nccl_ofi_reqs_fl;

/* Comm ID provided by the local endpoint */
uint64_t local_comm_id;
uint32_t local_comm_id;
/* Comm ID provided by remote endpoint */
uint64_t remote_comm_id;
uint32_t remote_comm_id;

/* Request to receive connect response message to finalize
* connection establishment */
Expand Down Expand Up @@ -449,9 +449,9 @@ typedef struct nccl_net_ofi_rdma_recv_comm {
nccl_ofi_freelist_t *nccl_ofi_reqs_fl;

/* Comm ID provided by the local endpoint */
uint64_t local_comm_id;
uint32_t local_comm_id;
/* Comm ID provided by remote endpoint */
uint64_t remote_comm_id;
uint32_t remote_comm_id;

/* The flush buffer */
nccl_net_ofi_rdma_flush_buffer_t flush_buff;
Expand All @@ -477,7 +477,7 @@ typedef struct nccl_net_ofi_rdma_listen_comm {
nccl_net_ofi_listen_comm_t base;

/* Comm ID provided by local endpoint */
uint64_t comm_id;
uint32_t comm_id;
struct fid_ep *leader_local_ep;

/* Communicator created while accept routine is executed */
Expand Down Expand Up @@ -653,7 +653,7 @@ typedef struct nccl_net_ofi_rdma_device {
char *prov_name;

/* Maximum number of supported communicator IDs */
uint64_t num_comm_ids;
uint32_t num_comm_ids;

/* Memory registration key pool */
nccl_ofi_idpool_t key_pool;
Expand Down
44 changes: 21 additions & 23 deletions src/nccl_ofi_rdma.c
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,7 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req);
/*
* @brief Get endpoint communicator with given ID
*/
static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep,
int64_t local_comm_id)
static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id)
{
assert(local_comm_id < NCCL_OFI_RDMA_MAX_COMMS);
return ep->comms[local_comm_id];
Expand All @@ -147,7 +146,7 @@ static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep,
* @brief Set endpoint communicator with given ID
*/
static inline void set_comm(nccl_net_ofi_rdma_ep_t *ep,
int64_t local_comm_id,
uint32_t local_comm_id,
nccl_net_ofi_comm_t *comm)
{
assert(local_comm_id < NCCL_OFI_RDMA_MAX_COMMS);
Expand All @@ -157,7 +156,7 @@ static inline void set_comm(nccl_net_ofi_rdma_ep_t *ep,
/*
* @brief Get endpoint listen communicator with given comm_id
*/
static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma_ep_t *ep, uint64_t local_comm_id)
static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id)
{
nccl_net_ofi_rdma_listen_comm_t *l_comm = (nccl_net_ofi_rdma_listen_comm_t *)get_comm(ep, local_comm_id);
assert(l_comm->base.base.type == NCCL_NET_OFI_LISTEN_COMM);
Expand All @@ -167,8 +166,7 @@ static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma
/*
* @brief Get endpoint send communicator with given ID
*/
static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_t *ep,
uint64_t local_comm_id)
static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id)
{
nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)
get_comm(ep, local_comm_id);
Expand All @@ -180,7 +178,7 @@ static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_
* @brief Get endpoint recv communicator with given comm_id
*/
static inline nccl_net_ofi_rdma_recv_comm_t *get_recv_comm(nccl_net_ofi_rdma_ep_t *ep,
uint64_t local_comm_id)
uint32_t local_comm_id)
{
nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)
get_comm(ep, local_comm_id);
Expand Down Expand Up @@ -1273,7 +1271,7 @@ static inline int handle_bounce_recv(nccl_net_ofi_rdma_req_type_t msg_type, nccl
static inline nccl_net_ofi_rdma_req_t *get_req_from_imm_data
(nccl_net_ofi_rdma_ep_t *ep, uint64_t data)
{
uint16_t comm_id = GET_COMM_ID_FROM_IMM(data);
uint32_t comm_id = GET_COMM_ID_FROM_IMM(data);
nccl_net_ofi_rdma_recv_comm_t *r_comm = get_recv_comm(ep, comm_id);

uint16_t msg_seq_num = GET_SEQ_NUM_FROM_IMM(data);
Expand Down Expand Up @@ -3238,7 +3236,7 @@ static int recv_close(nccl_net_ofi_recv_comm_t *recv_comm)
/* Release communicator ID */
ret = nccl_ofi_idpool_free_id(ep->comm_idpool, r_comm->local_comm_id);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", r_comm->local_comm_id);
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", r_comm->local_comm_id);
}

free(r_comm);
Expand Down Expand Up @@ -3459,12 +3457,12 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device
r_comm->base.close = recv_close;

/* Allocate recv communicator ID */
int comm_id = nccl_ofi_idpool_allocate_id(ep->comm_idpool);
uint32_t comm_id = (uint32_t)nccl_ofi_idpool_allocate_id(ep->comm_idpool);
if (OFI_UNLIKELY(comm_id < 0)) {
r_comm->local_comm_id = ~0;
goto error;
}
r_comm->local_comm_id = (uint64_t)comm_id;
r_comm->local_comm_id = comm_id;

/* Validate received comm ID */
if (OFI_UNLIKELY(conn_msg->local_comm_id >= device->num_comm_ids)) {
Expand Down Expand Up @@ -3561,7 +3559,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device
if (~0 != r_comm->local_comm_id) {
ret = nccl_ofi_idpool_free_id(ep->comm_idpool, r_comm->local_comm_id);
if (ret != 0) {
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", r_comm->local_comm_id);
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", r_comm->local_comm_id);
}
}
free(r_comm);
Expand Down Expand Up @@ -3921,7 +3919,7 @@ static int listen_close(nccl_net_ofi_listen_comm_t *listen_comm)
/* Release communicator ID */
ret = nccl_ofi_idpool_free_id(((nccl_net_ofi_rdma_ep_t *)base_ep)->comm_idpool, l_comm->comm_id);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", l_comm->comm_id);
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", l_comm->comm_id);
}

free(l_comm);
Expand Down Expand Up @@ -3969,13 +3967,13 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
l_comm->leader_local_ep = first_rail->ofi_ep;

/* Allocate listen communicator ID */
int comm_id = nccl_ofi_idpool_allocate_id(ep->comm_idpool);
uint32_t comm_id = (uint32_t)nccl_ofi_idpool_allocate_id(ep->comm_idpool);
if (OFI_UNLIKELY(comm_id < 0)) {
l_comm->comm_id = ~0;
ret = comm_id;
goto error;
}
l_comm->comm_id = (uint64_t)comm_id;
l_comm->comm_id = comm_id;
handle->comm_id = l_comm->comm_id;

/* Add listen comm to ep's lookup array */
Expand All @@ -3993,7 +3991,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
error:
if (l_comm && ~0 != l_comm->comm_id) {
if (0 != nccl_ofi_idpool_free_id(ep->comm_idpool, l_comm->comm_id)) {
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", l_comm->comm_id);
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", l_comm->comm_id);
}
}
free(l_comm);
Expand Down Expand Up @@ -4663,7 +4661,7 @@ static int send_close(nccl_net_ofi_rdma_send_comm_t *s_comm)
/* Release communicator ID */
ret = nccl_ofi_idpool_free_id(ep->comm_idpool, s_comm->local_comm_id);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", s_comm->local_comm_id);
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", s_comm->local_comm_id);
}

free(s_comm);
Expand Down Expand Up @@ -4726,8 +4724,8 @@ static int blocked_send_close(nccl_net_ofi_send_comm_t *send_comm)
* @return Connection information, on success
* NULL, on others
*/
static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, uint64_t local_comm_id,
uint64_t remote_comm_id, nccl_net_ofi_conn_handle_t *handle,
static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, uint32_t local_comm_id,
uint32_t remote_comm_id, nccl_net_ofi_conn_handle_t *handle,
nccl_ofi_rdma_connection_info_t *conn_msg)
{
int num_rails = ep->num_rails;
Expand Down Expand Up @@ -4918,13 +4916,13 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
ret_s_comm->remote_comm_id = handle->comm_id;

/* Allocate send communicator ID */
int comm_id = nccl_ofi_idpool_allocate_id(ep->comm_idpool);
uint32_t comm_id = (uint32_t)nccl_ofi_idpool_allocate_id(ep->comm_idpool);
if (OFI_UNLIKELY(comm_id < 0)) {
ret_s_comm->local_comm_id = ~0;
ret = comm_id;
goto error;
}
ret_s_comm->local_comm_id = (uint64_t)comm_id;
ret_s_comm->local_comm_id = comm_id;

/* Add ourselves to ep's lookup array */
set_comm(ep, ret_s_comm->local_comm_id, &ret_s_comm->base.base);
Expand Down Expand Up @@ -4977,7 +4975,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,
error:
if (ret_s_comm && ~0 != ret_s_comm->local_comm_id) {
if (0 != nccl_ofi_idpool_free_id(ep->comm_idpool, ret_s_comm->local_comm_id)) {
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", ret_s_comm->local_comm_id);
NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", ret_s_comm->local_comm_id);
}
}
free(ret_s_comm);
Expand Down Expand Up @@ -5617,7 +5615,7 @@ static int device_prepare_for_connection(nccl_net_ofi_rdma_device_t *device)
nccl_net_ofi_rdma_device_rail_t *begin = device->device_rails;
nccl_net_ofi_rdma_device_rail_t *end = device->device_rails + device->num_rails;

device->num_comm_ids = (uint64_t)NCCL_OFI_RDMA_MAX_COMMS;
device->num_comm_ids = (uint32_t)NCCL_OFI_RDMA_MAX_COMMS;

for (; begin != end; ++begin) {
ret = init_device_rail_ofi_resources(begin);
Expand Down
4 changes: 2 additions & 2 deletions src/nccl_ofi_sendrecv.c
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep,
local_ep_name = get_local_address(ep->ofi_ep);

memcpy(handle->ep_name, local_ep_name, MAX_EP_ADDR);
handle->comm_id = tag;
handle->comm_id = (uint32_t)tag;

/* Insert local EP address to AV. This will be used to issue local read operations */
num_addrs = fi_av_insert(ep->av, (void *)local_ep_name, 1,
Expand Down Expand Up @@ -1727,7 +1727,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle,

/* Get tag and remote name from handle */
memcpy(&remote_ep_addr, handle->ep_name, MAX_EP_ADDR);
memcpy(&tag, &handle->comm_id, sizeof(tag));
memcpy(&tag, &handle->comm_id, sizeof(handle->comm_id));
if (tag < 1 || tag > max_tag) {
NCCL_OFI_WARN("Received an invalid tag %lu for device %d", tag,
device->base.dev_id);
Expand Down

0 comments on commit e3f9fd4

Please sign in to comment.