Skip to content

Commit

Permalink
Improve protocol selection logic
Browse files Browse the repository at this point in the history
Change the protocol selection logic when the platform file does
not specify a protocol.  Instead of always defaulting to sendrecv
if the user / platform file didn't specify a protocol, try to figure
out when rdma is a good default.  We still are conservative with
enabling the rdma protocol, to avoid changing the default in as
many places as posible.

With the change, the protocol selection order is:

  1. if the user set NCCL_OFI_PROTOCOL, use that.
  2. if the platform init set nccl_ofi_selected_protocol, use that.
  3. If the rdma protocol reports multiple nics per device and
     initialized successfully, use that.
  4. If the sendrecv protocol initialized successfully, use that.
  5. If the rdma protocol initialized successfully, use that.

Signed-off-by: Brian Barrett <bbarrett@amazon.com>
  • Loading branch information
bwbarrett committed Sep 20, 2024
1 parent cbf5115 commit 159bfed
Show file tree
Hide file tree
Showing 6 changed files with 251 additions and 130 deletions.
13 changes: 13 additions & 0 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,19 @@ struct nccl_net_ofi_recv_comm {
*/
struct nccl_net_ofi_plugin {
/* public */

/**
* Complete initialization of plugin
*
* When a plugin is first created, it should not create any
* network resources -- create is called to understand the
* configuration of the network and see which transports can
* run. The base code will pick one and call complete_init,
* at which point devices and network resources can be
* created.
*/
int (*complete_init)(nccl_net_ofi_plugin_t *plugin);

int (*assign_device)(nccl_net_ofi_plugin_t *plugin,
size_t device_index, nccl_net_ofi_device_t *device);

Expand Down
12 changes: 11 additions & 1 deletion include/nccl_ofi_rdma.h
Original file line number Diff line number Diff line change
Expand Up @@ -834,11 +834,21 @@ typedef struct nccl_net_ofi_rdma_device {
#endif
} nccl_net_ofi_rdma_device_t;


struct nccl_net_ofi_rdma_plugin {
nccl_net_ofi_plugin_t base;

nccl_ofi_topo_t *topo;
};
typedef struct nccl_net_ofi_rdma_plugin nccl_net_ofi_rdma_plugin_t;


/*
* @brief Initialize plugin with rdma protocol structures
*/
int nccl_net_ofi_rdma_init(const char *provider_filter,
nccl_net_ofi_plugin_t **plugin_p);
nccl_net_ofi_plugin_t **plugin_p,
bool *found_multi_rail);

#ifdef __cplusplus
} // End extern "C"
Expand Down
9 changes: 9 additions & 0 deletions include/nccl_ofi_sendrecv.h
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,15 @@ typedef struct nccl_net_ofi_sendrecv_req {
nccl_net_ofi_sendrecv_req_direction_t direction;
} nccl_net_ofi_sendrecv_req_t;


struct nccl_net_ofi_sendrecv_plugin {
nccl_net_ofi_plugin_t base;

struct fi_info *provider_list;
};
typedef struct nccl_net_ofi_sendrecv_plugin nccl_net_ofi_sendrecv_plugin_t;


/*
* @brief Initialize plugin with sendrecv protocol structures
*/
Expand Down
98 changes: 81 additions & 17 deletions src/nccl_ofi_net.c
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ bool endpoint_mr = false;
bool virt_addr_mr = false;

/* Selected communication protocol. */
const char *nccl_ofi_selected_protocol = "SENDRECV";
const char *nccl_ofi_selected_protocol = NULL;

/* Allocate one domain per process (0) or per thread (1) */
int domain_per_thread = 0;
Expand Down Expand Up @@ -134,6 +134,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
{
int ret = 0;
const char *provider_filter = NULL;
nccl_net_ofi_plugin_t *plugin;

NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Initializing " PACKAGE_STRING);

Expand Down Expand Up @@ -188,34 +189,95 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
goto exit;
}

/* Select and initialize protocol data structure.
* platform_init() may change the default, so this must occur
* after the platform init call.
/* This is ugly, but here's the basic protocol selection
* logic:
* 1. if the user set NCCL_OFI_PROTOCOL, use that.
* 2. if the platform init set nccl_ofi_selected_protocol,
* use that.
* 3. If the rdma protocol reports multiple nics per device
* and initialized successfully, use that.
* 4. If the sendrecv protocol initialized successfully, use
* that
* 5. If the rdma protocol initialized successfully, use
* that.
*/
if (ofi_nccl_protocol()) {
nccl_ofi_selected_protocol = ofi_nccl_protocol();
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Using transport protocol %s (user set)",
nccl_ofi_selected_protocol);
} else {
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Using transport protocol %s",
} else if (nccl_ofi_selected_protocol != NULL) {
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Using transport protocol %s (platform set)",
nccl_ofi_selected_protocol);
}

