diff --git a/include/Makefile.am b/include/Makefile.am index 749dbc9f9..ba34b6ced 100644 --- a/include/Makefile.am +++ b/include/Makefile.am @@ -27,6 +27,7 @@ noinst_HEADERS = \ nccl_ofi_ofiutils.h \ nccl_ofi_tracepoint.h \ tracing_impl/lttng.h \ + tracing_impl/nvtx.h \ nccl-headers/net.h \ nccl-headers/error.h \ nccl-headers/nvidia/err.h \ diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index d7b468352..8449fdc63 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -19,6 +19,7 @@ extern "C" { #include "nccl_ofi_deque.h" #include "nccl_ofi_freelist.h" #include "nccl_ofi_idpool.h" +#include "nccl_ofi_tracepoint.h" /* Maximum number of rails supported. This defines the size of * messages exchanged during connection establishment (linear @@ -170,6 +171,10 @@ typedef struct { /* Total number of completions. Expect one completion for receiving the * control message and one completion for each send segment. */ int total_num_compls; +#if HAVE_NVTX_TRACING + nvtxRangeId_t trace_id; + nvtxRangeId_t seg_trace_id[MAX_NUM_RAILS]; +#endif } rdma_req_send_data_t; /* @@ -184,6 +189,9 @@ typedef struct { nccl_net_ofi_schedule_t *ctrl_schedule; /* Pointer to recv parent request */ nccl_net_ofi_rdma_req_t *recv_req; +#if HAVE_NVTX_TRACING + nvtxRangeId_t trace_id; +#endif } rdma_req_send_ctrl_data_t; typedef struct { @@ -224,6 +232,9 @@ typedef struct { * For eager messages, the second completion will be received * when the local read into the destination buffer is complete */ int total_num_compls; +#if HAVE_NVTX_TRACING + nvtxRangeId_t trace_id; +#endif } rdma_req_recv_data_t; /* @@ -403,8 +414,13 @@ typedef struct nccl_net_ofi_rdma_send_comm { * and `num_init_rails' is adjusted. */ int num_init_rails; +#if HAVE_NVTX_TRACING + nvtxDomainHandle_t nvtx_domain[N_NVTX_DOMAIN_PER_COMM]; +#endif + /* Array of `num_rails` communicator rails */ nccl_net_ofi_rdma_send_comm_rail_t rails[]; + } nccl_net_ofi_rdma_send_comm_t; /* @@ -465,6 +481,10 @@ typedef struct nccl_net_ofi_rdma_recv_comm { /* Free list to track control buffers, for sending RDMA control messages */ nccl_ofi_freelist_t *ctrl_buff_fl; +#if HAVE_NVTX_TRACING + nvtxDomainHandle_t nvtx_domain[N_NVTX_DOMAIN_PER_COMM]; +#endif + /* Number of rails */ int num_rails; @@ -659,6 +679,10 @@ typedef struct nccl_net_ofi_rdma_device { /* Memory registration key pool */ nccl_ofi_idpool_t key_pool; + +#if HAVE_NVTX_TRACING + nvtxDomainHandle_t nvtx_domain[MAX_NUM_RAILS]; +#endif } nccl_net_ofi_rdma_device_t; /* diff --git a/include/nccl_ofi_tracepoint.h b/include/nccl_ofi_tracepoint.h index 94b48f7ce..8ea67fa62 100644 --- a/include/nccl_ofi_tracepoint.h +++ b/include/nccl_ofi_tracepoint.h @@ -7,6 +7,7 @@ #define NCCL_OFI_TRACEPOINT_H_ #include "config.h" +#include "tracing_impl/nvtx.h" #include "tracing_impl/lttng.h" /***** SENDRECV PROTOCOL *****/ @@ -27,52 +28,83 @@ } while(0) /***** RDMA PROTOCL *****/ +#define N_NVTX_DOMAIN_PER_COMM 8 + +#define NCCL_OFI_NVTX_TRACE_PER_COMM 1 +#define NCCL_OFI_NVTX_TRACE_PER_DEV 0 + #define NCCL_OFI_TRACE_SEND(dev, size, comm, msg_seq_num, request, nccl_req) do { \ lttng_ust_tracepoint(nccl_ofi_plugin, Send, dev, size, comm, msg_seq_num, request, nccl_req); \ - } while(0) + NCCL_OFI_TRACE_SEND_NVTX(dev, size, comm, msg_seq_num, request, nccl_req); \ +} while(0) + +#define NCCL_OFI_TRACE_SEND_END(request) do { \ + NCCL_OFI_TRACE_SEND_END_NVTX(request); \ +} while(0) #define NCCL_OFI_TRACE_SEND_CTRL_RECV(dev, rail_id, comm, msg_seq_num) do { \ - lttng_ust_tracepoint(nccl_ofi_plugin, Send_ctrl_recv, dev, rail_id, comm, msg_seq_num); \ - } while (0) + lttng_ust_tracepoint(nccl_ofi_plugin, Send_ctrl_recv, dev, rail_id, comm, msg_seq_num); \ + NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX(dev, rail_id, comm, msg_seq_num); \ +} while (0) + +#define NCCL_OFI_TRACE_SEND_CTRL_START(dev, rail_id, comm, req, msg_seq_num) do { \ + NCCL_OFI_TRACE_SEND_CTRL_START_NVTX(dev, rail_id, comm, req, msg_seq_num); \ +} while (0); + +#define NCCL_OFI_TRACE_SEND_CTRL_END(dev, rail_id, comm, req, msg_seq_num) do { \ + NCCL_OFI_TRACE_SEND_CTRL_END_NVTX(dev, rail_id, comm, req, msg_seq_num); \ +} while (0); #define NCCL_OFI_TRACE_SEND_WRITE_SEG_START(dev, rail_id, size, comm, msg_seq_num, request) do { \ - lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_start, dev, rail_id, size, comm, msg_seq_num, request); \ - } while(0) + lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_start, dev, rail_id, size, comm, msg_seq_num, request); \ + NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX(dev, rail_id, size, comm, msg_seq_num, request); \ +} while(0) #define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE(dev, rail_id, comm, msg_seq_num, request) do { \ lttng_ust_tracepoint(nccl_ofi_plugin, Send_write_segment_complete, dev, rail_id, comm, msg_seq_num, request); \ - } while(0) + NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX(dev, rail_id, comm, msg_seq_num, request); \ +} while(0) #define NCCL_OFI_TRACE_RECV(dev, tag, size, request, nccl_req) do { \ lttng_ust_tracepoint(nccl_ofi_plugin, Recv, dev, tag, size, request, nccl_req); \ - } while(0) + NCCL_OFI_TRACE_RECV_NVTX(dev, tag, size, request, nccl_req); \ +} while(0) + +#define NCCL_OFI_TRACE_RECV_END(request) do { \ + NCCL_OFI_TRACE_RECV_END_NVTX(request); \ +} while(0) #define NCCL_OFI_TRACE_RECV_CTRL_SEND_COMPLETE(request) do { \ - lttng_ust_tracepoint(nccl_ofi_plugin, Recv_ctrl_send_complete, request); \ - } while(0) + lttng_ust_tracepoint(nccl_ofi_plugin, Recv_ctrl_send_complete, request); \ +} while(0) #define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE(dev, rail_id, size, request) do { \ lttng_ust_tracepoint(nccl_ofi_plugin, Recv_segment_complete, dev, rail_id, size, request); \ - } while(0) + NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX(dev, rail_id, size, request); \ +} while(0) #define NCCL_OFI_TRACE_EAGER_RECV(dev, rail_id, comm, msg_seq_num) do { \ - lttng_ust_tracepoint(nccl_ofi_plugin, Eager_recv, dev, rail_id, comm, msg_seq_num); \ - } while(0) + lttng_ust_tracepoint(nccl_ofi_plugin, Eager_recv, dev, rail_id, comm, msg_seq_num); \ + NCCL_OFI_TRACE_EAGER_RECV_NVTX(dev, rail_id, comm, msg_seq_num); \ +} while(0) #define NCCL_OFI_TRACE_COMPLETIONS(request,ctx) do { \ lttng_ust_tracepoint(nccl_ofi_plugin, ProcessCompletions, request,ctx); \ - } while(0) +} while(0) #define NCCL_OFI_TRACE_FLUSH(request, nccl_req) do { \ lttng_ust_tracepoint(nccl_ofi_plugin, Flush, request, nccl_req); \ - } while(0) + NCCL_OFI_TRACE_FLUSH_NVTX(request, nccl_req); \ +} while(0) #define NCCL_OFI_TRACE_PENDING_INSERT(request) do { \ lttng_ust_tracepoint(nccl_ofi_plugin, Pending_queue_insert, request); \ - } while(0) + NCCL_OFI_TRACE_PENDING_INSERT_NVTX(request); \ +} while(0) #define NCCL_OFI_TRACE_PENDING_REMOVE(request) do { \ lttng_ust_tracepoint(nccl_ofi_plugin, Pending_queue_remove, request); \ - } while(0) + NCCL_OFI_TRACE_PENDING_REMOVE_NVTX(request); \ +} while(0) #endif /* NCCL_OFI_TRACEPOINT_H_ */ \ No newline at end of file diff --git a/include/tracing_impl/nvtx.h b/include/tracing_impl/nvtx.h new file mode 100644 index 000000000..b957865dd --- /dev/null +++ b/include/tracing_impl/nvtx.h @@ -0,0 +1,190 @@ +/* + * Copyright (c) 2022-2024 Amazon.com, Inc. or its affiliates. All rights reserved. + */ + +#ifndef NVTX_H +#define NVTX_H + +#if HAVE_NVTX_TRACING +#include "nvToolsExt.h" + +static inline void nvtx_mark_domain(nvtxDomainHandle_t domain, const char* name, uint32_t color) +{ + const nvtxEventAttributes_t eventAttrib = { + .version = NVTX_VERSION, + .size = NVTX_EVENT_ATTRIB_STRUCT_SIZE, + .colorType = NVTX_COLOR_ARGB, + .color = color, + .messageType = NVTX_MESSAGE_TYPE_ASCII, + .message = { .ascii = name }, + }; + nvtxDomainMarkEx(domain, &eventAttrib); +} + +static inline nvtxRangeId_t nvtx_start_domain(bool have_domain, nvtxDomainHandle_t domain, const char* name, uint32_t color) { + const nvtxEventAttributes_t eventAttrib = { + .version = NVTX_VERSION, + .size = NVTX_EVENT_ATTRIB_STRUCT_SIZE, + .colorType = NVTX_COLOR_ARGB, + .color = color, + .messageType = NVTX_MESSAGE_TYPE_ASCII, + .message = { .ascii = name }, + }; + if (have_domain) + return nvtxDomainRangeStartEx(domain, &eventAttrib); + else + return nvtxRangeStartEx(&eventAttrib); +} + +static inline nvtxRangeId_t nvtx_start(const char* name, uint32_t color) { + return nvtx_start_domain(false, 0, name, color); +} + +static inline void nvtx_end_domain(nvtxDomainHandle_t domain, nvtxRangeId_t id) { + nvtxDomainRangeEnd(domain, id); +} + +static inline void nvtx_end(nvtxRangeId_t id) { + nvtxRangeEnd(id); +} + +#define NCCL_OFI_TRACE_SEND_NVTX(dev, size, comm, msg_seq_num, request, nccl_req) do { \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_send_comm_t*)comm) \ + ->nvtx_domain[msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + get_send_data(request)->trace_id = nvtx_start_domain(true, handle, "Send", 0xeb9234); \ + } \ +} while (0) + +#define NCCL_OFI_TRACE_SEND_END_NVTX(request) do { \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_send_comm_t*)(request->comm)) \ + ->nvtx_domain[request->msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + nvtx_end_domain(handle, get_send_data(request)->trace_id); \ + } \ +} while(0) + +#define NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX(dev, rail_id, comm, msg_seq_num) do { \ + nvtxDomainHandle_t handle; \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + nvtx_mark_domain(handle, "Send_ctrl_recv", 0x00ffff); \ + } \ + if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ + handle = ((nccl_net_ofi_rdma_device_t*)(comm->base.base.ep->device))->nvtx_domain[rail_id]; \ + nvtx_mark_domain(handle, "Send_ctrl_recv", 0x00ffff); \ + } \ +} while (0) + +#define NCCL_OFI_TRACE_SEND_CTRL_START_NVTX(dev, rail_id, comm, req, msg_seq_num) do { \ + nvtxDomainHandle_t handle; \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + handle = ((nccl_net_ofi_rdma_recv_comm_t *)comm)->nvtx_domain[msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + get_send_ctrl_data(req)->trace_id = nvtx_start_domain(true, handle, "Send_ctrl_start", 0x00ffff); \ + } \ + if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ + handle = ((nccl_net_ofi_rdma_device_t*)(comm->ep->device))->nvtx_domain[rail_id]; \ + get_send_ctrl_data(req)->trace_id = nvtx_start_domain(true, handle, "Send_ctrl_start", 0x00ffff); \ + } \ +} while (0) + +#define NCCL_OFI_TRACE_SEND_CTRL_END_NVTX(dev, rail_id, comm, req, msg_seq_num) do { \ + nvtxDomainHandle_t handle; \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + handle = ((nccl_net_ofi_rdma_recv_comm_t *)comm)->nvtx_domain[msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + nvtx_end_domain(handle, get_send_ctrl_data(req)->trace_id); \ + } \ + if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ + handle = ((nccl_net_ofi_rdma_device_t*)(comm->ep->device))->nvtx_domain[rail_id]; \ + nvtx_end_domain(handle, get_send_ctrl_data(req)->trace_id);\ + } \ +} while (0) + +#define NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX(dev, rail_id, size, comm, msg_seq_num, request) do { \ + nvtxDomainHandle_t handle; \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_write_seg", 0xff0000); \ + } \ + if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ + handle = ((nccl_net_ofi_rdma_device_t*)(comm->ep->device))->nvtx_domain[rail_id]; \ + get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_write_seg", 0xff0000); \ + } \ +} while(0) + +#define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX(dev, rail_id, comm, msg_seq_num, request) do { \ + nvtxDomainHandle_t handle; \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \ + } \ + if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ + handle = ((nccl_net_ofi_rdma_device_t*)(comm->ep->device))->nvtx_domain[rail_id]; \ + nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \ + } \ +} while(0) + +#define NCCL_OFI_TRACE_RECV_NVTX(dev, tag, size, request, nccl_req) do { \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm) \ + ->nvtx_domain[msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + get_recv_data(request)->trace_id = nvtx_start_domain(true, handle, "Recv", 0x34EB37); \ + } \ +} while(0) + +#define NCCL_OFI_TRACE_RECV_END_NVTX(request) do { \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm) \ + ->nvtx_domain[request->msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + nvtx_end_domain(handle, get_recv_data(request)->trace_id); \ + } \ +} while(0) + +#define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX(dev, rail_id, size, request) do { \ + nvtxDomainHandle_t handle; \ + if (NCCL_OFI_NVTX_TRACE_PER_COMM) { \ + handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm)->nvtx_domain[request->msg_seq_num % N_NVTX_DOMAIN_PER_COMM]; \ + nvtx_mark_domain(handle, "Recv_segment_complete", 0xff0000); \ + } \ + if (NCCL_OFI_NVTX_TRACE_PER_DEV) { \ + handle = ((nccl_net_ofi_rdma_device_t*)(request->comm->ep->device))->nvtx_domain[rail_id]; \ + nvtx_mark_domain(handle, "Recv_segment_complete", 0xff0000); \ + } \ +} while(0) + +#define NCCL_OFI_TRACE_EAGER_RECV_NVTX(dev, rail_id, comm, msg_seq_num) do { \ + nvtx_mark_domain(NULL, "Eager_recv", 0x0000FF); \ +} while(0) + +#define NCCL_OFI_TRACE_FLUSH_NVTX(request, nccl_req) do { \ + nvtx_mark_domain(NULL, "Flush", 0xA52A2A); \ +} while(0) + +#define NCCL_OFI_TRACE_PENDING_INSERT_NVTX(request) do { \ + nvtx_mark_domain(NULL, "Pending_insert", 0xFF8C00); \ +} while(0) + +#define NCCL_OFI_TRACE_PENDING_REMOVE_NVTX(request) do { \ + nvtx_mark_domain(NULL, "Pending_remove", 0xFF8C00); \ +} while(0) + +#else + +#define NCCL_OFI_TRACE_SEND_NVTX(...) +#define NCCL_OFI_TRACE_SEND_END_NVTX(...) +#define NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX(...) +#define NCCL_OFI_TRACE_SEND_CTRL_START_NVTX(...) +#define NCCL_OFI_TRACE_SEND_CTRL_END_NVTX(...) +#define NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX(...) +#define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX(...) +#define NCCL_OFI_TRACE_RECV_NVTX(...) +#define NCCL_OFI_TRACE_RECV_END_NVTX(...) +#define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX(...) +#define NCCL_OFI_TRACE_EAGER_RECV_NVTX(...) +#define NCCL_OFI_TRACE_FLUSH_NVTX(...) +#define NCCL_OFI_TRACE_PENDING_INSERT_NVTX(...) +#define NCCL_OFI_TRACE_PENDING_REMOVE_NVTX(...) + +#endif + +#endif /* NVTX_H */ \ No newline at end of file diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 8c7a8eff4..72b02b4fe 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -1233,7 +1233,7 @@ static inline int handle_bounce_recv(nccl_ofi_rdma_msg_type_t msg_type, nccl_net ctrl_msg = get_bounce_ctrl_msg(bounce_data->bounce_fl_item); s_comm = get_send_comm(ep, ctrl_msg->remote_comm_id); - NCCL_OFI_TRACE_SEND_CTRL_RECV(r_comm->base.base.dev_id, rail_id, s_comm, ctrl_msg->msg_seq_num); + NCCL_OFI_TRACE_SEND_CTRL_RECV(s_comm->base.base.dev_id, rail_id, s_comm, ctrl_msg->msg_seq_num); ret = handle_ctrl_recv(s_comm, ctrl_msg->msg_seq_num, bounce_req); if (OFI_UNLIKELY(ret != 0)) { @@ -1428,6 +1428,7 @@ static inline int process_completions(struct fi_cq_data_entry *cq_entry, uint64_ } else if (req->type == NCCL_OFI_RDMA_SEND_CTRL) { /* CTRL message send completion */ + NCCL_OFI_TRACE_SEND_CTRL_END(req->dev_id, rail->rail_id, req->comm, req, req->msg_seq_num); ret = set_send_ctrl_completed(req); } else if (req->type == NCCL_OFI_RDMA_SEND) { @@ -2263,6 +2264,7 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) /* Determine whether the request has finished without error and free if done */ if (OFI_LIKELY(req->state == NCCL_OFI_RDMA_REQ_COMPLETED)) { + size_t req_size; ret = pthread_mutex_lock(&req->req_lock); if (OFI_UNLIKELY(ret != 0)) { @@ -2307,6 +2309,12 @@ static int test(nccl_net_ofi_req_t *base_req, int *done, int *size) } } + if (req->type == NCCL_OFI_RDMA_SEND) { + NCCL_OFI_TRACE_SEND_END(req); + } else if (req->type == NCCL_OFI_RDMA_RECV) { + NCCL_OFI_TRACE_RECV_END(req); + } + assert(req->free); req->free(req, true); } else if (OFI_UNLIKELY(req->state == NCCL_OFI_RDMA_REQ_ERROR)) { @@ -3229,6 +3237,14 @@ static int recv_close(nccl_net_ofi_recv_comm_t *recv_comm) goto exit; } + /* Destroy domain */ +#if HAVE_NVTX_TRACING + if (NCCL_OFI_NVTX_TRACE_PER_COMM) + for (int i = 0; i < N_NVTX_DOMAIN_PER_COMM; ++i) { + nvtxDomainDestroy(r_comm->nvtx_domain[i]); + } +#endif + /* Not strictly necessary, but why leave dangling pointers? */ nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *) base_ep; set_comm(ep, r_comm->local_comm_id, NULL); @@ -3547,6 +3563,17 @@ static nccl_net_ofi_rdma_recv_comm_t *prepare_recv_comm(nccl_net_ofi_rdma_device return NULL; } +#if HAVE_NVTX_TRACING + if (NCCL_OFI_NVTX_TRACE_PER_COMM) + for (int i = 0; i < N_NVTX_DOMAIN_PER_COMM; ++i) + { + /* Create nvtx domain */ + char name[64]; + snprintf(name, 64, "aws-ofi-nccl r_comm %p_%d", r_comm, i); + r_comm->nvtx_domain[i] = nvtxDomainCreateA(name); + } +#endif + return r_comm; error: @@ -4282,6 +4309,7 @@ static int post_rdma_ctrl(nccl_net_ofi_rdma_req_t *req) assert(xfer_info->rail_id < mr_handle->num_rails); void *desc = fi_mr_desc(mr_handle->mr[xfer_info->rail_id]); + NCCL_OFI_TRACE_SEND_CTRL_START(req->dev_id, xfer_info->rail_id, req->comm, req, req->msg_seq_num); ssize_t rc = fi_send(comm_rail->local_ep, &ctrl_fl_item->ctrl_msg, sizeof(nccl_net_ofi_rdma_ctrl_msg_t), desc, comm_rail->remote_addr, req); @@ -4664,6 +4692,14 @@ static int send_close(nccl_net_ofi_rdma_send_comm_t *s_comm) NCCL_OFI_WARN("Error freeing communicator ID %"PRIu32"", s_comm->local_comm_id); } + /* Destroy domain */ +#if HAVE_NVTX_TRACING + if (NCCL_OFI_NVTX_TRACE_PER_COMM) + for (int i = 0; i < N_NVTX_DOMAIN_PER_COMM; ++i) { + nvtxDomainDestroy(s_comm->nvtx_domain[i]); + } +#endif + free(s_comm); exit: @@ -4969,6 +5005,16 @@ static inline int create_send_comm(nccl_net_ofi_conn_handle_t *handle, goto error; } +#if HAVE_NVTX_TRACING + if (NCCL_OFI_NVTX_TRACE_PER_COMM) + for (int i = 0; i < N_NVTX_DOMAIN_PER_COMM; ++i) + { + /* Create nvtx domain */ + char name[64]; + snprintf(name, 64, "aws-ofi-nccl s_comm %p_%d", ret_s_comm, i); + ret_s_comm->nvtx_domain[i] = nvtxDomainCreateA(name); + } +#endif *s_comm = ret_s_comm; return ret; @@ -5990,6 +6036,18 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, if (ret != 0) { goto error; } + + /* NVTX domain */ +#if HAVE_NVTX_TRACING + if (NCCL_OFI_NVTX_TRACE_PER_DEV) { + for (int i = 0; i < device->num_rails; ++i) { + /* Create nvtx domain */ + char name[64]; + snprintf(name, 64, "aws-ofi-nccl dev %d_%d", dev_id, i); + device->nvtx_domain[i] = nvtxDomainCreateA(name); + } + } +#endif } goto exit;