From d6b7909dc6b8d197ac002c15f639d2606f9c9037 Mon Sep 17 00:00:00 2001 From: Arun Karthik Date: Tue, 10 Sep 2024 09:56:21 +0000 Subject: [PATCH] Add Multiplexed-round-robin scheduler This commit modifies the scheduler algorithm to round-robin the payload_msg on each QP or betwen pairs of QP's or triplets of QP's or quadruplets of QP's based on the min_stripe_size. --- include/nccl_ofi_param.h | 7 +- include/nccl_ofi_scheduler.h | 22 +++--- src/nccl_ofi_rdma.c | 10 +-- src/nccl_ofi_scheduler.c | 72 +++++++---------- tests/unit/scheduler.c | 146 +++++++++++++++++++++++++++-------- 5 files changed, 165 insertions(+), 92 deletions(-) diff --git a/include/nccl_ofi_param.h b/include/nccl_ofi_param.h index 7ae619c35..0cc2155ad 100644 --- a/include/nccl_ofi_param.h +++ b/include/nccl_ofi_param.h @@ -179,9 +179,10 @@ OFI_NCCL_PARAM_INT(disable_native_rdma_check, "DISABLE_NATIVE_RDMA_CHECK", 0); OFI_NCCL_PARAM_INT(disable_gdr_required_check, "DISABLE_GDR_REQUIRED_CHECK", 0); /* - * Maximum size of a message in bytes before message is multiplexed - */ -OFI_NCCL_PARAM_INT(round_robin_threshold, "ROUND_ROBIN_THRESHOLD", (256 * 1024)); + * Minimum data size in bytes before the message is to be striped across multiple + * rails. +*/ +OFI_NCCL_PARAM_INT(min_stripe_size, "MIN_STRIPE_SIZE", (210 * 1024)); /* * Minimum bounce buffers posted per endpoint. The plugin will attempt to post diff --git a/include/nccl_ofi_scheduler.h b/include/nccl_ofi_scheduler.h index e7e7a828c..1a32e850e 100644 --- a/include/nccl_ofi_scheduler.h +++ b/include/nccl_ofi_scheduler.h @@ -91,11 +91,13 @@ typedef struct nccl_net_ofi_threshold_scheduler { nccl_net_ofi_scheduler_t base; /* Round robin counter */ unsigned int rr_counter; + /* Multiplex round robin counter*/ + unsigned int mux_rr_counter; /* Lock for round robin counter */ pthread_mutex_t rr_lock; - /* Maximum size of a message in bytes before message is + /* Minimum size of the message in bytes before message is * multiplexed */ - size_t rr_threshold; + size_t min_stripe_size; } nccl_net_ofi_threshold_scheduler_t; /* @@ -109,16 +111,15 @@ void nccl_net_ofi_release_schedule(nccl_net_ofi_scheduler_t *scheduler, * * @param num_rails * Number of rails - * @param rr_threshold - * Maximum size of a message in bytes before message is multiplexed - * + * @param min_stripe_size + * Minimum size of a message in bytes before message is multiplexed * @return Scheduler, on success * NULL, on error * @return 0, on success * non-zero, on error */ int nccl_net_ofi_threshold_scheduler_init(int num_rails, - size_t rr_threshold, + size_t min_stripe_size, nccl_net_ofi_scheduler_t **scheduler); /* @@ -127,14 +128,17 @@ int nccl_net_ofi_threshold_scheduler_init(int num_rails, * A mininal stripe size `max_stripe_size' is calculated (multiple of * `align') that is sufficient to assign the whole message. Rails are * filled from low id to large id. The last rail may get assigned less - * data. + * data. The number of rails are calculated based on the ratio of + * (`data_size` / `min_stripe_size`) */ -void nccl_net_ofi_set_multiplexing_schedule(size_t size, +void nccl_net_ofi_set_multiplexed_round_robin_schedule(nccl_net_ofi_threshold_scheduler_t *scheduler, + size_t size, int num_rails, + int num_stripes, size_t align, nccl_net_ofi_schedule_t *schedule); -#ifdef __cplusplus +#ifdef _cplusplus } // End extern "C" #endif diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index 063a7bac2..a8df47033 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -6940,7 +6940,7 @@ nccl_net_ofi_rdma_device_release(nccl_net_ofi_device_t *base_device) static nccl_net_ofi_rdma_device_t * nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, int dev_id, struct fi_info *info_list, - nccl_ofi_topo_t *topo, size_t rr_threshold) + nccl_ofi_topo_t *topo, size_t min_strip_size) { int ret; @@ -6983,7 +6983,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, /* Create scheduler */ ret = nccl_net_ofi_threshold_scheduler_init(length, - rr_threshold, + min_strip_size, &device->scheduler); if (ret != 0) { goto error; @@ -7224,7 +7224,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, int num_devs = 0; struct fi_info *provider_list = NULL; unsigned int num_providers; - size_t rr_threshold = ofi_nccl_round_robin_threshold(); + size_t min_stripe_size = ofi_nccl_min_stripe_size(); nccl_net_ofi_plugin_t *plugin = NULL; nccl_ofi_topo_t *topo = NULL; struct fi_info *hints; @@ -7269,7 +7269,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, } if (ofi_nccl_eager_max_size() < 0 || - ofi_nccl_eager_max_size() > rr_threshold) { + ofi_nccl_eager_max_size() > min_stripe_size) { NCCL_OFI_WARN("Invalid value for EAGER_MAX_SIZE"); ret = ncclInvalidArgument; goto error; @@ -7365,7 +7365,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, nccl_net_ofi_rdma_device_t *device = nccl_net_ofi_rdma_device_create(plugin, dev_id, info_list, topo, - rr_threshold); + min_stripe_size); if (device == NULL) { NCCL_OFI_WARN("Device creation failed"); ret = -ENOMEM; diff --git a/src/nccl_ofi_scheduler.c b/src/nccl_ofi_scheduler.c index bac026e80..e73dff915 100644 --- a/src/nccl_ofi_scheduler.c +++ b/src/nccl_ofi_scheduler.c @@ -22,10 +22,25 @@ static inline size_t sizeof_schedule(int num_rails) + num_rails * sizeof(nccl_net_ofi_xfer_info_t); } -void nccl_net_ofi_set_multiplexing_schedule(size_t size, int num_rails, - size_t align, - nccl_net_ofi_schedule_t *schedule) +void nccl_net_ofi_set_multiplexed_round_robin_schedule(nccl_net_ofi_threshold_scheduler_t *scheduler, + size_t size, + int num_rails, + int num_stripes, + size_t align, + nccl_net_ofi_schedule_t *schedule) { + if (OFI_UNLIKELY(num_rails == 0)) return; + + /* TODO Number of stripes cannot be greater than number of rails */ + int rail_id; + nccl_net_ofi_mutex_lock(&scheduler->rr_lock); + + /* Retieve and increment multiplex-round-robin counter; wrap around if required */ + rail_id = scheduler->rr_counter; + scheduler->rr_counter = (scheduler->rr_counter + num_stripes) % num_rails; + + nccl_net_ofi_mutex_unlock(&scheduler->rr_lock); + /* Number of bytes left to assign */ size_t left = size; /* Offset into message */ @@ -35,50 +50,23 @@ void nccl_net_ofi_set_multiplexing_schedule(size_t size, int num_rails, schedule->num_xfer_infos = 0; - if (OFI_UNLIKELY(num_rails == 0)) return; - - max_stripe_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(size, num_rails), align) * align; + max_stripe_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(size, num_stripes), align) * align; /* Compute stripes and assign to rails */ - for (int rail_id = 0; rail_id != num_rails && left > 0; ++rail_id) { + for (int stripe_idx = 0; stripe_idx < num_stripes && left > 0; ++stripe_idx) { size_t stripe_size = NCCL_OFI_MIN(left, max_stripe_size); - schedule->rail_xfer_infos[rail_id].rail_id = rail_id; - schedule->rail_xfer_infos[rail_id].offset = offset; - schedule->rail_xfer_infos[rail_id].msg_size = stripe_size; + schedule->rail_xfer_infos[stripe_idx].rail_id = rail_id; + schedule->rail_xfer_infos[stripe_idx].offset = offset; + schedule->rail_xfer_infos[stripe_idx].msg_size = stripe_size; schedule->num_xfer_infos++; offset += stripe_size; left -= stripe_size; + rail_id = (rail_id + 1) % num_rails; } } -/* - * @brief Assign message round-robin - */ -static inline int set_round_robin_schedule(nccl_net_ofi_threshold_scheduler_t *scheduler, - size_t size, - int num_rails, - nccl_net_ofi_schedule_t *schedule) -{ - int rail_id; - - nccl_net_ofi_mutex_lock(&scheduler->rr_lock); - - /* Retieve and increment round-robin counter; wrap around if required */ - rail_id = (scheduler->rr_counter)++; - scheduler->rr_counter = scheduler->rr_counter == num_rails ? 0 : scheduler->rr_counter; - - nccl_net_ofi_mutex_unlock(&scheduler->rr_lock); - - schedule->num_xfer_infos = 1; - schedule->rail_xfer_infos[0].rail_id = rail_id; - schedule->rail_xfer_infos[0].offset = 0; - schedule->rail_xfer_infos[0].msg_size = size; - - return 0; -} - /* * @brief Assign message round-robin or multiplex message depending on its size * @@ -91,12 +79,8 @@ static inline int set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t * nccl_net_ofi_schedule_t *schedule) { int ret = 0; - if (size > scheduler->rr_threshold) { - nccl_net_ofi_set_multiplexing_schedule(size, num_rails, - align, schedule); - } else { - ret = set_round_robin_schedule(scheduler, size, num_rails, schedule); - } + int num_stripes = (int) NCCL_OFI_MIN(NCCL_OFI_DIV_CEIL(size, scheduler->min_stripe_size), num_rails); + nccl_net_ofi_set_multiplexed_round_robin_schedule(scheduler, size, num_rails, num_stripes, align, schedule); return ret; } @@ -238,7 +222,7 @@ int scheduler_init(int num_rails, nccl_net_ofi_scheduler_t *scheduler) } int nccl_net_ofi_threshold_scheduler_init(int num_rails, - size_t rr_threshold, + size_t min_stripe_size, nccl_net_ofi_scheduler_t **scheduler_p) { int ret = 0; @@ -261,7 +245,7 @@ int nccl_net_ofi_threshold_scheduler_init(int num_rails, scheduler->base.get_schedule = get_threshold_schedule; scheduler->base.fini = threshold_scheduler_fini; scheduler->rr_counter = 0; - scheduler->rr_threshold = rr_threshold; + scheduler->min_stripe_size = min_stripe_size; ret = nccl_net_ofi_mutex_init(&scheduler->rr_lock, NULL); if (ret) { diff --git a/tests/unit/scheduler.c b/tests/unit/scheduler.c index 1949161f3..05fec3b77 100644 --- a/tests/unit/scheduler.c +++ b/tests/unit/scheduler.c @@ -16,16 +16,24 @@ int create_multiplexed(size_t size, int num_rails, + int num_stripes, size_t align, nccl_net_ofi_schedule_t **schedule_p) { nccl_net_ofi_schedule_t *schedule = (nccl_net_ofi_schedule_t *)malloc( sizeof(nccl_net_ofi_schedule_t) + num_rails * sizeof(nccl_net_ofi_xfer_info_t)); + nccl_net_ofi_scheduler_t *scheduler_p; + size_t min_stripe_size = 4098; if (!schedule) { NCCL_OFI_WARN("Could not allocate schedule"); return -ENOMEM; } - nccl_net_ofi_set_multiplexing_schedule(size, num_rails, align, schedule); + if (nccl_net_ofi_threshold_scheduler_init(num_rails, min_stripe_size, &scheduler_p)) { + NCCL_OFI_WARN("Failed to initialize threshold scheduler"); + return -1; + } + nccl_net_ofi_threshold_scheduler_t * scheduler = (nccl_net_ofi_threshold_scheduler_t *)scheduler_p; + nccl_net_ofi_set_multiplexed_round_robin_schedule(scheduler, size, num_rails, num_stripes, align, schedule); *schedule_p = schedule; return 0; } @@ -83,13 +91,15 @@ int test_multiplexing_schedule() } size_t size; int num_rails; + int num_stripes; size_t align; int ret = 0; size = 1; num_rails = 0; + num_stripes = 0; align = 1; - ret = create_multiplexed(size, num_rails, align, &schedule); + ret = create_multiplexed(size, num_rails, num_stripes, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); free(ref_schedule); @@ -111,8 +121,9 @@ int test_multiplexing_schedule() /* No data */ size = 0; num_rails = 1; + num_stripes = 1; align = 1; - ret = create_multiplexed(size, num_rails, align, &schedule); + ret = create_multiplexed(size, num_rails, num_stripes, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); free(ref_schedule); @@ -130,8 +141,9 @@ int test_multiplexing_schedule() /* Data size = align - 1 */ size = 1; num_rails = 1; + num_stripes = 1; align = 2; - ret = create_multiplexed(size, num_rails, align, &schedule); + ret = create_multiplexed(size, num_rails, num_stripes, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); free(ref_schedule); @@ -152,8 +164,9 @@ int test_multiplexing_schedule() /* Data size = align */ size = 2; num_rails = 1; + num_stripes = 1; align = 2; - ret = create_multiplexed(size, num_rails, align, &schedule); + ret = create_multiplexed(size, num_rails, num_stripes, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); free(ref_schedule); @@ -174,8 +187,9 @@ int test_multiplexing_schedule() /* Data size = align + 1 */ size = 3; num_rails = 1; + num_stripes = 1; align = 2; - ret = create_multiplexed(size, num_rails, align, &schedule); + ret = create_multiplexed(size, num_rails, num_stripes, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); free(ref_schedule); @@ -200,8 +214,9 @@ int test_multiplexing_schedule() /* No data */ size = 0; num_rails = 3; + num_stripes = 3; align = 1; - ret = create_multiplexed(size, num_rails, align, &schedule); + ret = create_multiplexed(size, num_rails, num_stripes, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); free(ref_schedule); @@ -220,7 +235,7 @@ int test_multiplexing_schedule() num_rails = 3; align = 3; size = 4 * align - 1; - ret = create_multiplexed(size, num_rails, align, &schedule); + ret = create_multiplexed(size, num_rails, num_stripes, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); free(ref_schedule); @@ -245,7 +260,7 @@ int test_multiplexing_schedule() num_rails = 3; align = 3; size = 4 * align; - ret = create_multiplexed(size, num_rails, align, &schedule); + ret = create_multiplexed(size, num_rails, num_stripes, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); free(ref_schedule); @@ -270,7 +285,7 @@ int test_multiplexing_schedule() num_rails = 3; align = 3; size = 4 * align + 1; - ret = create_multiplexed(size, num_rails, align, &schedule); + ret = create_multiplexed(size, num_rails,num_stripes, align, &schedule); if (ret) { NCCL_OFI_WARN("Failed to create multiplexed schedule"); free(ref_schedule); @@ -302,32 +317,31 @@ int test_multiplexing_schedule() int test_threshold_scheduler() { nccl_net_ofi_schedule_t *schedule; - int num_rails = 2; + int num_rails = 4; int ret = 0; - size_t rr_threshold = 8192; + size_t min_stripe_size = 4096; + size_t msg_size = 0; nccl_net_ofi_schedule_t *ref_schedule = (nccl_net_ofi_schedule_t *)malloc( sizeof(nccl_net_ofi_schedule_t) + num_rails * sizeof(nccl_net_ofi_xfer_info_t)); nccl_net_ofi_scheduler_t *scheduler; - if (nccl_net_ofi_threshold_scheduler_init(num_rails, rr_threshold, &scheduler)) { + if (nccl_net_ofi_threshold_scheduler_init(num_rails, min_stripe_size, &scheduler)) { NCCL_OFI_WARN("Failed to initialize threshold scheduler"); free(ref_schedule); return -1; } - /* Verify that message with more than `rr_threshold' bytes is multiplexed */ - schedule = scheduler->get_schedule(scheduler, rr_threshold + 1, num_rails); + /* Verify that message with less than `min_stripe_size' bytes are assigned round-robin */ + schedule = scheduler->get_schedule(scheduler, min_stripe_size - 1, num_rails); if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); return -1; } - ref_schedule->num_xfer_infos = 2; + msg_size = min_stripe_size - 1; + ref_schedule->num_xfer_infos = 1; ref_schedule->rail_xfer_infos[0].rail_id = 0; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold / 2 + 128; - ref_schedule->rail_xfer_infos[1].rail_id = 1; - ref_schedule->rail_xfer_infos[1].offset = rr_threshold / 2 + 128; - ref_schedule->rail_xfer_infos[1].msg_size = rr_threshold / 2- 127; + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); @@ -336,17 +350,16 @@ int test_threshold_scheduler() } nccl_net_ofi_release_schedule(scheduler, schedule); - /* Verify that three messages with `rr_threshold' bytes are assigned round robin */ - schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); + schedule = scheduler->get_schedule(scheduler, min_stripe_size - 1, num_rails); if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); return -1; } ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; + ref_schedule->rail_xfer_infos[0].rail_id = 1; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold; + ref_schedule->rail_xfer_infos[0].msg_size = min_stripe_size - 1; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); @@ -355,16 +368,24 @@ int test_threshold_scheduler() } nccl_net_ofi_release_schedule(scheduler, schedule); - schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); + /* Verify that messages with greater than the `min_stripe_size' but less than 2x `min_stripe_size` + * bytes are assigned 2 rail multiplexing */ + schedule = scheduler->get_schedule(scheduler, min_stripe_size + 1, num_rails); + msg_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(min_stripe_size + 1, 2), 128) * 128; if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); return -1; } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 1; + ref_schedule->num_xfer_infos = 2; + ref_schedule->rail_xfer_infos[0].rail_id = 2; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold; + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; + + ref_schedule->rail_xfer_infos[1].rail_id = 3; + ref_schedule->rail_xfer_infos[1].offset = msg_size; + ref_schedule->rail_xfer_infos[1].msg_size = min_stripe_size + 1 - msg_size; + ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); @@ -373,16 +394,21 @@ int test_threshold_scheduler() } nccl_net_ofi_release_schedule(scheduler, schedule); - schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); + schedule = scheduler->get_schedule(scheduler, min_stripe_size + 1, num_rails); if (!schedule) { NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); return -1; } - ref_schedule->num_xfer_infos = 1; + ref_schedule->num_xfer_infos = 2; ref_schedule->rail_xfer_infos[0].rail_id = 0; ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold; + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; + + ref_schedule->rail_xfer_infos[1].rail_id = 1; + ref_schedule->rail_xfer_infos[1].offset = msg_size; + ref_schedule->rail_xfer_infos[1].msg_size = min_stripe_size + 1 - msg_size; + ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); @@ -391,6 +417,64 @@ int test_threshold_scheduler() } nccl_net_ofi_release_schedule(scheduler, schedule); + /* Verify that messages with greater than the 2x `min_stripe_size' but less than or equal to + * 3x `min_stripe_size` bytes are assigned 3 rail multiplexing */ + schedule = scheduler->get_schedule(scheduler, (min_stripe_size * 2) + 1, num_rails); + msg_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL((min_stripe_size * 2) + 1, 3), 128) * 128; + if (!schedule) { + NCCL_OFI_WARN("Failed to get schedule"); + free(ref_schedule); + return -1; + } + ref_schedule->num_xfer_infos = 3; + ref_schedule->rail_xfer_infos[0].rail_id = 2; + ref_schedule->rail_xfer_infos[0].offset = 0; + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; + ref_schedule->rail_xfer_infos[1].rail_id = 3; + ref_schedule->rail_xfer_infos[1].offset = msg_size; + ref_schedule->rail_xfer_infos[1].msg_size = msg_size; + ref_schedule->rail_xfer_infos[2].rail_id = 0; + ref_schedule->rail_xfer_infos[2].offset = 2 * msg_size; + ref_schedule->rail_xfer_infos[2].msg_size = (min_stripe_size * 2) + 1 - (2 * msg_size); + ret = verify_schedule(schedule, ref_schedule); + if (ret) { + NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); + return ret; + } + nccl_net_ofi_release_schedule(scheduler, schedule); + + /* Verify that messages with greater than the 3x `min_stripe_size' are assigned 4 rail multiplexing */ + schedule = scheduler->get_schedule(scheduler, (min_stripe_size * 3) + 1, num_rails); + msg_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL((min_stripe_size * 3) + 1, 4), 128) * 128; + if (!schedule) { + NCCL_OFI_WARN("Failed to get schedule"); + free(ref_schedule); + return -1; + } + ref_schedule->num_xfer_infos = 4; + ref_schedule->rail_xfer_infos[0].rail_id = 1; + ref_schedule->rail_xfer_infos[0].offset = 0; + ref_schedule->rail_xfer_infos[0].msg_size = msg_size; + ref_schedule->rail_xfer_infos[1].rail_id = 2; + ref_schedule->rail_xfer_infos[1].offset = msg_size; + ref_schedule->rail_xfer_infos[1].msg_size = msg_size; + ref_schedule->rail_xfer_infos[2].rail_id = 3; + ref_schedule->rail_xfer_infos[2].offset = 2 * msg_size; + ref_schedule->rail_xfer_infos[2].msg_size = msg_size; + ref_schedule->rail_xfer_infos[3].rail_id = 0; + ref_schedule->rail_xfer_infos[3].offset = 3 * msg_size; + ref_schedule->rail_xfer_infos[3].msg_size = (min_stripe_size * 3) + 1 - (3 * msg_size); + ret = verify_schedule(schedule, ref_schedule); + if (ret) { + NCCL_OFI_WARN("Verification failed"); + free(ref_schedule); + return ret; + } + nccl_net_ofi_release_schedule(scheduler, schedule); + + NCCL_OFI_WARN("4 Rail Mulitplexing scheduler test Successful"); + ret = scheduler->fini(scheduler); if (ret) { NCCL_OFI_WARN("Failed to destroy threshold scheduler");