Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changes to support the v8 interface of the NCCL plugin interface #365

Merged
merged 3 commits into from
Mar 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions include/Makefile.am
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ noinst_HEADERS = \
nccl-headers/nvidia/net_v5.h \
nccl-headers/nvidia/net_v6.h \
nccl-headers/nvidia/net_v7.h \
nccl-headers/nvidia/net_v8.h \
nccl-headers/nvidia/types.h \
nccl-headers/nvidia/tuner.h \
nccl-headers/neuron/net.h \
Expand Down
3 changes: 2 additions & 1 deletion include/nccl-headers/nvidia/net.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
#define NCCL_PTR_DMABUF 0x4

// Maximum number of requests per comm object
#define NCCL_NET_MAX_REQUESTS 8
#define NCCL_NET_MAX_REQUESTS 32

typedef enum {NCCL_LOG_NONE=0, NCCL_LOG_VERSION=1, NCCL_LOG_WARN=2, NCCL_LOG_INFO=3, NCCL_LOG_ABORT=4, NCCL_LOG_TRACE=5} ncclDebugLogLevel;
typedef enum {NCCL_INIT=1, NCCL_COLL=2, NCCL_P2P=4, NCCL_SHM=8, NCCL_NET=16, NCCL_GRAPH=32, NCCL_TUNING=64, NCCL_ENV=128, NCCL_ALLOC=256, NCCL_CALL=512, NCCL_ALL=~0} ncclDebugLogSubSys;

typedef void (*ncclDebugLogger_t)(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...);

#include "net_v8.h"
#include "net_v7.h"
#include "net_v6.h"
#include "net_v5.h"
Expand Down
1 change: 1 addition & 0 deletions include/nccl-headers/nvidia/net_device.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ typedef struct {
int needsProxyProgress;
} ncclNetDeviceHandle_v7_t;

typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_v8_t;
typedef ncclNetDeviceHandle_v7_t ncclNetDeviceHandle_t;

#endif
2 changes: 0 additions & 2 deletions include/nccl-headers/nvidia/net_v7.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@ typedef struct {
int netDeviceVersion; // Version number for network offload
} ncclNetProperties_v7_t;

typedef ncclNetProperties_v7_t ncclNetProperties_t;

