diff --git a/include/nccl_ofi.h b/include/nccl_ofi.h index 1bcb6e0ff..4155bfd6c 100644 --- a/include/nccl_ofi.h +++ b/include/nccl_ofi.h @@ -484,6 +484,19 @@ struct nccl_net_ofi_recv_comm { */ struct nccl_net_ofi_plugin { /* public */ + + /** + * Complete initialization of plugin + * + * When a plugin is first created, it should not create any + * network resources -- create is called to understand the + * configuration of the network and see which transports can + * run. The base code will pick one and call complete_init, + * at which point devices and network resources can be + * created. + */ + int (*complete_init)(nccl_net_ofi_plugin_t *plugin); + int (*assign_device)(nccl_net_ofi_plugin_t *plugin, size_t device_index, nccl_net_ofi_device_t *device); diff --git a/include/nccl_ofi_rdma.h b/include/nccl_ofi_rdma.h index 9620ff5a7..09761d6d9 100644 --- a/include/nccl_ofi_rdma.h +++ b/include/nccl_ofi_rdma.h @@ -834,11 +834,21 @@ typedef struct nccl_net_ofi_rdma_device { #endif } nccl_net_ofi_rdma_device_t; + +struct nccl_net_ofi_rdma_plugin { + nccl_net_ofi_plugin_t base; + + nccl_ofi_topo_t *topo; +}; +typedef struct nccl_net_ofi_rdma_plugin nccl_net_ofi_rdma_plugin_t; + + /* * @brief Initialize plugin with rdma protocol structures */ int nccl_net_ofi_rdma_init(const char *provider_filter, - nccl_net_ofi_plugin_t **plugin_p); + nccl_net_ofi_plugin_t **plugin_p, + bool *found_multi_rail); #ifdef __cplusplus } // End extern "C" diff --git a/include/nccl_ofi_sendrecv.h b/include/nccl_ofi_sendrecv.h index 7309b75f6..265d18e6f 100644 --- a/include/nccl_ofi_sendrecv.h +++ b/include/nccl_ofi_sendrecv.h @@ -220,6 +220,15 @@ typedef struct nccl_net_ofi_sendrecv_req { nccl_net_ofi_sendrecv_req_direction_t direction; } nccl_net_ofi_sendrecv_req_t; + +struct nccl_net_ofi_sendrecv_plugin { + nccl_net_ofi_plugin_t base; + + struct fi_info *provider_list; +}; +typedef struct nccl_net_ofi_sendrecv_plugin nccl_net_ofi_sendrecv_plugin_t; + + /* * @brief Initialize plugin with sendrecv protocol structures */ diff --git a/src/nccl_ofi_net.c b/src/nccl_ofi_net.c index ac4c13044..4066e8fb2 100644 --- a/src/nccl_ofi_net.c +++ b/src/nccl_ofi_net.c @@ -58,7 +58,7 @@ bool endpoint_mr = false; bool virt_addr_mr = false; /* Selected communication protocol. */ -const char *nccl_ofi_selected_protocol = "SENDRECV"; +const char *nccl_ofi_selected_protocol = NULL; /* Allocate one domain per process (0) or per thread (1) */ int domain_per_thread = 0; @@ -134,6 +134,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p) { int ret = 0; const char *provider_filter = NULL; + nccl_net_ofi_plugin_t *plugin; NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Initializing " PACKAGE_STRING); @@ -188,34 +189,95 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p) goto exit; } - /* Select and initialize protocol data structure. - * platform_init() may change the default, so this must occur - * after the platform init call. + /* This is ugly, but here's the basic protocol selection + * logic: + * 1. if the user set NCCL_OFI_PROTOCOL, use that. + * 2. if the platform init set nccl_ofi_selected_protocol, + * use that. + * 3. If the rdma protocol reports multiple nics per device + * and initialized successfully, use that. + * 4. If the sendrecv protocol initialized successfully, use + * that + * 5. If the rdma protocol initialized successfully, use + * that. */ if (ofi_nccl_protocol()) { nccl_ofi_selected_protocol = ofi_nccl_protocol(); NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Using transport protocol %s (user set)", nccl_ofi_selected_protocol); - } else { - NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Using transport protocol %s", + } else if (nccl_ofi_selected_protocol != NULL) { + NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Using transport protocol %s (platform set)", nccl_ofi_selected_protocol); } - if (0 == strcasecmp(nccl_ofi_selected_protocol, "SENDRECV")) { - ret = nccl_net_ofi_sendrecv_init(provider_filter, plugin_p); - if (ret != 0) { - NCCL_OFI_WARN("Failed to initialize sendrecv protocol"); + if (nccl_ofi_selected_protocol != NULL) { + bool dummy; + + if (0 == strcasecmp(nccl_ofi_selected_protocol, "SENDRECV")) { + ret = nccl_net_ofi_sendrecv_init(provider_filter, &plugin); + if (ret != 0) { + NCCL_OFI_WARN("Failed to initialize sendrecv protocol"); + goto exit; + } + } else if (0 == strcasecmp(nccl_ofi_selected_protocol, "RDMA")) { + ret = nccl_net_ofi_rdma_init(provider_filter, &plugin, &dummy); + if (ret != 0) { + NCCL_OFI_WARN("Failed to initialize rdma protocol"); + goto exit; + } + } else { + NCCL_OFI_WARN("Unable to find plugin protocol %s", nccl_ofi_selected_protocol); + ret = -ENOTSUP; goto exit; } - } else if (0 == strcasecmp(nccl_ofi_selected_protocol, "RDMA")) { - ret = nccl_net_ofi_rdma_init(provider_filter, plugin_p); + } else { + bool have_multiple_rails = false; + nccl_net_ofi_plugin_t *rdma_plugin = NULL, *sendrecv_plugin = NULL; + + ret = nccl_net_ofi_rdma_init(provider_filter, &rdma_plugin, &have_multiple_rails); if (ret != 0) { - NCCL_OFI_WARN("Failed to initialize rdma protocol"); + NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, + "Failed to initialize rdma protocol: %s", fi_strerror(-ret)); + have_multiple_rails = false; + rdma_plugin = NULL; + } + + if (!have_multiple_rails) { + ret = nccl_net_ofi_sendrecv_init(provider_filter, &sendrecv_plugin); + if (ret != 0) { + NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET, + "Failed to initialized rdma protocol: %s", fi_strerror(-ret)); + sendrecv_plugin = NULL; + } + } + + if (have_multiple_rails && rdma_plugin != NULL) { + nccl_ofi_selected_protocol = "RDMA"; + plugin = rdma_plugin; + if (sendrecv_plugin != NULL) { + sendrecv_plugin->release_plugin(sendrecv_plugin); + } + } else { + nccl_ofi_selected_protocol = "SENDRECV"; + plugin = sendrecv_plugin; + if (rdma_plugin != NULL) { + rdma_plugin->release_plugin(rdma_plugin); + } + } + + if (nccl_ofi_selected_protocol == NULL) { + NCCL_OFI_WARN("Unable to find a protocol that worked. Failing initialization."); + ret = -EINVAL; goto exit; } - } else { - NCCL_OFI_WARN("Unable to find plugin protocol %s", nccl_ofi_selected_protocol); - ret = -ENOTSUP; + + NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Using transport protocol %s", + nccl_ofi_selected_protocol); + } + + ret = plugin->complete_init(plugin); + if (ret != 0) { + NCCL_OFI_WARN("Failed to initialize rdma protocol"); goto exit; } @@ -230,7 +292,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p) * resources. This initialization happens once per process, and thus it * does not matter which device is used to create the endpoint. */ - nccl_net_ofi_device_t *device = (*plugin_p)->get_device(*plugin_p, 0); + nccl_net_ofi_device_t *device = plugin->get_device(plugin, 0); nccl_net_ofi_ep_t *base_ep = NULL; ret = device->get_ep(device, &base_ep); @@ -254,6 +316,8 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p) goto exit; } + *plugin_p = plugin; + exit: if (ret != 0) { NCCL_OFI_WARN(PACKAGE_NAME " initialization failed"); diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index dabe925cd..ee61a885e 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -7133,6 +7133,12 @@ static void get_hints(struct fi_info *hints) static inline int nccl_net_ofi_rdma_plugin_fini(nccl_net_ofi_plugin_t *plugin) { int ret, last_error = 0; + nccl_net_ofi_rdma_plugin_t *rdma_plugin = (nccl_net_ofi_rdma_plugin_t *)plugin; + + if (rdma_plugin->topo != NULL) { + nccl_ofi_topo_free(rdma_plugin->topo); + rdma_plugin->topo = NULL; + } ret = nccl_ofi_deque_finalize(r_comm_cleanup_list); if (ret != 0) { @@ -7161,19 +7167,66 @@ static inline int nccl_net_ofi_rdma_plugin_fini(nccl_net_ofi_plugin_t *plugin) } +static inline int nccl_net_ofi_rdma_plugin_complete_init(nccl_net_ofi_plugin_t *plugin) +{ + nccl_net_ofi_rdma_plugin_t *rdma_plugin = (nccl_net_ofi_rdma_plugin_t *)plugin; + nccl_ofi_topo_data_iterator_t data_iter; + int ret; + + /* Initialize user data iterator */ + ret = nccl_ofi_topo_set_to_begin(rdma_plugin->topo, &data_iter); + if (ret != 0) { + NCCL_OFI_WARN("Failed to set iterator to begin of user data vector"); + return ret; + } + + /* Allocate and initialize nccl_net devices */ + for (int dev_id = 0 ; dev_id != rdma_plugin->base.p_num_devs ; ++dev_id) { + struct fi_info *info_list; + + /* Retrieve NIC info list from topology */ + info_list = nccl_ofi_topo_next_info_list(&data_iter); + /* Verify NIC info list from topology */ + if (!info_list) { + NCCL_OFI_WARN("Unable to retrieve next NIC info list from topology"); + return -EINVAL; + } + + /* Allocate device */ + nccl_net_ofi_rdma_device_t *device = + nccl_net_ofi_rdma_device_create(&rdma_plugin->base, dev_id, + info_list, rdma_plugin->topo, + ofi_nccl_round_robin_threshold()); + if (device == NULL) { + NCCL_OFI_WARN("Device creation failed"); + return -ENOMEM; + } + + ret = plugin->assign_device(plugin, dev_id, &device->base); + if (ret != 0) { + NCCL_OFI_WARN("Assigning device %d failed", dev_id); + return ret; + } + } + + return 0; +} + + static inline int nccl_net_ofi_rdma_plugin_create(size_t num_devices, - nccl_net_ofi_plugin_t **plugin_p) + nccl_ofi_topo_t *topo, + nccl_net_ofi_rdma_plugin_t **plugin_p) { int ret; - nccl_net_ofi_plugin_t *plugin = NULL; + nccl_net_ofi_rdma_plugin_t *plugin = NULL; - plugin = (nccl_net_ofi_plugin_t*)malloc(sizeof(nccl_net_ofi_plugin_t)); + plugin = (nccl_net_ofi_rdma_plugin_t*)calloc(1, sizeof(nccl_net_ofi_rdma_plugin_t)); if (plugin == NULL) { NCCL_OFI_WARN("Unable to allocate nccl_net_ofi_plugin_t"); return -ENOMEM; } - ret = nccl_net_ofi_plugin_init(plugin, num_devices); + ret = nccl_net_ofi_plugin_init(&plugin->base, num_devices); if (ret != 0) { NCCL_OFI_WARN("Initializing base plugin failed: %s", strerror(-ret)); @@ -7193,7 +7246,10 @@ static inline int nccl_net_ofi_rdma_plugin_create(size_t num_devices, goto error; } - plugin->release_plugin = nccl_net_ofi_rdma_plugin_fini; + plugin->topo = topo; + + plugin->base.release_plugin = nccl_net_ofi_rdma_plugin_fini; + plugin->base.complete_init = nccl_net_ofi_rdma_plugin_complete_init; *plugin_p = plugin; @@ -7201,16 +7257,7 @@ static inline int nccl_net_ofi_rdma_plugin_create(size_t num_devices, error: if (plugin) { - if (s_comm_cleanup_list) { - nccl_ofi_deque_finalize(s_comm_cleanup_list); - s_comm_cleanup_list = NULL; - } - if (r_comm_cleanup_list) { - nccl_ofi_deque_finalize(r_comm_cleanup_list); - r_comm_cleanup_list = NULL; - } - plugin->release_plugin(plugin); - free(plugin); + plugin->base.release_plugin(&plugin->base); plugin = NULL; } return ret; @@ -7218,17 +7265,19 @@ static inline int nccl_net_ofi_rdma_plugin_create(size_t num_devices, int nccl_net_ofi_rdma_init(const char *provider_filter, - nccl_net_ofi_plugin_t **plugin_p) + nccl_net_ofi_plugin_t **plugin_p, + bool *found_multiple_rails) { int ret = 0; int num_devs = 0; struct fi_info *provider_list = NULL; unsigned int num_providers; - size_t rr_threshold = ofi_nccl_round_robin_threshold(); - nccl_net_ofi_plugin_t *plugin = NULL; + nccl_net_ofi_rdma_plugin_t *plugin = NULL; nccl_ofi_topo_t *topo = NULL; struct fi_info *hints; + *found_multiple_rails = false; + hints = fi_allocinfo(); if (hints == NULL) { NCCL_OFI_WARN("Allocation of fi_info failed"); @@ -7259,7 +7308,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, ret = nccl_net_ofi_query_provider_capabilities(provider_list, num_providers); if (ret != 0) { NCCL_OFI_WARN("Querying provider capabilities failed: %d", ret); - goto exit; + goto error; } if (endpoint_mr) { @@ -7269,7 +7318,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, } if (ofi_nccl_eager_max_size() < 0 || - ofi_nccl_eager_max_size() > rr_threshold) { + ofi_nccl_eager_max_size() > ofi_nccl_round_robin_threshold()) { NCCL_OFI_WARN("Invalid value for EAGER_MAX_SIZE"); ret = ncclInvalidArgument; goto error; @@ -7288,7 +7337,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, if (!topo) { NCCL_OFI_WARN("Failed to create NCCL OFI topology"); ret = -ENOTSUP; - goto exit; + goto error; } ret = nccl_ofi_topo_group(topo); @@ -7309,6 +7358,10 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, goto error; } + if (topo->max_group_size > 1) { + *found_multiple_rails = true; + } + /** * NCCL's topology detection will set NIC PCIe link speed based on the * "leader" NIC for the GPU. For multi-rail platforms, we increase the @@ -7334,65 +7387,21 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, goto error; } - ret = nccl_net_ofi_rdma_plugin_create(num_devs, &plugin); + ret = nccl_net_ofi_rdma_plugin_create(num_devs, topo, &plugin); if (ret != 0) { NCCL_OFI_WARN("Unable to allocate nccl_net_ofi_plugin_t"); - goto exit; - } - - /* Initialize user data iterator */ - nccl_ofi_topo_data_iterator_t data_iter; - ret = nccl_ofi_topo_set_to_begin(topo, &data_iter); - if (ret != 0) { - NCCL_OFI_WARN("Failed to set iterator to begin of user data vector"); goto error; } - /* Allocate and initialize nccl_net devices */ - for (int dev_id = 0 ; dev_id != num_devs ; ++dev_id) { - struct fi_info *info_list; - - /* Retrieve NIC info list from topology */ - info_list = nccl_ofi_topo_next_info_list(&data_iter); - /* Verify NIC info list from topology */ - if (!info_list) { - NCCL_OFI_WARN("Unable to retrieve next NIC info list from topology"); - ret = -EINVAL; - goto error; - } - - /* Allocate device */ - nccl_net_ofi_rdma_device_t *device = - nccl_net_ofi_rdma_device_create(plugin, dev_id, - info_list, topo, - rr_threshold); - if (device == NULL) { - NCCL_OFI_WARN("Device creation failed"); - ret = -ENOMEM; - goto error; - } + *plugin_p = &plugin->base; - ret = plugin->assign_device(plugin, dev_id, &device->base); - if (ret != 0) { - NCCL_OFI_WARN("Assigning device %d failed", dev_id); - goto error; - } - } - - goto exit; + return ret; error: if (plugin != NULL) { - plugin->release_plugin(plugin); + plugin->base.release_plugin(&plugin->base); plugin = NULL; } - exit: - if (topo != NULL) { - nccl_ofi_topo_free(topo); - } - - *plugin_p = plugin; - return ret; } diff --git a/src/nccl_ofi_sendrecv.c b/src/nccl_ofi_sendrecv.c index 67b3f3968..c7dd4d7cf 100644 --- a/src/nccl_ofi_sendrecv.c +++ b/src/nccl_ofi_sendrecv.c @@ -2529,6 +2529,11 @@ static void get_hints(struct fi_info *hints, int req_gdr) static int nccl_net_ofi_sendrecv_plugin_fini(nccl_net_ofi_plugin_t *plugin) { int ret, last_error = 0; + nccl_net_ofi_sendrecv_plugin_t *sendrecv_plugin = (nccl_net_ofi_sendrecv_plugin_t *)plugin; + + if (sendrecv_plugin->provider_list != NULL) { + fi_freeinfo(sendrecv_plugin->provider_list); + } ret = nccl_net_ofi_plugin_fini(plugin); if (ret != 0) { @@ -2545,26 +2550,66 @@ static int nccl_net_ofi_sendrecv_plugin_fini(nccl_net_ofi_plugin_t *plugin) } +static inline int nccl_net_ofi_sendrecv_plugin_complete_init(nccl_net_ofi_plugin_t *plugin) +{ + nccl_net_ofi_sendrecv_plugin_t *sendrecv_plugin = (nccl_net_ofi_sendrecv_plugin_t *)plugin; + struct fi_info *info; + int dev_id = 0; + int ret; + + /* Allocate and initialize nccl_net devices */ + info = sendrecv_plugin->provider_list; + while (dev_id != sendrecv_plugin->base.p_num_devs) { + if (!info) { + NCCL_OFI_WARN("Insufficient Libfabric devices found"); + return -EINVAL; + } + + nccl_net_ofi_sendrecv_device_t *device = + nccl_net_ofi_sendrecv_device_create(plugin, dev_id, info); + if (device == NULL) { + NCCL_OFI_WARN("Unable to allocate device %i", dev_id); + return -ENOMEM; + } + + ret = plugin->assign_device(plugin, dev_id, &device->base); + if (ret != 0) { + NCCL_OFI_WARN("Assigning device %d failed", dev_id); + return ret; + } + + dev_id++; + info = info->next; + } + + return 0; +} + + static int nccl_net_ofi_sendrecv_plugin_create(size_t num_devices, - nccl_net_ofi_plugin_t **plugin_p) + struct fi_info *provider_list, + nccl_net_ofi_sendrecv_plugin_t **plugin_p) { int ret; - nccl_net_ofi_plugin_t *plugin = NULL; + nccl_net_ofi_sendrecv_plugin_t *plugin = NULL; - plugin = (nccl_net_ofi_plugin_t *)malloc(sizeof(nccl_net_ofi_plugin_t)); + plugin = (nccl_net_ofi_sendrecv_plugin_t *)calloc(1, sizeof(nccl_net_ofi_sendrecv_plugin_t)); if (plugin == NULL) { NCCL_OFI_WARN("Unable to allocate nccl_net_ofi_plugin_t"); return -ENOMEM; } - ret = nccl_net_ofi_plugin_init(plugin, num_devices); + ret = nccl_net_ofi_plugin_init(&plugin->base, num_devices); if (ret != 0) { NCCL_OFI_WARN("Initializing base plugin failed: %s", strerror(-ret)); return ret; } - plugin->release_plugin = nccl_net_ofi_sendrecv_plugin_fini; + plugin->provider_list = provider_list; + + plugin->base.release_plugin = nccl_net_ofi_sendrecv_plugin_fini; + plugin->base.complete_init = nccl_net_ofi_sendrecv_plugin_complete_init; *plugin_p = plugin; @@ -2576,10 +2621,9 @@ int nccl_net_ofi_sendrecv_init(const char *provider_filter, nccl_net_ofi_plugin_t **plugin_p) { int ret = 0; - int dev_id = 0; - struct fi_info *provider_list = NULL, *info; + struct fi_info *provider_list = NULL; unsigned int num_providers; - nccl_net_ofi_plugin_t *plugin = NULL; + nccl_net_ofi_sendrecv_plugin_t *plugin = NULL; struct fi_info *hints; hints = fi_allocinfo(); @@ -2685,7 +2729,7 @@ int nccl_net_ofi_sendrecv_init(const char *provider_filter, if (!tmp) { NCCL_OFI_WARN("DUP_CONNS fi_dupinfo failed."); ret = -ENOMEM; - goto exit; + goto error; } /* just in case */ tmp->next = NULL; @@ -2713,52 +2757,24 @@ int nccl_net_ofi_sendrecv_init(const char *provider_filter, ret = nccl_net_ofi_query_provider_capabilities(provider_list, num_providers); if (ret != 0) { NCCL_OFI_WARN("Querying provider capabilities failed: %d", ret); - goto exit; + goto error; } - ret = nccl_net_ofi_sendrecv_plugin_create(num_providers, &plugin); + ret = nccl_net_ofi_sendrecv_plugin_create(num_providers, provider_list, &plugin); if (ret != 0) { NCCL_OFI_WARN("Unable to allocate nccl_net_ofi_plugin_t"); - goto exit; + goto error; } - /* Allocate and initialize nccl_net devices */ - info = provider_list; - while (dev_id != num_providers) { - if (!info) { - NCCL_OFI_WARN("Insufficient Libfabric devices found"); - ret = -EINVAL; - goto error; - } - - nccl_net_ofi_sendrecv_device_t *device = - nccl_net_ofi_sendrecv_device_create(plugin, dev_id, info); - if (device == NULL) { - NCCL_OFI_WARN("Unable to allocate device %i", dev_id); - ret = -ENOMEM; - goto error; - } - - ret = plugin->assign_device(plugin, dev_id, &device->base); - if (ret != 0) { - NCCL_OFI_WARN("Assigning device %d failed", dev_id); - goto error; - } + *plugin_p = &plugin->base; - dev_id++; - info = info->next; - } - - goto exit; + return ret; error: if (plugin != NULL) { - plugin->release_plugin(plugin); + plugin->base.release_plugin(&plugin->base); plugin = NULL; } - exit: - *plugin_p = plugin; - return ret; }