if (0 == strcasecmp(nccl_ofi_selected_protocol, "SENDRECV")) {
ret = nccl_net_ofi_sendrecv_init(provider_filter, plugin_p);
if (ret != 0) {
NCCL_OFI_WARN("Failed to initialize sendrecv protocol");
if (nccl_ofi_selected_protocol != NULL) {
bool dummy;

if (0 == strcasecmp(nccl_ofi_selected_protocol, "SENDRECV")) {
ret = nccl_net_ofi_sendrecv_init(provider_filter, &plugin);
if (ret != 0) {
NCCL_OFI_WARN("Failed to initialize sendrecv protocol");
goto exit;
}
} else if (0 == strcasecmp(nccl_ofi_selected_protocol, "RDMA")) {
ret = nccl_net_ofi_rdma_init(provider_filter, &plugin, &dummy);
if (ret != 0) {
NCCL_OFI_WARN("Failed to initialize rdma protocol");
goto exit;
}
} else {
NCCL_OFI_WARN("Unable to find plugin protocol %s", nccl_ofi_selected_protocol);
ret = -ENOTSUP;
goto exit;
}
} else if (0 == strcasecmp(nccl_ofi_selected_protocol, "RDMA")) {
ret = nccl_net_ofi_rdma_init(provider_filter, plugin_p);
} else {
bool have_multiple_rails = false;
nccl_net_ofi_plugin_t *rdma_plugin = NULL, *sendrecv_plugin = NULL;

ret = nccl_net_ofi_rdma_init(provider_filter, &rdma_plugin, &have_multiple_rails);
if (ret != 0) {
NCCL_OFI_WARN("Failed to initialize rdma protocol");
NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET,
"Failed to initialize rdma protocol: %s", fi_strerror(-ret));
have_multiple_rails = false;
rdma_plugin = NULL;
}

if (!have_multiple_rails) {
ret = nccl_net_ofi_sendrecv_init(provider_filter, &sendrecv_plugin);
if (ret != 0) {
NCCL_OFI_TRACE(NCCL_INIT | NCCL_NET,
"Failed to initialized rdma protocol: %s", fi_strerror(-ret));
sendrecv_plugin = NULL;
}
}

if (have_multiple_rails && rdma_plugin != NULL) {
nccl_ofi_selected_protocol = "RDMA";
plugin = rdma_plugin;
if (sendrecv_plugin != NULL) {
sendrecv_plugin->release_plugin(sendrecv_plugin);
}
} else {
nccl_ofi_selected_protocol = "SENDRECV";
plugin = sendrecv_plugin;
if (rdma_plugin != NULL) {
rdma_plugin->release_plugin(rdma_plugin);
}
}

if (nccl_ofi_selected_protocol == NULL) {
NCCL_OFI_WARN("Unable to find a protocol that worked. Failing initialization.");
ret = -EINVAL;
goto exit;
}
} else {
NCCL_OFI_WARN("Unable to find plugin protocol %s", nccl_ofi_selected_protocol);
ret = -ENOTSUP;

NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Using transport protocol %s",
nccl_ofi_selected_protocol);
}

ret = plugin->complete_init(plugin);
if (ret != 0) {
NCCL_OFI_WARN("Failed to initialize rdma protocol");
goto exit;
}

Expand All @@ -230,7 +292,7 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
* resources. This initialization happens once per process, and thus it
* does not matter which device is used to create the endpoint.
*/
nccl_net_ofi_device_t *device = (*plugin_p)->get_device(*plugin_p, 0);
nccl_net_ofi_device_t *device = plugin->get_device(plugin, 0);
nccl_net_ofi_ep_t *base_ep = NULL;

ret = device->get_ep(device, &base_ep);
Expand All @@ -254,6 +316,8 @@ int nccl_net_ofi_create_plugin(nccl_net_ofi_plugin_t **plugin_p)
goto exit;
}

*plugin_p = plugin;

exit:
if (ret != 0) {
NCCL_OFI_WARN(PACKAGE_NAME " initialization failed");
Expand Down
Loading

0 comments on commit 159bfed

Please sign in to comment.