Skip to content

Commit

Permalink
tuner: added PAT algorithm to match NCCL interface
Browse files Browse the repository at this point in the history
NCCL 2.23 has introduced the PAT algorithm for AllGather and ReduceScatter.
This commit is updating the list of algorithms in the tuner to match NCCL's.

Signed-off-by: Amedeo Sapio <asapio@amazon.com>
  • Loading branch information
AmedeoSapio committed Oct 2, 2024
1 parent 90f1756 commit 4346d6e
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 2 deletions.
3 changes: 2 additions & 1 deletion include/nccl-headers/nvidia/tuner.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,15 @@ typedef enum {
ncclNumFuncs = 8
} ncclFunc_t;

#define NCCL_NUM_ALGORITHMS 6 // Tree/Ring/CollNet*
#define NCCL_NUM_ALGORITHMS 7 // Tree/Ring/CollNet*
#define NCCL_ALGO_UNDEF -1
#define NCCL_ALGO_TREE 0
#define NCCL_ALGO_RING 1
#define NCCL_ALGO_COLLNET_DIRECT 2
#define NCCL_ALGO_COLLNET_CHAIN 3
#define NCCL_ALGO_NVLS 4
#define NCCL_ALGO_NVLS_TREE 5
#define NCCL_ALGO_PAT 6

#define NCCL_NUM_PROTOCOLS 3 // Simple/LL/LL128
#define NCCL_PROTO_UNDEF -1
Expand Down
2 changes: 2 additions & 0 deletions src/tuner/nccl_ofi_tuner.c
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,8 @@ ncclResult_t nccl_ofi_tuner_get_coll_info(void *context,
protocol = nccl_ofi_tuner_ctx->regions[i].protocol;
if (table[algorithm][protocol] == NCCL_ALGO_PROTO_IGNORE || algorithm >= numAlgo ||
protocol >= numProto) {
/* Either NCCL says this combination is not valid/applicable or the algorithm or protocol is
* not in the table, hence it is not supported by this NCCL version. */
continue;
}

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/show_tuner_decisions.c
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

#include "nccl_ofi_tuner.h"

static const char *algo_names[] = { "tree", "ring", "collnet_direct", "collnet_chain", "nvls", "nvlstree" };
static const char *algo_names[] = { "tree", "ring", "collnet_direct", "collnet_chain", "nvls", "nvlstree" , "pat" };
static const char *proto_names[] = { "ll", "ll128", "simple" };
void dummy_logger(ncclDebugLogLevel level, unsigned long flags, const char *file, int line, const char *fmt, ...) { return; };

Expand Down

0 comments on commit 4346d6e

Please sign in to comment.