Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch to untagged send/recv and remove SAS ordering requirement in the RDMA protocol #361

Merged
merged 6 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -191,10 +191,13 @@ typedef struct nccl_ofi_connection_info {
uint64_t connect_to_self;
nccl_net_ofi_req_t* req;
} nccl_ofi_connection_info_t;
/* Since this is a message on the wire, check that it has the expected size */
_Static_assert(sizeof(nccl_ofi_connection_info_t) == 80,
"Wrong size for SENDRECV connect message");

typedef struct nccl_net_ofi_conn_handle {
char ep_name[MAX_EP_ADDR];
uint64_t comm_id;
uint32_t comm_id;
/* Save temporary communicator state when creating send communicator */
save_comm_state_t state;
} nccl_net_ofi_conn_handle_t;
Expand Down
52 changes: 43 additions & 9 deletions include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@ typedef enum nccl_net_ofi_rdma_req_type {
NCCL_OFI_RDMA_SEND_CONN_RESP,
} nccl_net_ofi_rdma_req_type_t;

typedef enum nccl_ofi_rdma_msg_type {
NCCL_OFI_RDMA_MSG_CONN,
NCCL_OFI_RDMA_MSG_CONN_RESP,
NCCL_OFI_RDMA_MSG_CTRL,
NCCL_OFI_RDMA_MSG_EAGER
} nccl_ofi_rdma_msg_type_t;

/*
* @brief Rdma memory registration handle

Expand All @@ -77,10 +84,23 @@ typedef struct nccl_net_ofi_rdma_mr_handle {
/* Contents of ctrl message sent from receiver to sender to advertise
destination buffer */
typedef struct nccl_net_ofi_rdma_ctrl_msg {
/* Message type, must be NCCL_OFI_RDMA_MSG_CTRL */
uint16_t type;

/* Message sequence number */
uint16_t msg_seq_num;

/* A comm identitifer that uniquely identifies the comm
* on the receiver side */
uint32_t remote_comm_id;

uint64_t buff_addr;
uint64_t buff_len;
uint64_t buff_mr_key[MAX_NUM_RAILS];
} nccl_net_ofi_rdma_ctrl_msg_t;
/* Since this is a message on the wire, check that it has the expected size */
_Static_assert(sizeof(nccl_net_ofi_rdma_ctrl_msg_t) == 56,
"Wrong size for RDMA Control message");

/* Structure used to store control messages in a free list */
typedef struct nccl_net_ofi_rdma_ctrl_fl_item {
Expand Down Expand Up @@ -296,18 +316,30 @@ typedef struct nccl_ofi_rdma_ep_name {
* connection information.
*/
typedef struct nccl_ofi_rdma_connection_info {
/* Message type
* either NCCL_OFI_RDMA_MSG_CONN or NCCL_OFI_RDMA_MSG_CONN_RESP
*/
uint16_t type;

/* Number of rails */
uint16_t num_rails;

/* A comm identitifer that uniquely identifies the comm on the sender
side. The receiver must use this ID when sending messages to sender */
uint64_t local_comm_id;
uint32_t local_comm_id;

/* Number of rails */
int num_rails;
/* A comm identitifer that uniquely identifies the comm
* on the receiver side */
uint32_t remote_comm_id;

/* Array of `MAX_NUM_RAILS` `nccl_ofi_rdma_ep_name_t`
* structs. The member `num_rails` indicates the number of
* entries that are in use. */
nccl_ofi_rdma_ep_name_t ep_names[MAX_NUM_RAILS];
} nccl_ofi_rdma_connection_info_t;
/* Since this is a message on the wire, check that it has the expected size */
_Static_assert(sizeof(nccl_ofi_rdma_connection_info_t) == 236,
"Wrong size for RDMA connect message");

/*
* @brief Send communicator rail
Expand Down Expand Up @@ -341,9 +373,10 @@ typedef struct nccl_net_ofi_rdma_send_comm {
nccl_ofi_freelist_t *nccl_ofi_reqs_fl;

/* Comm ID provided by the local endpoint */
uint64_t local_comm_id;
uint32_t local_comm_id;

/* Comm ID provided by remote endpoint */
uint64_t remote_comm_id;
uint32_t remote_comm_id;

/* Request to receive connect response message to finalize
* connection establishment */
Expand Down Expand Up @@ -419,9 +452,10 @@ typedef struct nccl_net_ofi_rdma_recv_comm {
nccl_ofi_freelist_t *nccl_ofi_reqs_fl;

/* Comm ID provided by the local endpoint */
uint64_t local_comm_id;
uint32_t local_comm_id;

/* Comm ID provided by remote endpoint */
uint64_t remote_comm_id;
uint32_t remote_comm_id;

/* The flush buffer */
nccl_net_ofi_rdma_flush_buffer_t flush_buff;
Expand All @@ -447,7 +481,7 @@ typedef struct nccl_net_ofi_rdma_listen_comm {
nccl_net_ofi_listen_comm_t base;

/* Comm ID provided by local endpoint */
uint64_t comm_id;
uint32_t comm_id;
struct fid_ep *leader_local_ep;

/* Communicator created while accept routine is executed */
Expand Down Expand Up @@ -623,7 +657,7 @@ typedef struct nccl_net_ofi_rdma_device {
char *prov_name;

/* Maximum number of supported communicator IDs */
uint64_t num_comm_ids;
uint32_t num_comm_ids;

/* Memory registration key pool */
nccl_ofi_idpool_t key_pool;
Expand Down
8 changes: 4 additions & 4 deletions include/tracepoint.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,21 +136,21 @@ LTTNG_UST_TRACEPOINT_EVENT(
Recv,
LTTNG_UST_TP_ARGS(
int, dev,
int, tag,
int, comm_id,
int, size,
void *, request,
void *, nccl_req
),
LTTNG_UST_TP_FIELDS(
lttng_ust_field_integer(int, dev, dev)
lttng_ust_field_integer(int, tag, tag)
lttng_ust_field_integer(int, comm_id, comm_id)
lttng_ust_field_integer(int, size, size)
lttng_ust_field_integer_hex(uint64_t, request, (uint64_t)request)
lttng_ust_field_integer_hex(uint64_t, nccl_req, (uint64_t)nccl_req)
)
)
#define NCCL_OFI_TRACE_RECV(dev, tag, size, request, nccl_req) \
lttng_ust_tracepoint(nccl_ofi_plugin, Recv, dev, tag, size, request, nccl_req)
#define NCCL_OFI_TRACE_RECV(dev, comm_id, size, request, nccl_req) \
lttng_ust_tracepoint(nccl_ofi_plugin, Recv, dev, comm_id, size, request, nccl_req)

LTTNG_UST_TRACEPOINT_EVENT(
nccl_ofi_plugin,
Expand Down
7 changes: 6 additions & 1 deletion src/nccl_ofi_ofiutils.c
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,12 @@ int nccl_ofi_ofiutils_init_connection(int api_version, struct fi_info *info, str
goto error;
}

cq_attr.format = FI_CQ_FORMAT_TAGGED;
if (info->caps & FI_TAGGED) {
cq_attr.format = FI_CQ_FORMAT_TAGGED;
} else {
cq_attr.format = FI_CQ_FORMAT_DATA;
}

ret = fi_cq_open(domain, &cq_attr, cq, NULL);
if (OFI_UNLIKELY(ret != 0)) {
NCCL_OFI_WARN("Couldn't open CQ. RC: %d, ERROR: %s",
Expand Down
Loading
Loading