typedef struct {
// Name of the network (mainly for logs)
const char* name;
Expand Down
83 changes: 83 additions & 0 deletions include/nccl-headers/nvidia/net_v8.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
/*
* Copyright (c) 2017-2022, NVIDIA CORPORATION. All rights reserved.
*/

#ifndef NCCL_NET_V8_H_
#define NCCL_NET_V8_H_

#include "net_device.h"

typedef struct {
char* name; // Used mostly for logging.
char* pciPath; // Path to the PCI device in /sys.
uint64_t guid; // Unique identifier for the NIC chip. Important for
// cards with multiple PCI functions (Physical or virtual).
int ptrSupport; // [NCCL_PTR_HOST|NCCL_PTR_CUDA|NCCL_PTR_DMABUF]
int regIsGlobal; // regMr is not tied to a particular comm
int speed; // Port speed in Mbps.
int port; // Port number.
float latency; // Network latency
int maxComms; // Maximum number of comms we can create
int maxRecvs; // Maximum number of grouped receives.
ncclNetDeviceType netDeviceType; // Network offload type
int netDeviceVersion; // Version number for network offload
} ncclNetProperties_v8_t;

typedef ncclNetProperties_v8_t ncclNetProperties_t;

typedef struct {
// Name of the network (mainly for logs)
const char* name;
// Initialize the network.
ncclResult_t (*init)(ncclDebugLogger_t logFunction);
// Return the number of adapters.
ncclResult_t (*devices)(int* ndev);
// Get various device properties.
ncclResult_t (*getProperties)(int dev, ncclNetProperties_v8_t* props);
// Create a receiving object and provide a handle to connect to it. The
// handle can be up to NCCL_NET_HANDLE_MAXSIZE bytes and will be exchanged
// between ranks to create a connection.
ncclResult_t (*listen)(int dev, void* handle, void** listenComm);
// Connect to a handle and return a sending comm object for that peer.
// This call must not block for the connection to be established, and instead
// should return successfully with sendComm == NULL with the expectation that
// it will be called again until sendComm != NULL.
// If *sendDevComm points to a valid object, then NCCL is requesting device offload for this connection
ncclResult_t (*connect)(int dev, void* handle, void** sendComm, ncclNetDeviceHandle_v8_t** sendDevComm);
// Finalize connection establishment after remote peer has called connect.
// This call must not block for the connection to be established, and instead
// should return successfully with recvComm == NULL with the expectation that
// it will be called again until recvComm != NULL.
// If *recvDevComm points to a valid object, then NCCL is requesting device offload for this connection
ncclResult_t (*accept)(void* listenComm, void** recvComm, ncclNetDeviceHandle_v8_t** recvDevComm);
// Register/Deregister memory. Comm can be either a sendComm or a recvComm.
// Type is either NCCL_PTR_HOST or NCCL_PTR_CUDA.
ncclResult_t (*regMr)(void* comm, void* data, size_t size, int type, void** mhandle);
/* DMA-BUF support */
ncclResult_t (*regMrDmaBuf)(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle);
ncclResult_t (*deregMr)(void* comm, void* mhandle);
// Asynchronous send to a peer.
// May return request == NULL if the call cannot be performed (or would block)
ncclResult_t (*isend)(void* sendComm, void* data, int size, int tag, void* mhandle, void** request);
// Asynchronous recv from a peer.
// May return request == NULL if the call cannot be performed (or would block)
ncclResult_t (*irecv)(void* recvComm, int n, void** data, int* sizes, int* tags, void** mhandles, void** request);
// Perform a flush/fence to make sure all data received with NCCL_PTR_CUDA is
// visible to the GPU
ncclResult_t (*iflush)(void* recvComm, int n, void** data, int* sizes, void** mhandles, void** request);
// Test whether a request is complete. If size is not NULL, it returns the
// number of bytes sent/received.
ncclResult_t (*test)(void* request, int* done, int* sizes);
// Close and free send/recv comm objects
ncclResult_t (*closeSend)(void* sendComm);
ncclResult_t (*closeRecv)(void* recvComm);
ncclResult_t (*closeListen)(void* listenComm);

// Copy the given mhandle to a dptr in a format usable by this plugin's device code
ncclResult_t (*getDeviceMr)(void* comm, void* mhandle, void** dptr_mhandle);

// Notify the plugin that a recv has completed by the device
ncclResult_t (*irecvConsumed)(void* recvComm, int n, void* request);
} ncclNet_v8_t;

#endif // end include guard
8 changes: 8 additions & 0 deletions include/nccl_ofi.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,12 @@ typedef enum nccl_ofi_comm_stage {
COMM_CONNECTED,
} nccl_ofi_comm_stage_t;

/* Determines which object a provider associates MRs with */
typedef enum nccl_ofi_mr_scope {
NCCL_OFI_MR_SCOPE_DOMAIN = 0,
NCCL_OFI_MR_SCOPE_ENDPOINT
} nccl_ofi_mr_scope_t;

typedef struct save_comm_state {
nccl_net_ofi_comm_t *comm;
nccl_net_ofi_req_t *req;
Expand Down Expand Up @@ -222,6 +228,8 @@ typedef struct nccl_ofi_properties {
unsigned int max_communicators;
/** Maximum number of grouped receives */
unsigned int max_group_receives;
/** Scope of a memory region registered with a provider **/
nccl_ofi_mr_scope_t mr_scope;
} nccl_ofi_properties_t;

/**
Expand Down
6 changes: 3 additions & 3 deletions include/nccl_ofi_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@ ncclResult_t nccl_net_ofi_connect(int dev, void* handle, void** sendComm);
ncclResult_t nccl_net_ofi_connect_v4(int dev, void* handle, void** sendComm);
ncclResult_t nccl_net_ofi_accept(void *listenComm, void **recvComm);
ncclResult_t nccl_net_ofi_accept_v4(void* listenComm, void** recvComm);
ncclResult_t nccl_net_ofi_regMr(void *comm, void *data, int size, int type,
void **mhandle);
ncclResult_t nccl_net_ofi_regMr_sizet(void *comm, void *data, size_t size, int type,
ncclResult_t nccl_net_ofi_regMr_v7(void *comm, void *data, int size, int type,
void **mhandle);
ncclResult_t nccl_net_ofi_regMr(void *comm, void *data, size_t size, int type,
void **mhandle);
ncclResult_t nccl_net_ofi_regMrDmaBuf(void* comm, void* data, size_t size, int type, uint64_t offset, int fd, void** mhandle);
ncclResult_t nccl_net_ofi_deregMr(void *comm, void *mhandle);
Expand Down
43 changes: 6 additions & 37 deletions src/nccl_ofi_api.c
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,13 @@ ncclResult_t nccl_net_ofi_connect_v4(int dev, void* handle, void** sendComm)
return ret;
}

ncclResult_t nccl_net_ofi_regMr_v7(void *comm, void *data, int size, int type,
void **mhandle)
{
return nccl_net_ofi_regMr(comm, data, (size_t)size, type, mhandle);
}

ncclResult_t nccl_net_ofi_regMr(void *comm, void *data, int size, int type,
ncclResult_t nccl_net_ofi_regMr(void *comm, void *data, size_t size, int type,
void **mhandle)
{
int ret = 0;
Expand Down Expand Up @@ -365,42 +370,6 @@ ncclResult_t nccl_net_ofi_regMr(void *comm, void *data, int size, int type,
return nccl_net_ofi_retval_translate(ret);
}


ncclResult_t nccl_net_ofi_regMr_sizet(void *comm, void *data, size_t size, int type,
void **mhandle)
{
/* Retrieve and validate comm */
nccl_net_ofi_comm_t *base_comm =
(nccl_net_ofi_comm_t *)comm;
if (OFI_UNLIKELY(base_comm == NULL)) {
NCCL_OFI_WARN("Invalid comm object provided");
return ncclInternalError;
}

int ret = 0;

switch (base_comm->type) {
case NCCL_NET_OFI_SEND_COMM:;
nccl_net_ofi_send_comm_t *send_comm =
(nccl_net_ofi_send_comm_t *)base_comm;
ret = send_comm->regMr(send_comm, data, size, type, mhandle);
break;
case NCCL_NET_OFI_RECV_COMM:;
nccl_net_ofi_recv_comm_t *recv_comm =
(nccl_net_ofi_recv_comm_t *)base_comm;
ret = recv_comm->regMr(recv_comm, data, size, type, mhandle);
break;
default:
NCCL_OFI_WARN("Unexpected communicator type. Communicator type: %d",
base_comm->type);
ret = -EINVAL;
break;
}

return nccl_net_ofi_retval_translate(ret);
}


ncclResult_t nccl_net_ofi_deregMr(void *comm, void *mhandle)
{
/* Retrieve and validate comm */
Expand Down
2 changes: 1 addition & 1 deletion src/nccl_ofi_interface_neuron.c
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ const ncclNet_v4_t ncclNetPlugin_v4 = {
.listen = nccl_net_ofi_listen_v4,
.connect = nccl_net_ofi_connect_v4,
.accept = nccl_net_ofi_accept_v4,
.regMr = nccl_net_ofi_regMr_sizet,
.regMr = nccl_net_ofi_regMr,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend_v4,
.irecv = nccl_net_ofi_irecv_v4,
Expand Down
72 changes: 66 additions & 6 deletions src/nccl_ofi_interface_nvidia.c
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,44 @@
#include "nccl_ofi.h"
#include "nccl_ofi_api.h"

static ncclResult_t getProperties_v8(int dev_id, ncclNetProperties_v8_t* props)
{
nccl_ofi_properties_t ofi_properties;
ncclResult_t ret = nccl_net_ofi_get_properties(dev_id, &ofi_properties);
if (ret != ncclSuccess) {
return ret;
}

props->name = ofi_properties.name;
props->pciPath = ofi_properties.pci_path;
props->guid = ofi_properties.guid;
props->ptrSupport = NCCL_PTR_HOST;
if (ofi_properties.hmem_support) {
props->ptrSupport |= NCCL_PTR_CUDA;
}
if (ofi_properties.dmabuf_support) {
props->ptrSupport |= NCCL_PTR_DMABUF;
}

/*
* NCCL uses regIsGlobal to determine support for User Registrations via
* the NCCL API. If providers tie MRs to endpoints, the plugin can not
* support this model (since NCCL maintains a per-domain registration
* cache which requires (domain-)global registrations.
*/
if (ofi_properties.mr_scope == NCCL_OFI_MR_SCOPE_DOMAIN)
props->regIsGlobal = 1;

props->speed = ofi_properties.port_speed;
props->port = ofi_properties.port_number;
props->latency = ofi_properties.latency;
props->maxComms = ofi_properties.max_communicators;
props->maxRecvs = ofi_properties.max_group_receives;
props->netDeviceType = NCCL_NET_DEVICE_HOST;
props->netDeviceVersion = NCCL_NET_DEVICE_INVALID_VERSION;

return ncclSuccess;
}

static ncclResult_t getProperties_v7(int dev_id, ncclNetProperties_v7_t *props)
{
Expand Down Expand Up @@ -139,7 +177,7 @@ const ncclNet_v2_t ncclNetPlugin_v2 = {
.listen = nccl_net_ofi_listen_v4,
.connect = nccl_net_ofi_connect_v4,
.accept = nccl_net_ofi_accept_v4,
.regMr = nccl_net_ofi_regMr,
.regMr = nccl_net_ofi_regMr_v7,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend_v4,
.irecv = nccl_net_ofi_irecv_v4,
Expand All @@ -158,7 +196,7 @@ const ncclNet_v3_t ncclNetPlugin_v3 = {
.listen = nccl_net_ofi_listen_v4,
.connect = nccl_net_ofi_connect_v4,
.accept = nccl_net_ofi_accept_v4,
.regMr = nccl_net_ofi_regMr,
.regMr = nccl_net_ofi_regMr_v7,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend_v4,
.irecv = nccl_net_ofi_irecv_v4,
Expand All @@ -177,7 +215,7 @@ const ncclNet_v4_t ncclNetPlugin_v4 = {
.listen = nccl_net_ofi_listen_v4,
.connect = nccl_net_ofi_connect_v4,
.accept = nccl_net_ofi_accept_v4,
.regMr = nccl_net_ofi_regMr,
.regMr = nccl_net_ofi_regMr_v7,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend_v4,
.irecv = nccl_net_ofi_irecv_v4,
Expand All @@ -196,7 +234,7 @@ const ncclNet_v5_t ncclNetPlugin_v5 = {
.listen = nccl_net_ofi_listen,
.connect = nccl_net_ofi_connect,
.accept = nccl_net_ofi_accept,
.regMr = nccl_net_ofi_regMr,
.regMr = nccl_net_ofi_regMr_v7,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
.irecv = nccl_net_ofi_irecv,
Expand All @@ -215,7 +253,7 @@ const ncclNet_v6_t ncclNetPlugin_v6 = {
.listen = nccl_net_ofi_listen,
.connect = nccl_net_ofi_connect,
.accept = nccl_net_ofi_accept,
.regMr = nccl_net_ofi_regMr,
.regMr = nccl_net_ofi_regMr_v7,
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
Expand All @@ -235,7 +273,7 @@ const ncclNet_v7_t ncclNetPlugin_v7 = {
.listen = nccl_net_ofi_listen,
.connect = connect_v7,
.accept = accept_v7,
.regMr = nccl_net_ofi_regMr,
.regMr = nccl_net_ofi_regMr_v7,
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
Expand All @@ -248,3 +286,25 @@ const ncclNet_v7_t ncclNetPlugin_v7 = {
.getDeviceMr = NULL,
.irecvConsumed = NULL,
};

const ncclNet_v8_t ncclNetPlugin_v8 = {
.name = "AWS Libfabric",
.init = nccl_net_ofi_init,
.devices = nccl_net_ofi_devices,
.getProperties = getProperties_v8,
.listen = nccl_net_ofi_listen,
.connect = connect_v7,
.accept = accept_v7,
.regMr = nccl_net_ofi_regMr,
.regMrDmaBuf = nccl_net_ofi_regMrDmaBuf,
.deregMr = nccl_net_ofi_deregMr,
.isend = nccl_net_ofi_isend,
.irecv = nccl_net_ofi_irecv,
.iflush = nccl_net_ofi_iflush,
.test = nccl_net_ofi_test,
.closeSend = nccl_net_ofi_closeSend,
.closeRecv = nccl_net_ofi_closeRecv,
.closeListen = nccl_net_ofi_closeListen,
.getDeviceMr = NULL,
.irecvConsumed = NULL,
};
12 changes: 12 additions & 0 deletions src/nccl_ofi_net.c
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,18 @@ int nccl_net_ofi_info_properties(struct fi_info *nic_prov, int dev_id, int num_d
dev_props.name = strdup(nic_info->device_attr->name);
}

/*
* Determine the scope of MRs for providers to report global
* registration support to NCCL
*/
if (nic_prov->domain_attr->mr_mode & FI_MR_ENDPOINT) {
dev_props.mr_scope = NCCL_OFI_MR_SCOPE_ENDPOINT;
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with endpoints");
} else {
dev_props.mr_scope = NCCL_OFI_MR_SCOPE_DOMAIN;
NCCL_OFI_INFO(NCCL_INIT | NCCL_NET, "Libfabric provider associates MRs with domains");
}

/* Speed reported in Mbps */
dev_props.port_speed = nic_info->link_attr->speed / (1e6);

Expand Down
2 changes: 1 addition & 1 deletion tests/functional/nccl_connection.c
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ int main(int argc, char* argv[])
nccl_net_ofi_send_comm_t *sComm = NULL;
nccl_net_ofi_listen_comm_t *lComm = NULL;
nccl_net_ofi_recv_comm_t *rComm = NULL;
ncclNetDeviceHandle_v7_t *s_ignore, *r_ignore;
ncclNetDeviceHandle_v8_t *s_ignore, *r_ignore;
char src_handle[NCCL_NET_HANDLE_MAXSIZE] = {0};
char handle[NCCL_NET_HANDLE_MAXSIZE] = {0};
test_nccl_net_t *extNet = NULL;
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/nccl_message_transfer.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ int main(int argc, char* argv[])
nccl_net_ofi_listen_comm_t *lComm = NULL;
nccl_net_ofi_recv_comm_t *rComm = NULL;
test_nccl_net_t *extNet = NULL;
ncclNetDeviceHandle_v7_t *s_ignore, *r_ignore;
ncclNetDeviceHandle_v8_t *s_ignore, *r_ignore;
char src_handle[NCCL_NET_HANDLE_MAXSIZE] = {0};

ofi_log_function = logger;
Expand Down
2 changes: 1 addition & 1 deletion tests/functional/ring.c
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ int main(int argc, char *argv[])
char handle[NCCL_NET_HANDLE_MAXSIZE] = {0};
char src_handle_prev[NCCL_NET_HANDLE_MAXSIZE] = {0};
char src_handle_next[NCCL_NET_HANDLE_MAXSIZE] = {0};
ncclNetDeviceHandle_v7_t *s_ignore, *r_ignore;
ncclNetDeviceHandle_v8_t *s_ignore, *r_ignore;
test_nccl_net_t *extNet = NULL;

ofi_log_function = logger;
Expand Down
Loading
Loading