diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index 7890dbd3e..d6eae6a6c 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -191,10 +191,13 @@ typedef struct nccl_ofi_connection_info { uint64_t connect_to_self; nccl_net_ofi_req_t* req; } nccl_ofi_connection_info_t; +/* Since this is a message on the wire, check that it has the expected size */ +_Static_assert(sizeof(nccl_ofi_connection_info_t) == 80, + "Wrong size for SENDRECV connect message"); typedef struct nccl_net_ofi_conn_handle { char ep_name[MAX_EP_ADDR]; - uint64_t comm_id; + uint32_t comm_id; /* Save temporary communicator state when creating send communicator */ save_comm_state_t state; } nccl_net_ofi_conn_handle_t; diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 9ec602f73..ee00455b4 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -60,6 +60,13 @@ typedef enum nccl_net_ofi_rdma_req_type { NCCL_OFI_RDMA_SEND_CONN_RESP, } nccl_net_ofi_rdma_req_type_t; +typedef enum nccl_ofi_rdma_msg_type { + NCCL_OFI_RDMA_MSG_CONN, + NCCL_OFI_RDMA_MSG_CONN_RESP, + NCCL_OFI_RDMA_MSG_CTRL, + NCCL_OFI_RDMA_MSG_EAGER +} nccl_ofi_rdma_msg_type_t; + /* * @brief Rdma memory registration handle @@ -77,10 +84,23 @@ typedef struct nccl_net_ofi_rdma_mr_handle { /* Contents of ctrl message sent from receiver to sender to advertise destination buffer */ typedef struct nccl_net_ofi_rdma_ctrl_msg { + /* Message type, must be NCCL_OFI_RDMA_MSG_CTRL */ + uint16_t type; + + /* Message sequence number */ + uint16_t msg_seq_num; + + /* A comm identitifer that uniquely identifies the comm + * on the receiver side */ + uint32_t remote_comm_id; + uint64_t buff_addr; uint64_t buff_len; uint64_t buff_mr_key[MAX_NUM_RAILS]; } nccl_net_ofi_rdma_ctrl_msg_t; +/* Since this is a message on the wire, check that it has the expected size */ +_Static_assert(sizeof(nccl_net_ofi_rdma_ctrl_msg_t) == 56, + "Wrong size for RDMA Control message"); /* Structure used to store control messages in a free list */ typedef struct nccl_net_ofi_rdma_ctrl_fl_item { @@ -296,18 +316,30 @@ typedef struct nccl_ofi_rdma_ep_name { * connection information. */ typedef struct nccl_ofi_rdma_connection_info { + /* Message type + * either NCCL_OFI_RDMA_MSG_CONN or NCCL_OFI_RDMA_MSG_CONN_RESP + */ + uint16_t type; + + /* Number of rails */ + uint16_t num_rails; + /* A comm identitifer that uniquely identifies the comm on the sender side. The receiver must use this ID when sending messages to sender */ - uint64_t local_comm_id; + uint32_t local_comm_id; - /* Number of rails */ - int num_rails; + /* A comm identitifer that uniquely identifies the comm + * on the receiver side */ + uint32_t remote_comm_id; /* Array of `MAX_NUM_RAILS` `nccl_ofi_rdma_ep_name_t` * structs. The member `num_rails` indicates the number of * entries that are in use. */ nccl_ofi_rdma_ep_name_t ep_names[MAX_NUM_RAILS]; } nccl_ofi_rdma_connection_info_t; +/* Since this is a message on the wire, check that it has the expected size */ +_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 236, + "Wrong size for RDMA connect message"); /* * @brief Send communicator rail @@ -341,9 +373,10 @@ typedef struct nccl_net_ofi_rdma_send_comm { nccl_ofi_freelist_t *nccl_ofi_reqs_fl; /* Comm ID provided by the local endpoint */ - uint64_t local_comm_id; + uint32_t local_comm_id; + /* Comm ID provided by remote endpoint */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; /* Request to receive connect response message to finalize * connection establishment */ @@ -419,9 +452,10 @@ typedef struct nccl_net_ofi_rdma_recv_comm { nccl_ofi_freelist_t *nccl_ofi_reqs_fl; /* Comm ID provided by the local endpoint */ - uint64_t local_comm_id; + uint32_t local_comm_id; + /* Comm ID provided by remote endpoint */ - uint64_t remote_comm_id; + uint32_t remote_comm_id; /* The flush buffer */ nccl_net_ofi_rdma_flush_buffer_t flush_buff; @@ -447,7 +481,7 @@ typedef struct nccl_net_ofi_rdma_listen_comm { nccl_net_ofi_listen_comm_t base; /* Comm ID provided by local endpoint */ - uint64_t comm_id; + uint32_t comm_id; struct fid_ep *leader_local_ep; /* Communicator created while accept routine is executed */ @@ -623,7 +657,7 @@ typedef struct nccl_net_ofi_rdma_device { char *prov_name; /* Maximum number of supported communicator IDs */ - uint64_t num_comm_ids; + uint32_t num_comm_ids; /* Memory registration key pool */ nccl_ofi_idpool_t key_pool; diff --git a/include/tracepoint.h b/include/tracepoint.h index 9d278162e..d070c68c1 100644 --- a/include/tracepoint.h +++ b/include/tracepoint.h @@ -136,21 +136,21 @@ LTTNG_UST_TRACEPOINT_EVENT( Recv, LTTNG_UST_TP_ARGS( int, dev, - int, tag, + int, comm_id, int, size, void *, request, void *, nccl_req ), LTTNG_UST_TP_FIELDS( lttng_ust_field_integer(int, dev, dev) - lttng_ust_field_integer(int, tag, tag) + lttng_ust_field_integer(int, comm_id, comm_id) lttng_ust_field_integer(int, size, size) lttng_ust_field_integer_hex(uint64_t, request, (uint64_t)request) lttng_ust_field_integer_hex(uint64_t, nccl_req, (uint64_t)nccl_req) ) ) -#define NCCL_OFI_TRACE_RECV(dev, tag, size, request, nccl_req) \ - lttng_ust_tracepoint(nccl_ofi_plugin, Recv, dev, tag, size, request, nccl_req) +#define NCCL_OFI_TRACE_RECV(dev, comm_id, size, request, nccl_req) \ + lttng_ust_tracepoint(nccl_ofi_plugin, Recv, dev, comm_id, size, request, nccl_req) LTTNG_UST_TRACEPOINT_EVENT( nccl_ofi_plugin, diff --git a/src/nccl_ofi_ofiutils.c b/src/nccl_ofi_ofiutils.c index 046fe7566..1ce552c7e 100644 --- a/src/nccl_ofi_ofiutils.c +++ b/src/nccl_ofi_ofiutils.c @@ -259,7 +259,12 @@ int nccl_ofi_ofiutils_init_connection(int api_version, struct fi_info *info, str goto error; } - cq_attr.format = FI_CQ_FORMAT_TAGGED; + if (info->caps & FI_TAGGED) { + cq_attr.format = FI_CQ_FORMAT_TAGGED; + } else { + cq_attr.format = FI_CQ_FORMAT_DATA; + } + ret = fi_cq_open(domain, &cq_attr, cq, NULL); if (OFI_UNLIKELY(ret != 0)) { NCCL_OFI_WARN("Couldn't open CQ. RC: %d, ERROR: %s", diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 72eaccd95..61eda80b7 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -35,39 +35,26 @@ static pthread_mutex_t topo_file_lock = PTHREAD_MUTEX_INITIALIZER; #define NCCL_OFI_RDMA_MSGBUFF_SIZE 256 /* - * @brief Number of bits used for the tag type - * - * Tag variables are split into two parts, the tag value and the tag - * type. The `NUM_TAG_TYPE_BITS' least significant bits indicate - * the tag type, i.e., data path message, connect message, and connect - * accept message. The more significant bits identify the tag value. - * - * Tag variable bits - * | 50 unused bits | 12-bit tag value | 2-bit tag type | - */ -#define NUM_TAG_TYPE_BITS ((uint64_t)2) - -/* - * @brief Number of bits used for the tag value + * @brief Number of bits used for the communicator ID */ -#define NUM_TAG_VALUE_BITS ((uint64_t)12) +#define NUM_COMM_ID_BITS ((uint64_t)12) /* Maximum number of comms open simultaneously. Eventually this will be runtime-expandable */ -#define NCCL_OFI_RDMA_MAX_COMMS (1 << NUM_TAG_VALUE_BITS) +#define NCCL_OFI_RDMA_MAX_COMMS (1 << NUM_COMM_ID_BITS) /* * @brief Number of bits used for message sequence number * * The immediate data associated with an RDMA write operation is 32 - * bits and is divided into three parts, the segment count, the tag - * value, and the message sequence number (msg_seq_num). The data is - * encoded as follows: + * bits and is divided into three parts, the segment count, the + * communicator ID, and the message sequence number (msg_seq_num). + * The data is encoded as follows: * - * | 4-bit segment count | 12-bit tag value | 16-bit msg_seq_num | + * | 4-bit segment count | 12-bit comm ID | 16-bit msg_seq_num | * * - Segment count: number of RDMA writes that will be delivered as part of this message - * - Tag value: the tag for this communicator, excluding the right two control bits + * - Comm ID: the ID for this communicator * - Message sequence number: message identifier */ #define NUM_MSG_SEQ_NUM_BITS ((uint64_t) 16) @@ -78,14 +65,9 @@ static pthread_mutex_t topo_file_lock = PTHREAD_MUTEX_INITIALIZER; #define NUM_NUM_SEG_BITS ((uint64_t)4) /* - * @brief Tag type bitmask for tag variables + * @brief Communicator ID bitmask */ -#define TAG_TYPE_TAG_MASK (((uint64_t)1 << NUM_TAG_TYPE_BITS) - 1) - -/* - * @brief Tag value bitmask for tag variables - */ -#define TAG_VALUE_TAG_MASK (((uint64_t)1 << NUM_TAG_VALUE_BITS) - 1) +#define COMM_ID_MASK (((uint64_t)1 << NUM_COMM_ID_BITS) - 1) /* * @brief Message sequence number bitmask for immediate data @@ -98,56 +80,11 @@ static pthread_mutex_t topo_file_lock = PTHREAD_MUTEX_INITIALIZER; #define MSG_NUM_SEG_MASK (((uint64_t)1 << NUM_NUM_SEG_BITS) - 1) /* - * @brief Bitmask of tag type that identifies data path messages - */ -#define DATA_MSG_TYPE_MASK ((uint64_t)0) - -/* - * @brief Bitmask of tag type that identifies connect messages - */ -#define CONN_MSG_TYPE_MASK ((uint64_t)1) - -/* - * @brief Bitmask of tag type that identifies connect response messages - */ -#define CONN_RESP_MSG_TYPE_MASK ((uint64_t)2) - -/* - * @brief Return true iff tag type of input tag indicates a data path message - */ -#define IS_DATA_MSG_TYPE(tag) (((tag) & TAG_TYPE_TAG_MASK) == DATA_MSG_TYPE_MASK) - -/* - * @brief Return true iff tag type of input tag indicates a connect message - */ -#define IS_CONN_MSG_TYPE(tag) (((tag) & TAG_TYPE_TAG_MASK) == CONN_MSG_TYPE_MASK) - -/* - * @brief Return true iff tag type of input tag indicates a connect response message - */ -#define IS_CONN_RESP_MSG_TYPE(tag) (((tag) & TAG_TYPE_TAG_MASK) == CONN_RESP_MSG_TYPE_MASK) - -/* - * @brief Return input tag indicating data path message - */ -#define GET_DATA_MSG_TAG(comm_id) (((comm_id) << NUM_TAG_TYPE_BITS) | CONN_MSG_TYPE_MASK) - -/* - * @brief Return input tag indicating connect message - */ -#define GET_CONN_MSG_TAG(comm_id) (((comm_id) << NUM_TAG_TYPE_BITS) | CONN_MSG_TYPE_MASK) - -/* - * @brief Return input tag indicating connect response message - */ -#define GET_CONN_RESP_MSG_TAG(comm_id) (((comm_id) << NUM_TAG_TYPE_BITS) | CONN_RESP_MSG_TYPE_MASK) - -/* - * @brief Extract tag from write completion immediate data + * @brief Extract communicator ID from write completion immediate data * * The immediate data bit format is documented in the definition of NUM_MSG_SEQ_NUM_BITS */ -#define GET_TAG_FROM_IMM(data) ((((data) >> NUM_MSG_SEQ_NUM_BITS)) & TAG_VALUE_TAG_MASK) +#define GET_COMM_ID_FROM_IMM(data) (((data) >> NUM_MSG_SEQ_NUM_BITS) & COMM_ID_MASK) /* * @brief Extract message sequence number from write completion immediate data @@ -161,7 +98,7 @@ static pthread_mutex_t topo_file_lock = PTHREAD_MUTEX_INITIALIZER; * * The immediate data bit format is documented in the definition of NUM_MSG_SEQ_NUM_BITS */ -#define GET_NUM_SEG_FROM_IMM(data) (((data) >> (NUM_MSG_SEQ_NUM_BITS + NUM_TAG_VALUE_BITS)) & MSG_NUM_SEG_MASK) +#define GET_NUM_SEG_FROM_IMM(data) (((data) >> (NUM_MSG_SEQ_NUM_BITS + NUM_COMM_ID_BITS)) & MSG_NUM_SEG_MASK) /* * @brief Build write completion immediate data from comm ID, message seq @@ -170,15 +107,7 @@ static pthread_mutex_t topo_file_lock = PTHREAD_MUTEX_INITIALIZER; * The immediate data bit format is documented in the definition of NUM_MSG_SEQ_NUM_BITS */ #define GET_RDMA_WRITE_IMM_DATA(comm_id, seq, nseg) \ - ((seq) | ((comm_id) << NUM_MSG_SEQ_NUM_BITS) | \ - ((nseg) << (NUM_MSG_SEQ_NUM_BITS + NUM_TAG_VALUE_BITS))) - -/* - * RDMA data-path communication does not use Libfabric tags, but we must use - * tagged APIs since connection establishment uses them. Hence, we use a single - * tag for all data. - */ -#define RDMA_DATA_TAG 0 + ((seq) | ((comm_id) << NUM_MSG_SEQ_NUM_BITS) | ((nseg) << (NUM_MSG_SEQ_NUM_BITS + NUM_COMM_ID_BITS))) /** Global variables **/ @@ -207,8 +136,7 @@ static inline int check_post_bounce_req(nccl_net_ofi_rdma_req_t *bounce_req); /* * @brief Get endpoint communicator with given ID */ -static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, - int64_t local_comm_id) +static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id) { assert(local_comm_id < NCCL_OFI_RDMA_MAX_COMMS); return ep->comms[local_comm_id]; @@ -218,18 +146,27 @@ static inline nccl_net_ofi_comm_t *get_comm(nccl_net_ofi_rdma_ep_t *ep, * @brief Set endpoint communicator with given ID */ static inline void set_comm(nccl_net_ofi_rdma_ep_t *ep, - int64_t local_comm_id, + uint32_t local_comm_id, nccl_net_ofi_comm_t *comm) { assert(local_comm_id < NCCL_OFI_RDMA_MAX_COMMS); ep->comms[local_comm_id] = comm; } +/* + * @brief Get endpoint listen communicator with given comm_id + */ +static inline nccl_net_ofi_rdma_listen_comm_t *get_listen_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id) +{ + nccl_net_ofi_rdma_listen_comm_t *l_comm = (nccl_net_ofi_rdma_listen_comm_t *)get_comm(ep, local_comm_id); + assert(l_comm->base.base.type == NCCL_NET_OFI_LISTEN_COMM); + return l_comm; +} + /* * @brief Get endpoint send communicator with given ID */ -static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_t *ep, - uint64_t local_comm_id) +static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_t *ep, uint32_t local_comm_id) { nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *) get_comm(ep, local_comm_id); @@ -241,7 +178,7 @@ static inline nccl_net_ofi_rdma_send_comm_t *get_send_comm(nccl_net_ofi_rdma_ep_ * @brief Get endpoint recv communicator with given comm_id */ static inline nccl_net_ofi_rdma_recv_comm_t *get_recv_comm(nccl_net_ofi_rdma_ep_t *ep, - uint64_t local_comm_id) + uint32_t local_comm_id) { nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *) get_comm(ep, local_comm_id); @@ -249,6 +186,15 @@ static inline nccl_net_ofi_rdma_recv_comm_t *get_recv_comm(nccl_net_ofi_rdma_ep_ return r_comm; } +/* + * Get connection message from bounce buffer + */ +static inline nccl_ofi_rdma_connection_info_t *get_bounce_connection_msg( + nccl_net_ofi_rdma_bounce_fl_item_t *bounce_fl_item) +{ + return (nccl_ofi_rdma_connection_info_t *)&bounce_fl_item->bounce_msg; +} + /* * Get ctrl message from bounce buffer */ @@ -600,8 +546,8 @@ static inline int get_properties(nccl_net_ofi_device_t *base_dev, * reails have the same speed. */ if (ret == 0) { props->port_speed *= device->num_rails; - _Static_assert(NUM_TAG_VALUE_BITS < 31, - "NUM_TAG_VALUE_BITS must be less than 31 so max_communicators fits in an integer"); + _Static_assert(NUM_COMM_ID_BITS < 31, + "NUM_COMM_ID_BITS must be less than 31 so max_communicators fits in an integer"); props->max_communicators = NCCL_OFI_RDMA_MAX_COMMS; } return ret; @@ -1005,12 +951,12 @@ static inline int decrease_bounce_buff_cnt(nccl_net_ofi_rdma_ep_t *ep, */ static inline int handle_ctrl_recv(nccl_net_ofi_rdma_send_comm_t *s_comm, uint16_t msg_seq_num, - nccl_net_ofi_rdma_req_t *bounce_req, - nccl_net_ofi_rdma_ep_t *ep) + nccl_net_ofi_rdma_req_t *bounce_req) { int ret; nccl_ofi_msgbuff_status_t stat; + nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)s_comm->base.base.ep; nccl_ofi_msgbuff_result_t mb_res = nccl_ofi_msgbuff_insert(s_comm->msgbuff, msg_seq_num, bounce_req, NCCL_OFI_MSGBUFF_BUFF, &stat); @@ -1123,10 +1069,10 @@ static inline int alloc_eager_copy_req(nccl_net_ofi_rdma_req_t *recv_req, nccl_n */ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, uint16_t msg_seq_num, - nccl_net_ofi_rdma_req_t *bounce_req, - nccl_net_ofi_rdma_ep_t *ep) + nccl_net_ofi_rdma_req_t *bounce_req) { int ret; + nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base.base.ep; /* Decrease bounce buffer count. It will be incremented again when reposting */ ret = decrease_bounce_buff_cnt(ep, get_bounce_data(bounce_req)->rail); @@ -1140,7 +1086,7 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, if (mb_res == NCCL_OFI_MSGBUFF_SUCCESS) { /* Inserted! In this case receiver has not yet called recv() for this message, so - return success and initiate eager read when sender calls send(). */ + return success and initiate eager read when receiver calls recv(). */ return 0; } if (mb_res != NCCL_OFI_MSGBUFF_INVALID_IDX) { @@ -1192,51 +1138,128 @@ static inline int handle_eager_recv(nccl_net_ofi_rdma_recv_comm_t *r_comm, return 0; } +static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm); + /** - * @brief Handle receiving a bounce buffer message. These are either - * RDMA control messages (s_comm) or eager messages (r_comm) + * @brief Handle receiving a bounce buffer message. These are: + * connect messages (l_comm), connect response messages (s_comm), + * RDMA control messages (s_comm), eager messages (r_comm). */ -static inline int handle_bounce_recv(struct fi_cq_tagged_entry *cq_entry, int rail_id) +static inline int handle_bounce_recv(nccl_ofi_rdma_msg_type_t msg_type, nccl_net_ofi_rdma_ep_t *ep, int rail_id, + struct fi_cq_data_entry *cq_entry, nccl_net_ofi_rdma_req_t *bounce_req) { - nccl_net_ofi_rdma_req_t *bounce_req = (nccl_net_ofi_rdma_req_t *)cq_entry->op_context; + int ret; + rdma_req_bounce_data_t *bounce_data = NULL; + nccl_ofi_rdma_connection_info_t *conn_msg = NULL; + nccl_ofi_rdma_connection_info_t *conn_resp_msg = NULL; + nccl_net_ofi_rdma_ctrl_msg_t *ctrl_msg = NULL; + nccl_net_ofi_rdma_listen_comm_t *l_comm = NULL; + nccl_net_ofi_rdma_send_comm_t *s_comm = NULL; + nccl_net_ofi_rdma_recv_comm_t *r_comm = NULL; - if (bounce_req == NULL) { + if (OFI_UNLIKELY(bounce_req == NULL)) { NCCL_OFI_WARN("RECV event had NULL ctx!"); return -EINVAL; } - if (bounce_req->type != NCCL_OFI_RDMA_BOUNCE) { + if (OFI_UNLIKELY(bounce_req->type != NCCL_OFI_RDMA_BOUNCE)) { NCCL_OFI_WARN("Invalid non-bounce request as ctx!"); return -EINVAL; } - uint64_t local_comm_id = GET_TAG_FROM_IMM(cq_entry->data); + bounce_data = get_bounce_data(bounce_req); + bounce_data->recv_len = cq_entry->len; + + switch (msg_type) { + case NCCL_OFI_RDMA_MSG_CONN: + /* CONN receive completion */ + assert(sizeof(nccl_ofi_rdma_connection_info_t) == cq_entry->len); - rdma_req_bounce_data_t *bounce_data = get_bounce_data(bounce_req); + conn_msg = get_bounce_connection_msg(bounce_data->bounce_fl_item); + l_comm = get_listen_comm(ep, conn_msg->remote_comm_id); - bounce_data->recv_len = cq_entry->len; + assert(l_comm->req.comm->type == NCCL_NET_OFI_LISTEN_COMM); + assert((nccl_net_ofi_comm_t *)l_comm == l_comm->req.comm); - nccl_net_ofi_rdma_ep_t *ep = bounce_data->ep; - nccl_net_ofi_comm_t *comm = get_comm(ep, local_comm_id); - uint16_t msg_seq_num = GET_SEQ_NUM_FROM_IMM(cq_entry->data); - - if (comm->type == NCCL_NET_OFI_SEND_COMM) { - /* Control message */ - NCCL_OFI_TRACE_SEND_CTRL_RECV(comm->dev_id, rail_id, comm, msg_seq_num); - nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)comm; - assert(s_comm->local_comm_id == local_comm_id); - assert(bounce_data->recv_len == sizeof(nccl_net_ofi_rdma_ctrl_msg_t)); - - return handle_ctrl_recv(s_comm, msg_seq_num, bounce_req, ep); - } else if (comm->type == NCCL_NET_OFI_RECV_COMM) { - /* Eager message */ - NCCL_OFI_TRACE_EAGER_RECV(comm->dev_id, rail_id, comm, msg_seq_num); - nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)comm; - - return handle_eager_recv(r_comm, msg_seq_num, bounce_req, ep); - } else { - NCCL_OFI_WARN("Wrong comm type"); - return -EINVAL; + /* Copy connection message in the communicator */ + l_comm->conn_msg = *conn_msg; + + ret = inc_req_completion(&l_comm->req, cq_entry->len, 1); + if (OFI_UNLIKELY(ret != 0)) { + goto exit; + } + + /* Attempt to re-post bounce buffer */ + ret = repost_bounce_buff(ep, bounce_req); + if (OFI_UNLIKELY(ret != 0)) { + NCCL_OFI_WARN("Failed to repost bounce buff"); + goto exit; + } + break; + case NCCL_OFI_RDMA_MSG_CONN_RESP: + /* CONN_RESP receive completion */ + assert(sizeof(nccl_ofi_rdma_connection_info_t) == cq_entry->len); + + conn_resp_msg = get_bounce_connection_msg(bounce_data->bounce_fl_item); + s_comm = get_send_comm(ep, conn_resp_msg->remote_comm_id); + + assert(NULL != s_comm->conn_resp_req); + assert(NCCL_NET_OFI_SEND_COMM == s_comm->conn_resp_req->comm->type); + assert((nccl_net_ofi_comm_t *)s_comm == s_comm->conn_resp_req->comm); + + /* Copy connection response message in the communicator */ + s_comm->conn_msg = *conn_resp_msg; + + ret = inc_req_completion(s_comm->conn_resp_req, cq_entry->len, 1); + if (OFI_UNLIKELY(ret != 0)) { + goto exit; + } + + ret = finish_connect(s_comm); + if (OFI_UNLIKELY(ret != 0)) { + goto exit; + } + + /* Attempt to re-post bounce buffer */ + ret = repost_bounce_buff(ep, bounce_req); + if (OFI_UNLIKELY(ret != 0)) { + NCCL_OFI_WARN("Failed to repost bounce buff"); + goto exit; + } + break; + case NCCL_OFI_RDMA_MSG_CTRL: + /* CTRL receive completion */ + assert(sizeof(nccl_net_ofi_rdma_ctrl_msg_t) == cq_entry->len); + + ctrl_msg = get_bounce_ctrl_msg(bounce_data->bounce_fl_item); + s_comm = get_send_comm(ep, ctrl_msg->remote_comm_id); + + NCCL_OFI_TRACE_SEND_CTRL_RECV(r_comm->base.base.dev_id, rail_id, s_comm, ctrl_msg->msg_seq_num); + + ret = handle_ctrl_recv(s_comm, ctrl_msg->msg_seq_num, bounce_req); + if (OFI_UNLIKELY(ret != 0)) { + goto exit; + } + break; + case NCCL_OFI_RDMA_MSG_EAGER: + /* Eager message receive completion */ + + r_comm = get_recv_comm(ep, GET_COMM_ID_FROM_IMM(cq_entry->data)); + + NCCL_OFI_TRACE_EAGER_RECV(r_comm->base.base.dev_id, rail_id, r_comm, + GET_SEQ_NUM_FROM_IMM(cq_entry->data)); + + ret = handle_eager_recv(r_comm, GET_SEQ_NUM_FROM_IMM(cq_entry->data), bounce_req); + if (OFI_UNLIKELY(ret != 0)) { + goto exit; + } + break; + default: + NCCL_OFI_WARN("Recv completion with unexpected type"); + ret = -EINVAL; + goto exit; } +exit: + return ret; } /** @@ -1248,7 +1271,7 @@ static inline int handle_bounce_recv(struct fi_cq_tagged_entry *cq_entry, int ra static inline nccl_net_ofi_rdma_req_t *get_req_from_imm_data (nccl_net_ofi_rdma_ep_t *ep, uint64_t data) { - uint16_t comm_id = GET_TAG_FROM_IMM(data); + uint32_t comm_id = GET_COMM_ID_FROM_IMM(data); nccl_net_ofi_rdma_recv_comm_t *r_comm = get_recv_comm(ep, comm_id); uint16_t msg_seq_num = GET_SEQ_NUM_FROM_IMM(data); @@ -1274,9 +1297,7 @@ static inline nccl_net_ofi_rdma_req_t *get_req_from_imm_data /** * @brief Handle completion for a remote write event */ -static inline int handle_write_comp(struct fi_cq_tagged_entry *cq_entry, - nccl_net_ofi_rdma_ep_t *ep, - int rail_id) +static inline int handle_write_comp(struct fi_cq_data_entry *cq_entry, nccl_net_ofi_rdma_ep_t *ep, int rail_id) { int ret; @@ -1301,8 +1322,6 @@ static inline int handle_write_comp(struct fi_cq_tagged_entry *cq_entry, return 0; } -static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm); - static const char *req_state_str(nccl_net_ofi_rdma_req_state_t state) { switch(state) { @@ -1374,126 +1393,105 @@ static int post_eager_copy(nccl_net_ofi_rdma_req_t *req); * @return 0, on success * error, on others */ -static inline int process_completions(struct fi_cq_tagged_entry *cq_entry, - uint64_t num_cqes, nccl_net_ofi_rdma_ep_t *ep, +static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_t num_cqes, nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_ep_rail_t *rail) { int ret = 0; nccl_net_ofi_rdma_req_t *req = NULL; uint64_t comp_idx = 0, comp_flags = 0; - for (comp_idx = 0; comp_idx < num_cqes; comp_idx++) { - void *op_ctx = cq_entry[comp_idx].op_context; + rdma_req_send_data_t *send_data = NULL; + uint16_t *msg_type = NULL; + for (comp_idx = 0; comp_idx < num_cqes; comp_idx++) { + /* The context for these operations is req. + * except in the FI_REMOTE_WRITE case where is NULL */ + req = cq_entry[comp_idx].op_context; comp_flags = cq_entry[comp_idx].flags; - - // TODO we don't always have a req in this function. - // NCCL_OFI_TRACE_COMPLETIONS(req, req); + assert(NULL != req || (comp_flags & FI_REMOTE_WRITE)); /** - * Types of completions - * 0. Connect/Accept ctrl : tagged message and connect message or connect response tag type - * 1. Ctrl send complete: recv communicator AND FI_SEND - * 2. Ctrl recv complete: send communicator AND FI_RECV - * 5. fi_write local complete: send communicator AND FI_WRITE - * 6. fi_write remote complete: recv communicator AND FI_REMOTE_WRITE - * 7. flush complete : recv communicator AND FI_READ + * Types of completions: + * 1. SEND: connect, connect response, or control message + * 2. RECV w/o immediate data: connect, connect response, or control message + * 3. RECV w/ immediate data: eager message + * 4. Remote-initiated write + * 5. Local-initiated write + * 6. READ: flush or eager copy */ + if (comp_flags & FI_SEND) { + /* Send completions */ - if (OFI_UNLIKELY((comp_flags & FI_TAGGED) && !IS_DATA_MSG_TYPE(cq_entry[comp_idx].tag))) { - /* Type 0 */ - assert(IS_CONN_MSG_TYPE(cq_entry[comp_idx].tag) || IS_CONN_RESP_MSG_TYPE(cq_entry[comp_idx].tag)); + if (req->type == NCCL_OFI_RDMA_SEND_CONN || req->type == NCCL_OFI_RDMA_SEND_CONN_RESP) { + /* CONN or CONN_RESP send completion */ + ret = inc_req_completion(req, cq_entry[comp_idx].len, 1); - req = op_ctx; - ret = inc_req_completion(req, cq_entry[comp_idx].len, 1); - if (OFI_UNLIKELY(ret != 0)) { - return ret; - } - - if (IS_CONN_RESP_MSG_TYPE(cq_entry[comp_idx].tag) && (comp_flags & FI_RECV)) { - assert(req->comm->type == NCCL_NET_OFI_SEND_COMM); - /* Complete send communicator */ - nccl_net_ofi_rdma_send_comm_t *s_comm = - (nccl_net_ofi_rdma_send_comm_t *)req->comm; - assert(s_comm->conn_resp_req == req); - ret = finish_connect(s_comm); - } - } else if (comp_flags & FI_SEND) { - /* The context for these operations is req. */ - req = op_ctx; - - if (req->type == NCCL_OFI_RDMA_SEND_CTRL) { - /* Type 1 */ + } else if (req->type == NCCL_OFI_RDMA_SEND_CTRL) { + /* CTRL message send completion */ ret = set_send_ctrl_completed(req); - if (OFI_UNLIKELY(ret != 0)) { - return ret; - } - } else if (req->type == NCCL_OFI_RDMA_SEND) { - rdma_req_send_data_t *send_data = get_send_data(req); + } else if (req->type == NCCL_OFI_RDMA_SEND) { + /* Eager message send completion */ + send_data = get_send_data(req); assert(send_data->eager); - ret = inc_req_completion(req, 0, send_data->total_num_compls); - if (OFI_UNLIKELY(ret != 0)) { - goto exit; - } + } else { - /* Type 3 */ - NCCL_OFI_WARN("Send complete from unexpected req type"); + NCCL_OFI_WARN("Send completion from unexpected request type"); ret = -EINVAL; - goto exit; } } else if (comp_flags & FI_RECV) { - /* This is a bounce buffer receive event. It could be a - ctrl message receive (send comm) or an eager message - receive (recv comm) */ - ret = handle_bounce_recv(&cq_entry[comp_idx], rail->rail_id); + /* Receive completions */ + + if (!(comp_flags & FI_REMOTE_CQ_DATA)) { + /* CONN, CONN_RESP, or CTRL message */ + msg_type = (uint16_t *)cq_entry[comp_idx].buf; + ret = handle_bounce_recv(*msg_type, ep, rail->rail_id, &cq_entry[comp_idx], req); + + } else { + /* Eager message receive completion */ + ret = handle_bounce_recv(NCCL_OFI_RDMA_MSG_EAGER, ep, rail->rail_id, + &cq_entry[comp_idx], req); + } } else if (comp_flags & FI_REMOTE_WRITE) { - /* Type 6: Remote-initiated write is complete */ + /* Remote-initiated write is complete */ ret = handle_write_comp(&cq_entry[comp_idx], ep, rail->rail_id); - } else if (comp_flags & FI_WRITE) { - /* Type 5: Local-initiated write is complete */ - req = op_ctx; - rdma_req_send_data_t *send_data = get_send_data(req); - NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE(req->dev_id, rail->rail_id, req->comm, req->msg_seq_num, req); + } else if (comp_flags & FI_WRITE) { + /* Local-initiated write is complete */ + NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE(req->dev_id, rail->rail_id, req->comm, req->msg_seq_num, + req); + send_data = get_send_data(req); ret = inc_req_completion(req, 0, send_data->total_num_compls); - if (OFI_UNLIKELY(ret != 0)) { - goto exit; - } - } else if (comp_flags & FI_READ) { - req = op_ctx; + } else if (comp_flags & FI_READ) { switch (req->type) { case NCCL_OFI_RDMA_FLUSH: { /* fi_read flush is complete */ + rdma_req_flush_data_t *flush_data = get_flush_data(req); ret = inc_req_completion(req, 0, flush_data->schedule->num_xfer_infos); - if (OFI_UNLIKELY(ret != 0)) { - goto exit; - } break; } case NCCL_OFI_RDMA_EAGER_COPY: { ret = set_eager_copy_completed(req); - if (OFI_UNLIKELY(ret != 0)) { - goto exit; - } break; } default: NCCL_OFI_WARN("Read complete from unexpected request type!"); ret = -EINVAL; - goto exit; } } else { - NCCL_OFI_WARN("Unexpected comp_flags on cq event"); + NCCL_OFI_WARN("Unexpected comp_flags on cq event 0x%016X", comp_flags); ret = -EINVAL; + } + + if (OFI_UNLIKELY(ret != 0)) { goto exit; } } - exit: +exit: return ret; } @@ -1673,16 +1671,15 @@ static int process_pending_reqs(nccl_net_ofi_rdma_ep_t *ep) static int ofi_process_cq_rail(nccl_net_ofi_rdma_ep_t *ep, nccl_net_ofi_ep_rail_t *rail) { - struct fi_cq_tagged_entry cqe_tagged_buffers[cq_read_count]; + struct fi_cq_data_entry cqe_buffers[cq_read_count]; ssize_t rc = 0; int ret = 0; while (true) { /* Receive completions for the given endpoint */ - rc = fi_cq_read(rail->cq, cqe_tagged_buffers, cq_read_count); + rc = fi_cq_read(rail->cq, cqe_buffers, cq_read_count); if (rc > 0) { - ret = process_completions( - cqe_tagged_buffers, rc, ep, rail); + ret = process_completions(cqe_buffers, rc, ep, rail); if (OFI_UNLIKELY(ret != 0)) goto exit; } else if (OFI_UNLIKELY(rc == -FI_EAVAIL)) { @@ -2207,7 +2204,7 @@ static int finish_connect(nccl_net_ofi_rdma_send_comm_t *s_comm) /* Validate received comm ID */ if (OFI_UNLIKELY(conn_resp->local_comm_id >= device->num_comm_ids)) { - NCCL_OFI_WARN("Received an invalid communicator ID %lu for device %d", conn_resp->local_comm_id, + NCCL_OFI_WARN("Received an invalid communicator ID %u for device %d", conn_resp->local_comm_id, dev_id); return -EINVAL; } @@ -2356,7 +2353,7 @@ static int prepare_recv_conn_req(nccl_net_ofi_rdma_listen_comm_t *l_comm) req->type = NCCL_OFI_RDMA_RECV_CONN; req->free = free_invalid; req->base.test = test; - req->state = NCCL_OFI_RDMA_REQ_CREATED; + req->state = NCCL_OFI_RDMA_REQ_PENDING; req->comm = &l_comm->base.base; req->dev_id = l_comm->base.base.dev_id; /* Initialize mutex for request access */ @@ -2369,102 +2366,6 @@ static int prepare_recv_conn_req(nccl_net_ofi_rdma_listen_comm_t *l_comm) return 0; } -/* - * @brief Post a request to receive peer connect response message and - * process completion queue in case posting the receive fails - * - * @param s_comm - * Send communicator with initalized first communicator rail - * @param device - * Device of send communicator - * @param ep - * Endpoint of send communicator - * - * @return 0, on successful posting of receive request - * -FI_EAGAIN, on lack of provider resources to post receive request - * error, others - */ -static int post_recv_conn_resp(nccl_net_ofi_rdma_send_comm_t *s_comm, - nccl_net_ofi_rdma_device_t *device, - nccl_net_ofi_rdma_ep_t *ep) -{ - ssize_t rc = 0; - int ret = 0; - int dev_id = s_comm->base.base.dev_id; - assert(s_comm && s_comm->num_rails > 0); - nccl_net_ofi_rdma_send_comm_rail_t *comm_rail = get_send_comm_rail(s_comm, 0); - nccl_net_ofi_rdma_req_t *req = s_comm->conn_resp_req; - - /* Post a buffer for receiving connect response requests */ - rc = fi_trecv(comm_rail->local_ep, &s_comm->conn_msg, - sizeof(nccl_ofi_rdma_connection_info_t), - NULL, comm_rail->remote_addr, - GET_CONN_RESP_MSG_TAG(s_comm->local_comm_id), - 0, req); - if (rc == -FI_EAGAIN) { - /* - * Process completions so that you have enough - * resources for posting receive buffer - */ - ret = ofi_process_cq(ep); - if (OFI_UNLIKELY(ret != 0)) - return ret; - } - else if (rc != 0) - NCCL_OFI_WARN("Unable to post a buffer for receving connect responses for dev %d. RC: %zd, ERROR: %s", - dev_id, rc, fi_strerror(-rc)); - - return rc; -} - -/* - * @brief Post a request to receive peer connect message and - * process completion queue in case posting the receive failed - * - * @param l_comm - * Listen communicator - * @param device - * Device of listen communicator - * @param ep - * Endpoint of listen communicator - * - * @return 0, on successful posting of receive request - * -FI_EAGAIN, on lack of provider resources to post receive request - * error, others - */ -static int post_recv_conn(nccl_net_ofi_rdma_listen_comm_t *l_comm, - nccl_net_ofi_rdma_device_t *device, - nccl_net_ofi_rdma_ep_t *ep) -{ - ssize_t rc = 0; - int ret = 0; - int dev_id = l_comm->base.base.dev_id; - - /* Post a buffer for receiving connection requests */ - l_comm->req.state = NCCL_OFI_RDMA_REQ_PENDING; - rc = fi_trecv(l_comm->leader_local_ep, &l_comm->conn_msg, sizeof(nccl_ofi_rdma_connection_info_t), - NULL, FI_ADDR_UNSPEC, GET_CONN_MSG_TAG(l_comm->comm_id), - 0, &l_comm->req); - if (rc == -FI_EAGAIN) { - l_comm->req.state = NCCL_OFI_RDMA_REQ_CREATED; - /* - * Process completions so that you have enough - * resources for posting receive buffer - */ - ret = ofi_process_cq(ep); - if (OFI_UNLIKELY(ret != 0)) { - return ret; - } - } - else if (rc != 0) { - l_comm->req.state = NCCL_OFI_RDMA_REQ_CREATED; - NCCL_OFI_WARN("Unable to post a buffer for receving connections for dev %d. RC: %zd, ERROR: %s", - dev_id, rc, fi_strerror(-rc)); - } - - return rc; -} - /* * @brief Deregister libfabric memory registration of rails * @@ -2853,8 +2754,12 @@ static inline int insert_send_ctrl_req( return -ENOTSUP; } + ctrl_fl_item->ctrl_msg.type = NCCL_OFI_RDMA_MSG_CTRL; + ctrl_fl_item->ctrl_msg.remote_comm_id = r_comm->remote_comm_id; + ctrl_fl_item->ctrl_msg.msg_seq_num = msg_seq_num; ctrl_fl_item->ctrl_msg.buff_addr = (uint64_t)buff; ctrl_fl_item->ctrl_msg.buff_len = size; + int rail_id = 0; for (; rail_id < r_comm->num_rails; rail_id++) { ctrl_fl_item->ctrl_msg.buff_mr_key[rail_id] = fi_mr_key(buff_mr_handle->mr[rail_id]); @@ -3331,7 +3236,7 @@ static int recv_close(nccl_net_ofi_recv_comm_t *recv_comm) /* Release communicator ID */ ret = nccl_ofi_idpool_free_id(ep->comm_idpool, r_comm->local_comm_id); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", r_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", r_comm->local_comm_id); } free(r_comm); @@ -3518,8 +3423,7 @@ static inline nccl_net_ofi_rdma_recv_comm_t *calloc_rdma_recv_comm(int num_rails * @return Receive communicator object, on success * NULL, on error */ -static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen_comm_t *l_comm, - nccl_net_ofi_rdma_device_t *device, +static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device_t *device, nccl_net_ofi_rdma_ep_t *ep, nccl_ofi_rdma_connection_info_t *conn_msg) { @@ -3558,7 +3462,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen r_comm->local_comm_id = ~0; goto error; } - r_comm->local_comm_id = (uint64_t)comm_id; + r_comm->local_comm_id = (uint32_t)comm_id; /* Validate received comm ID */ if (OFI_UNLIKELY(conn_msg->local_comm_id >= device->num_comm_ids)) { @@ -3655,7 +3559,7 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_listen if (~0 != r_comm->local_comm_id) { ret = nccl_ofi_idpool_free_id(ep->comm_idpool, r_comm->local_comm_id); if (ret != 0) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", r_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", r_comm->local_comm_id); } } free(r_comm); @@ -3690,6 +3594,8 @@ static int prepare_conn_resp(nccl_net_ofi_rdma_ep_t *ep, return -EINVAL; } + conn_resp->type = NCCL_OFI_RDMA_MSG_CONN_RESP; + /* Set number of rails to be sent back to remote for verification */ conn_resp->num_rails = num_rails; @@ -3735,9 +3641,8 @@ static int post_send_conn_resp(nccl_net_ofi_rdma_recv_comm_t *r_comm, nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail = get_recv_comm_rail(r_comm, 0); req->state = NCCL_OFI_RDMA_REQ_PENDING; - rc = fi_tsend(comm_rail->local_ep, (void *)conn_resp, - sizeof(nccl_ofi_rdma_connection_info_t), NULL, comm_rail->remote_addr, - GET_CONN_RESP_MSG_TAG(r_comm->remote_comm_id), req); + rc = fi_send(comm_rail->local_ep, (void *)conn_resp, sizeof(nccl_ofi_rdma_connection_info_t), NULL, + comm_rail->remote_addr, req); if (rc == -FI_EAGAIN) { req->state = NCCL_OFI_RDMA_REQ_CREATED; @@ -3875,7 +3780,7 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm, } /* Prepare receive communicator object for the received peer connection */ - r_comm = prepare_recv_comm(l_comm, device, ep, conn_msg); + r_comm = prepare_recv_comm(device, ep, conn_msg); if (OFI_UNLIKELY(r_comm == NULL)) { ret = -EINVAL; goto exit; @@ -3898,6 +3803,9 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm, /* Set r_comm's (local) comm ID to be sent back to remote */ conn_msg->local_comm_id = r_comm->local_comm_id; + /* Send r_comm's remote comm ID */ + conn_msg->remote_comm_id = r_comm->remote_comm_id; + /* COMM_SEND_CONN: Send connect response message to remote */ ret = post_send_conn_resp(r_comm, conn_msg, device, ep, req); if (ret == -FI_EAGAIN) { @@ -3969,12 +3877,13 @@ static int accept(nccl_net_ofi_listen_comm_t *listen_comm, ret = -EINVAL; } - exit: - + exit:; /* Close receive communicator in case listen operation failed */ - ret = close_listen_recv_comm(l_comm); - - return ret; + int close_ret = close_listen_recv_comm(l_comm); + if (close_ret) { + NCCL_OFI_WARN("Failed to close listen communicator"); + } + return ret ? ret : close_ret; } static int listen_close(nccl_net_ofi_listen_comm_t *listen_comm) @@ -4010,7 +3919,7 @@ static int listen_close(nccl_net_ofi_listen_comm_t *listen_comm) /* Release communicator ID */ ret = nccl_ofi_idpool_free_id(((nccl_net_ofi_rdma_ep_t *)base_ep)->comm_idpool, l_comm->comm_id); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", l_comm->comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", l_comm->comm_id); } free(l_comm); @@ -4025,7 +3934,6 @@ static int listen(nccl_net_ofi_ep_t *base_ep, { int ret = 0; nccl_net_ofi_rdma_listen_comm_t *l_comm = NULL; - bool first_post = true; nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)base_ep; nccl_net_ofi_ep_rail_t *first_rail = get_rail(ep, 0); @@ -4065,29 +3973,17 @@ static int listen(nccl_net_ofi_ep_t *base_ep, ret = comm_id; goto error; } - l_comm->comm_id = (uint64_t)comm_id; + l_comm->comm_id = (uint32_t)comm_id; handle->comm_id = l_comm->comm_id; + /* Add listen comm to ep's lookup array */ + set_comm(ep, l_comm->comm_id, &l_comm->base.base); + /* Prepare receive request to accept connections */ ret = prepare_recv_conn_req(l_comm); if (ret != 0) goto error; - /* Post connect message to receive peer connections until posting succeeds */ - do { - ret = post_recv_conn(l_comm, device, ep); - if (ret == -FI_EAGAIN) { - if (first_post) { - first_post = false; - NCCL_OFI_WARN("Unable to post receive of connect message for dev %d. Trying again until success.", - dev_id); - } - // Try again - } else if (ret != 0) { - goto error; - } - } while (ret == -FI_EAGAIN); - *listen_comm = &l_comm->base; goto exit; @@ -4095,7 +3991,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, error: if (l_comm && ~0 != l_comm->comm_id) { if (0 != nccl_ofi_idpool_free_id(ep->comm_idpool, l_comm->comm_id)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", l_comm->comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", l_comm->comm_id); } } free(l_comm); @@ -4251,13 +4147,11 @@ static int post_rdma_eager_send(nccl_net_ofi_rdma_req_t *req, ssize_t rc; /* Post eager send */ - rc = fi_tsenddata(comm_rail->local_ep, send_data->buff + xfer_info->offset, - xfer_info->msg_size, desc, send_data->wdata, comm_rail->remote_addr, - RDMA_DATA_TAG, req); + rc = fi_senddata(comm_rail->local_ep, send_data->buff + xfer_info->offset, xfer_info->msg_size, desc, + send_data->wdata, comm_rail->remote_addr, req); if ((rc != 0) && (rc != -FI_EAGAIN)) { - NCCL_OFI_WARN("fi_tsenddata failed; RC: %zd, Error: %s", - rc, fi_strerror(-rc)); + NCCL_OFI_WARN("fi_senddata failed; RC: %zd, Error: %s", rc, fi_strerror(-rc)); } else if (rc == 0) { /* TODO: use a better trace for eager send? */ NCCL_OFI_TRACE_SEND_WRITE_SEG_START(req->dev_id, rail_id, xfer_info->msg_size, req->comm, req->msg_seq_num, req); @@ -4282,9 +4176,8 @@ static int post_bounce_buffer(nccl_net_ofi_rdma_req_t *req, bounce_fl_item); req->state = NCCL_OFI_RDMA_REQ_CREATED; - ssize_t rc = fi_trecv(ep_rail->ofi_ep, &bounce_fl_item->bounce_msg, - bounce_data->buff_len, desc, FI_ADDR_UNSPEC, - RDMA_DATA_TAG, 0, req); + ssize_t rc = + fi_recv(ep_rail->ofi_ep, &bounce_fl_item->bounce_msg, bounce_data->buff_len, desc, FI_ADDR_UNSPEC, req); if ((rc != 0) && (rc != -FI_EAGAIN)) { NCCL_OFI_WARN("Error posting bounce buffer. RC: %zd, Error: %s", rc, fi_strerror(-rc)); @@ -4389,11 +4282,8 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) assert(xfer_info->rail_id < mr_handle->num_rails); void *desc = fi_mr_desc(mr_handle->mr[xfer_info->rail_id]); - uint64_t data = GET_RDMA_WRITE_IMM_DATA(r_comm->remote_comm_id, req->msg_seq_num, 0); - - ssize_t rc = fi_tsenddata(comm_rail->local_ep, &ctrl_fl_item->ctrl_msg, - sizeof(nccl_net_ofi_rdma_ctrl_msg_t), desc, - data, comm_rail->remote_addr, RDMA_DATA_TAG, req); + ssize_t rc = fi_send(comm_rail->local_ep, &ctrl_fl_item->ctrl_msg, sizeof(nccl_net_ofi_rdma_ctrl_msg_t), desc, + comm_rail->remote_addr, req); if ((rc != 0) && (rc != -FI_EAGAIN)) { NCCL_OFI_WARN("Error posting RDMA ctrl request. RC: %zd, Error: %s", @@ -4771,7 +4661,7 @@ static int send_close(nccl_net_ofi_rdma_send_comm_t *s_comm) /* Release communicator ID */ ret = nccl_ofi_idpool_free_id(ep->comm_idpool, s_comm->local_comm_id); if (OFI_UNLIKELY(ret != 0)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", s_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", s_comm->local_comm_id); } free(s_comm); @@ -4834,16 +4724,20 @@ static int blocked_send_close(nccl_net_ofi_send_comm_t *send_comm) * @return Connection information, on success * NULL, on others */ -static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, - int dev_id, uint64_t local_comm_id, - nccl_net_ofi_conn_handle_t *handle, - nccl_ofi_rdma_connection_info_t* conn_msg) +static void prepare_send_connect_message(nccl_net_ofi_rdma_ep_t *ep, int dev_id, uint32_t local_comm_id, + uint32_t remote_comm_id, nccl_net_ofi_conn_handle_t *handle, + nccl_ofi_rdma_connection_info_t *conn_msg) { int num_rails = ep->num_rails; + conn_msg->type = NCCL_OFI_RDMA_MSG_CONN; + /* Send s_comm's local comm ID to be transferred to receiver */ conn_msg->local_comm_id = local_comm_id; + /* Send s_comm's remote comm ID */ + conn_msg->remote_comm_id = remote_comm_id; + /* Set number of rails to be sent back to remote for verification */ conn_msg->num_rails = num_rails; @@ -5012,7 +4906,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret_s_comm->base.close = blocked_send_close; ret_s_comm->next_msg_seq_num = 0; - /* Store tag from handle in communicator */ + /* Store communicator ID from handle in communicator */ if (OFI_UNLIKELY(handle->comm_id >= device->num_comm_ids)) { NCCL_OFI_WARN("Received an invalid communicator ID %lu for device %d", handle->comm_id, dev_id); @@ -5028,7 +4922,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, ret = comm_id; goto error; } - ret_s_comm->local_comm_id = (uint64_t)comm_id; + ret_s_comm->local_comm_id = (uint32_t)comm_id; /* Add ourselves to ep's lookup array */ set_comm(ep, ret_s_comm->local_comm_id, &ret_s_comm->base.base); @@ -5063,7 +4957,8 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, } /* Allocate and initialize connect message */ - prepare_send_connect_message(ep, dev_id, ret_s_comm->local_comm_id, handle, &ret_s_comm->conn_msg); + prepare_send_connect_message(ep, dev_id, ret_s_comm->local_comm_id, ret_s_comm->remote_comm_id, handle, + &ret_s_comm->conn_msg); /* Allocate message buffer */ ret_s_comm->msgbuff = nccl_ofi_msgbuff_init(NCCL_OFI_RDMA_MSGBUFF_SIZE); @@ -5080,7 +4975,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, error: if (ret_s_comm && ~0 != ret_s_comm->local_comm_id) { if (0 != nccl_ofi_idpool_free_id(ep->comm_idpool, ret_s_comm->local_comm_id)) { - NCCL_OFI_WARN("Error freeing communicator ID %"PRIu64"", ret_s_comm->local_comm_id); + NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", ret_s_comm->local_comm_id); } } free(ret_s_comm); @@ -5166,9 +5061,8 @@ static int post_send_conn(nccl_net_ofi_rdma_send_comm_t *s_comm, * providers can support it, so that need for completion check * can be lifted. */ - rc = fi_tsend(comm_rail->local_ep, (void *)&s_comm->conn_msg, - sizeof(nccl_ofi_rdma_connection_info_t), NULL, comm_rail->remote_addr, - GET_CONN_MSG_TAG(s_comm->remote_comm_id), req); + rc = fi_send(comm_rail->local_ep, (void *)&s_comm->conn_msg, sizeof(nccl_ofi_rdma_connection_info_t), NULL, + comm_rail->remote_addr, req); if (rc == -FI_EAGAIN) { /* @@ -5192,7 +5086,7 @@ static int post_send_conn(nccl_net_ofi_rdma_send_comm_t *s_comm, * * The connect functionality is split into two steps. This function * implements the first step in a nonblocking manner. The first step - * performs (a) create send comminicator with only the first + * performs (a) create send communicator with only the first * communicator rail being initalized, (b) post send operation to send * connect message to remote, containing local endpoint addresses, (c) * wait until message is delivered, (d) post receive operation to @@ -5267,6 +5161,14 @@ static int connect(nccl_net_ofi_ep_t *base_ep, comm_state->stage = COMM_SEND_CONN; case COMM_SEND_CONN: + + /* Prepare request to receive connect response message */ + s_comm->conn_resp_req = prepare_recv_conn_resp_req(s_comm); + if (OFI_UNLIKELY(s_comm->conn_resp_req == NULL)) { + send_close(s_comm); + return -EINVAL; + } + /* COMM_SEND_CONN: Post a connect message to send peer connections */ ret = post_send_conn(s_comm, device, ep, req); if (ret == -FI_EAGAIN) { @@ -5317,25 +5219,12 @@ static int connect(nccl_net_ofi_ep_t *base_ep, comm_state->req = NULL; req = NULL; - /* Prepare request to receive connect response message */ - s_comm->conn_resp_req = prepare_recv_conn_resp_req(s_comm); - if (OFI_UNLIKELY(s_comm->conn_resp_req == NULL)) { - send_close(s_comm); - return -EINVAL; - } - comm_state->stage = COMM_RECV_CONN; case COMM_RECV_CONN: /* COMM_RECV_CONN: Receive connect response message from remote */ - ret = post_recv_conn_resp(s_comm, device, ep); - if (ret == -FI_EAGAIN) { - return 0; - } else if (ret != 0) { - send_close(s_comm); - return ret; - } + assert(s_comm && s_comm->num_rails > 0); /* Progress our engine to get completions. If the * connect response message has arrived, the @@ -5422,9 +5311,8 @@ static int ep_rail_init(nccl_net_ofi_rdma_ep_t *ep, { int ret = 0; - ret = nccl_ofi_ofiutils_init_connection(FI_VERSION(1, 18), - dev_rail->info, dev_rail->domain, - &ep_rail->ofi_ep, &ep_rail->av, &ep_rail->cq); + ret = nccl_ofi_ofiutils_init_connection(FI_VERSION(1, 18), dev_rail->info, dev_rail->domain, &ep_rail->ofi_ep, + &ep_rail->av, &ep_rail->cq); if (ret != 0) { return ret; } @@ -5598,8 +5486,8 @@ static int get_ep(nccl_net_ofi_device_t *base_dev, /* Initialize reference count */ ep->ref_cnt = 0; - ep->bounce_buff_size = NCCL_OFI_MAX(sizeof(nccl_net_ofi_rdma_ctrl_msg_t), - eager_max_size); + ep->bounce_buff_size = NCCL_OFI_MAX(NCCL_OFI_MAX(sizeof(nccl_net_ofi_rdma_ctrl_msg_t), eager_max_size), + sizeof(nccl_ofi_rdma_connection_info_t)); /* Store endpoint in thread-local variable */ pthread_setspecific(device->ep_key, (void *)ep); @@ -5717,43 +5605,6 @@ static int init_device_rail_ofi_resources(nccl_net_ofi_rdma_device_rail_t *rail_ return ret; } -/* - * @brief Calculate maximum number of comm IDs per device - * - * @param device - * Rdma device - * - * @return 0, on success - * -EINVAL, on error - */ -static int calculate_num_comm_ids(nccl_net_ofi_rdma_device_t *device) -{ - int ret = 0; - int ofi_tag_leading_zeroes = 0, ofi_tag_bits_for_ring_id = 64; - nccl_net_ofi_rdma_device_rail_t *dev_rail = get_device_rail(device, 0); - - /* Determine if any tag bits are used by provider */ - while (!((dev_rail->info->ep_attr->mem_tag_format << ofi_tag_leading_zeroes++) & - (uint64_t) OFI_HIGHEST_TAG_BIT) && - (ofi_tag_bits_for_ring_id >= MIN_TAG_BITS_FOR_RING_ID)) { - ofi_tag_bits_for_ring_id--; - } - - if (OFI_UNLIKELY(ofi_tag_bits_for_ring_id < MIN_TAG_BITS_FOR_RING_ID)) { - NCCL_OFI_WARN("Provider %s does not provide enough tag bits %d for ring ID. Minimum required is %d", - dev_rail->info->fabric_attr->prov_name, - ofi_tag_bits_for_ring_id, - MIN_TAG_BITS_FOR_RING_ID); - ret = -EINVAL; - } else { - /* Set maximum tag information; Reserving 2 bits for control information */ - /* RDMA write protocol has maximum 12-bit tag due to 32-bit immediate data restriction */ - device->num_comm_ids = (uint64_t)NCCL_OFI_RDMA_MAX_COMMS; - } - - return ret; -} - /* * @brief Allocates and initializes various libfabric resources to make rdma * device ready for endpoint creation. @@ -5764,10 +5615,7 @@ static int device_prepare_for_connection(nccl_net_ofi_rdma_device_t *device) nccl_net_ofi_rdma_device_rail_t *begin = device->device_rails; nccl_net_ofi_rdma_device_rail_t *end = device->device_rails + device->num_rails; - ret = calculate_num_comm_ids(device); - if (ret != 0) { - return ret; - } + device->num_comm_ids = (uint32_t)NCCL_OFI_RDMA_MAX_COMMS; for (; begin != end; ++begin) { ret = init_device_rail_ofi_resources(begin); @@ -5776,7 +5624,7 @@ static int device_prepare_for_connection(nccl_net_ofi_rdma_device_t *device) } } - return 0; + return ret; } /* @@ -5882,7 +5730,7 @@ static void get_hints(struct fi_info *hints) hints->caps = 0; /* Primary Capabilities */ - hints->caps = FI_MSG | FI_RMA | FI_TAGGED | FI_HMEM; + hints->caps = FI_MSG | FI_RMA | FI_HMEM; /* Primary Modifiers. Explicitly do not request any primary * modifiers, as we need send/recv, read, and write @@ -5895,8 +5743,8 @@ static void get_hints(struct fi_info *hints) hints->mode = 0; - hints->tx_attr->msg_order = FI_ORDER_SAS; - hints->rx_attr->msg_order = FI_ORDER_SAS; + hints->tx_attr->msg_order = FI_ORDER_NONE; + hints->rx_attr->msg_order = FI_ORDER_NONE; hints->ep_attr->type = FI_EP_RDM; @@ -6145,7 +5993,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, goto exit; - error:; + error: if (base_devs) { for (nccl_net_ofi_device_t **base_dev = base_devs; base_dev != base_devs + num_devs; ++base_dev) { nccl_net_ofi_rdma_device_t *device = diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index ebcb3fb95..dcbd811d0 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -1478,7 +1478,7 @@ static int listen(nccl_net_ofi_ep_t *base_ep, local_ep_name = get_local_address(ep->ofi_ep); memcpy(handle->ep_name, local_ep_name, MAX_EP_ADDR); - handle->comm_id = tag; + handle->comm_id = (uint32_t)tag; /* Insert local EP address to AV. This will be used to issue local read operations */ num_addrs = fi_av_insert(ep->av, (void *)local_ep_name, 1, @@ -1727,7 +1727,7 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, /* Get tag and remote name from handle */ memcpy(&remote_ep_addr, handle->ep_name, MAX_EP_ADDR); - memcpy(&tag, &handle->comm_id, sizeof(tag)); + memcpy(&tag, &handle->comm_id, sizeof(handle->comm_id)); if (tag < 1 || tag > max_tag) { NCCL_OFI_WARN("Received an invalid tag %lu for device %d", tag, device->base.dev_id); @@ -2105,9 +2105,8 @@ static int get_ep(nccl_net_ofi_device_t *base_dev, } if (ep->ref_cnt == 0) { - ret = nccl_ofi_ofiutils_init_connection(selected_api_version, device->info, - device->domain, &ep->ofi_ep, &ep->av, - &ep->cq); + ret = nccl_ofi_ofiutils_init_connection(selected_api_version, device->info, device->domain, &ep->ofi_ep, + &ep->av, &ep->cq); if (ret != 0) { goto unlock; }