diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index 246635506..d6eae6a6c 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -197,7 +197,7 @@ _Static_assert(sizeof(nccl_ofi_connection_info_t) == 80, typedef struct nccl_net_ofi_conn_handle { char ep_name[MAX_EP_ADDR]; - uint64_t comm_id; + uint32_t comm_id; /* Save temporary communicator state when creating send communicator */ save_comm_state_t state; } nccl_net_ofi_conn_handle_t; diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 76b67766b..ba1839433 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -92,14 +92,14 @@ typedef struct nccl_net_ofi_rdma_ctrl_msg { /* A comm identitifer that uniquely identifies the comm * on the receiver side */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; uint64_t buff_addr; uint64_t buff_len; uint64_t buff_mr_key[MAX_NUM_RAILS]; } nccl_net_ofi_rdma_ctrl_msg_t; /* Since this is a message on the wire, check that it has the expected size */ -_Static_assert(sizeof(nccl_net_ofi_rdma_ctrl_msg_t) == 64, +_Static_assert(sizeof(nccl_net_ofi_rdma_ctrl_msg_t) == 56, "Wrong size for RDMA Control message"); /* Structure used to store control messages in a free list */ @@ -326,11 +326,11 @@ typedef struct nccl_ofi_rdma_connection_info { /* A comm identitifer that uniquely identifies the comm on the sender side. The receiver must use this ID when sending messages to sender */ - uint64_t local_comm_id; + uint32_t local_comm_id; /* A comm identitifer that uniquely identifies the comm * on the receiver side */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; /* Array of `MAX_NUM_RAILS` `nccl_ofi_rdma_ep_name_t` * structs. The member `num_rails` indicates the number of @@ -338,7 +338,7 @@ typedef struct nccl_ofi_rdma_connection_info { nccl_ofi_rdma_ep_name_t ep_names[MAX_NUM_RAILS]; } nccl_ofi_rdma_connection_info_t; /* Since this is a message on the wire, check that it has the expected size */ -_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 248, +_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 240, "Wrong size for RDMA connect message"); /* @@ -373,9 +373,9 @@ typedef struct nccl_net_ofi_rdma_send_comm { nccl_ofi_freelist_t *nccl_ofi_reqs_fl; /* Comm ID provided by the local endpoint */ - uint64_t local_comm_id; + uint32_t local_comm_id; /* Comm ID provided by remote endpoint */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; /* Request to receive connect response message to finalize * connection establishment */ @@ -451,9 +451,9 @@ typedef struct nccl_net_ofi_rdma_recv_comm { nccl_ofi_freelist_t *nccl_ofi_reqs_fl; /* Comm ID provided by the local endpoint */ - uint64_t local_comm_id; + uint32_t local_comm_id; /* Comm ID provided by remote endpoint */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; /* The flush buffer */ nccl_net_ofi_rdma_flush_buffer_t flush_buff; @@ -479,7 +479,7 @@ typedef struct nccl_net_ofi_rdma_listen_comm { nccl_net_ofi_listen_comm_t base; /* Comm ID provided by local endpoint */ - uint64_t comm_id; + uint32_t comm_id; struct fid_ep *leader_local_ep; /* Communicator created while accept routine is executed */ @@ -655,7 +655,7 @@ typedef struct nccl_net_ofi_rdma_device { char *prov_name; /* Maximum number of supported communicator IDs */ - uint64_t num_comm_ids; + uint32_t num_comm_ids; /* Memory registration key pool */ nccl_ofi_idpool_t key_pool; diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index abd106b60..3b51d5602 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -136,8 +136,7 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req); /* * @brief Get endpoint communicator with given ID */ -static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, - int64_t local_comm_id) +static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id) { assert(local_comm_id < NCCL_OFI_RDMA_MAX_COMMS); return ep->comms[local_comm_id]; @@ -147,7 +146,7 @@ static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, * @brief Set endpoint communicator with given ID */ static inline void set_comm(nccl_net_ofi_rdma_ep_t *ep, - int64_t local_comm_id, + uint32_t local_comm_id, nccl_net_ofi_comm_t *comm) { assert(local_comm_id < NCCL_OFI_RDMA_MAX_COMMS); @@ -157,7 +156,7 @@ static inline void set_comm(nccl_net_ofi_rdma_ep_t *ep, /* * @brief Get endpoint listen communicator with given comm_id */ -static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma_ep_t *ep, uint64_t local_comm_id) +static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id) { nccl_net_ofi_rdma_listen_comm_t *l_comm = (nccl_net_ofi_rdma_listen_comm_t *)get_comm(ep, local_comm_id); assert(l_comm->base.base.type == NCCL_NET_OFI_LISTEN_COMM); @@ -167,8 +166,7 @@ static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma /* * @brief Get endpoint send communicator with given ID */ -static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_t *ep, - uint64_t local_comm_id) +static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id) { nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *) get_comm(ep, local_comm_id); @@ -180,7 +178,7 @@ static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_ * @brief Get endpoint recv communicator with given comm_id */ static inline nccl_net_ofi_rdma_recv_comm_t *get_recv_comm(nccl_net_ofi_rdma_ep_t *ep, - uint64_t local_comm_id) + uint32_t local_comm_id) { nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *) get_comm(ep, local_comm_id); @@ -1273,7 +1271,7 @@ static inline int handle_bounce_recv(nccl_net_ofi_rdma_req_type_t msg_type, nccl static inline nccl_net_ofi_rdma_req_t *get_req_from_imm_data (nccl_net_ofi_rdma_ep_t *ep, uint64_t data) { - uint16_t comm_id = GET_COMM_ID_FROM_IMM(data); + uint32_t comm_id = GET_COMM_ID_FROM_IMM(data); nccl_net_ofi_rdma_recv_comm_t *r_comm = get_recv_comm(ep, comm_id); uint16_t msg_seq_num = GET_SEQ_NUM_FROM_IMM(data); @@ -3238,7 +3236,7 @@ static int recv_close(nccl_net_ofi_recv_comm_t *recv_comm) /* Release communicator ID */ ret = nccl_ofi_idpool_free_id(ep->comm_idpool, r_comm->local_comm_id); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", r_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", r_comm->local_comm_id); } free(r_comm); @@ -3459,12 +3457,12 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device r_comm->base.close = recv_close; /* Allocate recv communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(ep->comm_idpool); + uint32_t comm_id = (uint32_t)nccl_ofi_idpool_allocate_id(ep->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { r_comm->local_comm_id = ~0; goto error; } - r_comm->local_comm_id = (uint64_t)comm_id; + r_comm->local_comm_id = comm_id; /* Validate received comm ID */ if (OFI_UNLIKELY(conn_msg->local_comm_id >= device->num_comm_ids)) { @@ -3561,7 +3559,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device if (~0 != r_comm->local_comm_id) { ret = nccl_ofi_idpool_free_id(ep->comm_idpool, r_comm->local_comm_id); if (ret != 0) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", r_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", r_comm->local_comm_id); } } free(r_comm); @@ -3921,7 +3919,7 @@ static int listen_close(nccl_net_ofi_listen_comm_t *listen_comm) /* Release communicator ID */ ret = nccl_ofi_idpool_free_id(((nccl_net_ofi_rdma_ep_t *)base_ep)->comm_idpool, l_comm->comm_id); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", l_comm->comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", l_comm->comm_id); } free(l_comm); @@ -3969,13 +3967,13 @@ static int listen(nccl_net_ofi_ep_t *base_ep, l_comm->leader_local_ep = first_rail->ofi_ep; /* Allocate listen communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(ep->comm_idpool); + uint32_t comm_id = (uint32_t)nccl_ofi_idpool_allocate_id(ep->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { l_comm->comm_id = ~0; ret = comm_id; goto error; } - l_comm->comm_id = (uint64_t)comm_id; + l_comm->comm_id = comm_id; handle->comm_id = l_comm->comm_id; /* Add listen comm to ep's lookup array */ @@ -3993,7 +3991,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, error: if (l_comm && ~0 != l_comm->comm_id) { if (0 != nccl_ofi_idpool_free_id(ep->comm_idpool, l_comm->comm_id)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", l_comm->comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", l_comm->comm_id); } } free(l_comm); @@ -4663,7 +4661,7 @@ static int send_close(nccl_net_ofi_rdma_send_comm_t *s_comm) /* Release communicator ID */ ret = nccl_ofi_idpool_free_id(ep->comm_idpool, s_comm->local_comm_id); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", s_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", s_comm->local_comm_id); } free(s_comm); @@ -4726,8 +4724,8 @@ static int blocked_send_close(nccl_net_ofi_send_comm_t *send_comm) * @return Connection information, on success * NULL, on others */ -static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, uint64_t local_comm_id, - uint64_t remote_comm_id, nccl_net_ofi_conn_handle_t *handle, +static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, uint32_t local_comm_id, + uint32_t remote_comm_id, nccl_net_ofi_conn_handle_t *handle, nccl_ofi_rdma_connection_info_t *conn_msg) { int num_rails = ep->num_rails; @@ -4918,13 +4916,13 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret_s_comm->remote_comm_id = handle->comm_id; /* Allocate send communicator ID */ - int comm_id = nccl_ofi_idpool_allocate_id(ep->comm_idpool); + uint32_t comm_id = (uint32_t)nccl_ofi_idpool_allocate_id(ep->comm_idpool); if (OFI_UNLIKELY(comm_id < 0)) { ret_s_comm->local_comm_id = ~0; ret = comm_id; goto error; } - ret_s_comm->local_comm_id = (uint64_t)comm_id; + ret_s_comm->local_comm_id = comm_id; /* Add ourselves to ep's lookup array */ set_comm(ep, ret_s_comm->local_comm_id, &ret_s_comm->base.base); @@ -4977,7 +4975,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, error: if (ret_s_comm && ~0 != ret_s_comm->local_comm_id) { if (0 != nccl_ofi_idpool_free_id(ep->comm_idpool, ret_s_comm->local_comm_id)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", ret_s_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", ret_s_comm->local_comm_id); } } free(ret_s_comm); @@ -5617,7 +5615,7 @@ static int device_prepare_for_connection(nccl_net_ofi_rdma_device_t *device) nccl_net_ofi_rdma_device_rail_t *begin = device->device_rails; nccl_net_ofi_rdma_device_rail_t *end = device->device_rails + device->num_rails; - device->num_comm_ids = (uint64_t)NCCL_OFI_RDMA_MAX_COMMS; + device->num_comm_ids = (uint32_t)NCCL_OFI_RDMA_MAX_COMMS; for (; begin != end; ++begin) { ret = init_device_rail_ofi_resources(begin); diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index f8d80c111..dcbd811d0 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -1478,7 +1478,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, local_ep_name = get_local_address(ep->ofi_ep); memcpy(handle->ep_name, local_ep_name, MAX_EP_ADDR); - handle->comm_id = tag; + handle->comm_id = (uint32_t)tag; /* Insert local EP address to AV. This will be used to issue local read operations */ num_addrs = fi_av_insert(ep->av, (void *)local_ep_name, 1, @@ -1727,7 +1727,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, /* Get tag and remote name from handle */ memcpy(&remote_ep_addr, handle->ep_name, MAX_EP_ADDR); - memcpy(&tag, &handle->comm_id, sizeof(tag)); + memcpy(&tag, &handle->comm_id, sizeof(handle->comm_id)); if (tag < 1 || tag > max_tag) { NCCL_OFI_WARN("Received an invalid tag %lu for device %d", tag, device->base.dev_id);