From 15d0adfc97b100f27245a6c275d33ee735d16418 Mon Sep 17 00:00:00 2001 From: Amedeo Sapio Date: Wed, 20 Mar 2024 00:58:06 +0000 Subject: [PATCH] Communicator ID from 64 bits to 32 bits Reduced size of communicator ID (both for SENDRECV and RDMA) to 32 bits, since the communicator ID is transmitted on the wire and 32 bits are more than enough. Signed-off-by: Amedeo Sapio --- include/nccl_ofi.h | 2 +- include/nccl_ofi_rdma.h | 22 ++++++++++----------- src/nccl_ofi_rdma.c | 44 ++++++++++++++++++++--------------------- src/nccl_ofi_sendrecv.c | 4 ++-- 4 files changed, 35 insertions(+), 37 deletions(-) diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index 246635506..d6eae6a6c 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -197,7 +197,7 @@ _Static_assert(sizeof(nccl_ofi_connection_info_t) == 80, typedef struct nccl_net_ofi_conn_handle { char ep_name[MAX_EP_ADDR]; - uint64_t comm_id; + uint32_t comm_id; /* Save temporary communicator state when creating send communicator */ save_comm_state_t state; } nccl_net_ofi_conn_handle_t; diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 76b67766b..ba1839433 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -92,14 +92,14 @@ typedef struct nccl_net_ofi_rdma_ctrl_msg { /* A comm identitifer that uniquely identifies the comm * on the receiver side */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; uint64_t buff_addr; uint64_t buff_len; uint64_t buff_mr_key[MAX_NUM_RAILS]; } nccl_net_ofi_rdma_ctrl_msg_t; /* Since this is a message on the wire, check that it has the expected size */ -_Static_assert(sizeof(nccl_net_ofi_rdma_ctrl_msg_t) == 64, +_Static_assert(sizeof(nccl_net_ofi_rdma_ctrl_msg_t) == 56, "Wrong size for RDMA Control message"); /* Structure used to store control messages in a free list */ @@ -326,11 +326,11 @@ typedef struct nccl_ofi_rdma_connection_info { /* A comm identitifer that uniquely identifies the comm on the sender side. The receiver must use this ID when sending messages to sender */ - uint64_t local_comm_id; + uint32_t local_comm_id; /* A comm identitifer that uniquely identifies the comm * on the receiver side */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; /* Array of `MAX_NUM_RAILS` `nccl_ofi_rdma_ep_name_t` * structs. The member `num_rails` indicates the number of @@ -338,7 +338,7 @@ typedef struct nccl_ofi_rdma_connection_info { nccl_ofi_rdma_ep_name_t ep_names[MAX_NUM_RAILS]; } nccl_ofi_rdma_connection_info_t; /* Since this is a message on the wire, check that it has the expected size */ -_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 248, +_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 240, "Wrong size for RDMA connect message"); /* @@ -373,9 +373,9 @@ typedef struct nccl_net_ofi_rdma_send_comm { nccl_ofi_freelist_t *nccl_ofi_reqs_fl; /* Comm ID provided by the local endpoint */ - uint64_t local_comm_id; + uint32_t local_comm_id; /* Comm ID provided by remote endpoint */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; /* Request to receive connect response message to finalize * connection establishment */ @@ -451,9 +451,9 @@ typedef struct nccl_net_ofi_rdma_recv_comm { nccl_ofi_freelist_t *nccl_ofi_reqs_fl; /* Comm ID provided by the local endpoint */ - uint64_t local_comm_id; + uint32_t local_comm_id; /* Comm ID provided by remote endpoint */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; /* The flush buffer */ nccl_net_ofi_rdma_flush_buffer_t flush_buff; @@ -479,7 +479,7 @@ typedef struct nccl_net_ofi_rdma_listen_comm { nccl_net_ofi_listen_comm_t base; /* Comm ID provided by local endpoint */ - uint64_t comm_id; + uint32_t comm_id; struct fid_ep *leader_local_ep; /* Communicator created while accept routine is executed */ @@ -655,7 +655,7 @@ typedef struct nccl_net_ofi_rdma_device { char *prov_name; /* Maximum number of supported communicator IDs */ - uint64_t num_comm_ids; + uint32_t num_comm_ids; /* Memory registration key pool */ nccl_ofi_idpool_t key_pool; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index abd106b60..3b51d5602 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -136,8 +136,7 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req); /* * @brief Get endpoint communicator with given ID */ -static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, - int64_t local_comm_id) +static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id) { assert(local_comm_id < NCCL_OFI_RDMA_MAX_COMMS); return ep->comms[local_comm_id]; @@ -147,7 +146,7 @@ static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, * @brief Set endpoint communicator with given ID */ static inline void set_comm(nccl_net_ofi_rdma_ep_t *ep, - int64_t local_comm_id, + uint32_t local_comm_id, nccl_net_ofi_comm_t *comm) { assert(local_comm_id < NCCL_OFI_RDMA_MAX_COMMS); @@ -157,7 +156,7 @@ static inline void set_comm(nccl_net_ofi_rdma_ep_t *ep, /* * @brief Get endpoint listen communicator with given comm_id */ -static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma_ep_t *ep, uint64_t local_comm_id) +static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id) { nccl_net_ofi_rdma_listen_comm_t *l_comm = (nccl_net_ofi_rdma_listen_comm_t *)get_comm(ep, local_comm_id); assert(l_comm->base.base.type == NCCL_NET_OFI_LISTEN_COMM); @@ -167,8 +166,7 @@ static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma /* * @brief Get endpoint send communicator with given ID */ -static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_t *ep, - uint64_t local_comm_id) +static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id) { nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *) get_comm(ep, local_comm_id); @@ -180,7 +178,7 @@ static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_ * @brief Get endpoint recv communicator with given comm_id */ static inline nccl_net_ofi_rdma_recv_comm_t *get_recv_comm(nccl_net_ofi_rdma_ep_t *ep, - uint64_t local_comm_id) + uint32_t local_comm_id) { nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *) get_comm(ep, local_comm_id); @@ -1273,7 +1271,7 @@ static inline int handle_bounce_recv(nccl_net_ofi_rdma_req_type_t msg_type, nccl static inline nccl_net_ofi_rdma_req_t *get_req_from_imm_data (nccl_net_ofi_rdma_ep_t *ep, uint64_t data) { - uint16_t comm_id = GET_COMM_ID_FROM_IMM(data); + uint32_t comm_id = GET_COMM_ID_FROM_IMM(data); nccl_net_ofi_rdma_recv_comm_t *r_comm = get_recv_comm(ep, comm_id); uint16_t msg_seq_num = GET_SEQ_NUM_FROM_IMM(data); @@ -3238,7 +3236,7 @@ static int recv_close(nccl_net_ofi_recv_comm_t *recv_comm) /* Release communicator ID */ ret = nccl_ofi_idpool_free_id(ep->comm_idpool, r_comm->local_comm_id); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", r_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", r_comm->local_comm_id); } free(r_comm); @@ -3459,12 +3457,12 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device r_comm->base.close = recv_close; /* Allocate recv communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(ep->comm_idpool); + uint32_t comm_id = (uint32_t)nccl_ofi_idpool_allocate_id(ep->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { r_comm->local_comm_id = ~0; goto error; } - r_comm->local_comm_id = (uint64_t)comm_id; + r_comm->local_comm_id = comm_id; /* Validate received comm ID */ if (OFI_UNLIKELY(conn_msg->local_comm_id >= device->num_comm_ids)) { @@ -3561,7 +3559,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device if (~0 != r_comm->local_comm_id) { ret = nccl_ofi_idpool_free_id(ep->comm_idpool, r_comm->local_comm_id); if (ret != 0) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", r_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", r_comm->local_comm_id); } } free(r_comm); @@ -3921,7 +3919,7 @@ static int listen_close(nccl_net_ofi_listen_comm_t *listen_comm) /* Release communicator ID */ ret = nccl_ofi_idpool_free_id(((nccl_net_ofi_rdma_ep_t *)base_ep)->comm_idpool, l_comm->comm_id); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", l_comm->comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", l_comm->comm_id); } free(l_comm); @@ -3969,13 +3967,13 @@ static int listen(nccl_net_ofi_ep_t *base_ep, l_comm->leader_local_ep = first_rail->ofi_ep; /* Allocate listen communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(ep->comm_idpool); + uint32_t comm_id = (uint32_t)nccl_ofi_idpool_allocate_id(ep->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { l_comm->comm_id = ~0; ret = comm_id; goto error; } - l_comm->comm_id = (uint64_t)comm_id; + l_comm->comm_id = comm_id; handle->comm_id = l_comm->comm_id; /* Add listen comm to ep's lookup array */ @@ -3993,7 +3991,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, error: if (l_comm && ~0 != l_comm->comm_id) { if (0 != nccl_ofi_idpool_free_id(ep->comm_idpool, l_comm->comm_id)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", l_comm->comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", l_comm->comm_id); } } free(l_comm); @@ -4663,7 +4661,7 @@ static int send_close(nccl_net_ofi_rdma_send_comm_t *s_comm) /* Release communicator ID */ ret = nccl_ofi_idpool_free_id(ep->comm_idpool, s_comm->local_comm_id); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", s_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", s_comm->local_comm_id); } free(s_comm); @@ -4726,8 +4724,8 @@ static int blocked_send_close(nccl_net_ofi_send_comm_t *send_comm) * @return Connection information, on success * NULL, on others */ -static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, uint64_t local_comm_id, - uint64_t remote_comm_id, nccl_net_ofi_conn_handle_t *handle, +static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, uint32_t local_comm_id, + uint32_t remote_comm_id, nccl_net_ofi_conn_handle_t *handle, nccl_ofi_rdma_connection_info_t *conn_msg) { int num_rails = ep->num_rails; @@ -4918,13 +4916,13 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret_s_comm->remote_comm_id = handle->comm_id; /* Allocate send communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(ep->comm_idpool); + uint32_t comm_id = (uint32_t)nccl_ofi_idpool_allocate_id(ep->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { ret_s_comm->local_comm_id = ~0; ret = comm_id; goto error; } - ret_s_comm->local_comm_id = (uint64_t)comm_id; + ret_s_comm->local_comm_id = comm_id; /* Add ourselves to ep's lookup array */ set_comm(ep, ret_s_comm->local_comm_id, &ret_s_comm->base.base); @@ -4977,7 +4975,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, error: if (ret_s_comm && ~0 != ret_s_comm->local_comm_id) { if (0 != nccl_ofi_idpool_free_id(ep->comm_idpool, ret_s_comm->local_comm_id)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", ret_s_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", ret_s_comm->local_comm_id); } } free(ret_s_comm); @@ -5617,7 +5615,7 @@ static int device_prepare_for_connection(nccl_net_ofi_rdma_device_t *device) nccl_net_ofi_rdma_device_rail_t *begin = device->device_rails; nccl_net_ofi_rdma_device_rail_t *end = device->device_rails + device->num_rails; - device->num_comm_ids = (uint64_t)NCCL_OFI_RDMA_MAX_COMMS; + device->num_comm_ids = (uint32_t)NCCL_OFI_RDMA_MAX_COMMS; for (; begin != end; ++begin) { ret = init_device_rail_ofi_resources(begin); diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index f8d80c111..dcbd811d0 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -1478,7 +1478,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, local_ep_name = get_local_address(ep->ofi_ep); memcpy(handle->ep_name, local_ep_name, MAX_EP_ADDR); - handle->comm_id = tag; + handle->comm_id = (uint32_t)tag; /* Insert local EP address to AV. This will be used to issue local read operations */ num_addrs = fi_av_insert(ep->av, (void *)local_ep_name, 1, @@ -1727,7 +1727,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, /* Get tag and remote name from handle */ memcpy(&remote_ep_addr, handle->ep_name, MAX_EP_ADDR); - memcpy(&tag, &handle->comm_id, sizeof(tag)); + memcpy(&tag, &handle->comm_id, sizeof(handle->comm_id)); if (tag < 1 || tag > max_tag) { NCCL_OFI_WARN("Received an invalid tag %lu for device %d", tag, device->base.dev_id);