diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 5c391b513..7902fb2b7 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -498,6 +498,8 @@ typedef struct nccl_net_ofi_rdma_send_comm { nccl_net_ofi_send_comm_t base; uint64_t num_inflight_reqs; + uint64_t num_inflight_writes; + nccl_ofi_freelist_t *nccl_ofi_reqs_fl; /* Comm ID provided by the local endpoint */ diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 65466806a..f6c2d9e9d 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -2074,6 +2074,18 @@ static inline int free_send_req(nccl_net_ofi_rdma_req_t *req, send_data = get_send_data(req); + if (!send_data->eager) { + /* free is going to be called inside of test(), which will + happen in a time when NCCL guarantees no other thread will + be accessing the communicator. So no mutex protections are + required if we do it here. Better would be to do this as + soon as we get the CQE for this request, but that would + require atomics or locks, which isn't worth it today. But + if we ever refactor the locking strategy, we should revisit + this. */ + (s_comm->num_inflight_writes)--; + } + if (send_data->schedule) { nccl_net_ofi_rdma_device_t *device = rdma_req_get_device(req); nccl_net_ofi_release_schedule(device->scheduler, send_data->schedule); @@ -5857,7 +5869,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t /* Determine if this should be sent eagerly. */ eager = false; - if ((!have_ctrl && (size_t)size <= eager_max_size) || (size == 0)) { + if ((!have_ctrl && (size_t)size <= eager_max_size && s_comm->num_inflight_writes == 0) || (size == 0)) { eager = true; } @@ -5897,6 +5909,10 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t */ (s_comm->num_inflight_reqs)++; + if (!eager) { + (s_comm->num_inflight_writes)++; + } + NCCL_OFI_TRACE_SEND(req->dev_id, size, s_comm, msg_seq_num, req, base_req); /* Try posting RDMA write for received RDMA control messages */ @@ -5958,6 +5974,7 @@ static int send_close_deferred(nccl_net_ofi_send_comm_t *send_comm) ret = -EINVAL; goto exit; } + assert (s_comm->num_inflight_writes == 0); s_comm->comm_active = false;