diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index dd8f25008..807b27659 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -498,6 +498,8 @@ typedef struct nccl_net_ofi_rdma_send_comm { nccl_net_ofi_send_comm_t base; uint64_t num_inflight_reqs; + uint64_t num_inflight_writes; + nccl_ofi_freelist_t *nccl_ofi_reqs_fl; /* Comm ID provided by the local endpoint */ diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 9923e65b4..ce95fbdf2 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -2058,6 +2058,15 @@ static inline int free_write_req(nccl_net_ofi_rdma_req_t *req, assert(req->type == NCCL_OFI_RDMA_WRITE); nccl_net_ofi_rdma_send_comm_t *s_comm = (nccl_net_ofi_rdma_send_comm_t *)req->comm; + /* free is going to be called inside of test(), which will + happen in a time when NCCL guarantees no other thread will + be accessing the communicator. So no mutex protections are + required if we do it here. Better would be to do this as + soon as we get the CQE for this request, but that would + require atomics or locks, which isn't worth it today. But + if we ever refactor the locking strategy, we should revisit + this. */ + (s_comm->num_inflight_writes)--; return free_base_req(&s_comm->num_inflight_reqs, s_comm->nccl_ofi_reqs_fl, req, dec_inflight_reqs); } @@ -5880,7 +5889,7 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t /* Determine if this should be sent eagerly. */ eager = false; - if ((!have_ctrl && (size_t)size <= eager_max_size) || (size == 0)) { + if ((!have_ctrl && (size_t)size <= eager_max_size && s_comm->num_inflight_writes == 0) || (size == 0)) { eager = true; } @@ -5920,6 +5929,10 @@ static int send(nccl_net_ofi_send_comm_t *send_comm, void *data, int size, int t */ (s_comm->num_inflight_reqs)++; + if (!eager) { + (s_comm->num_inflight_writes)++; + } + NCCL_OFI_TRACE_SEND(req->dev_id, size, s_comm, msg_seq_num, req, base_req); /* Try posting RDMA write for received RDMA control messages */