diff --git a/include/nccl_ofi_param.h b/include/nccl_ofi_param.h index 7ae619c35..0daac630c 100644 --- a/include/nccl_ofi_param.h +++ b/include/nccl_ofi_param.h @@ -179,9 +179,9 @@ OFI_NCCL_PARAM_INT(disable_native_rdma_check, "DISABLE_NATIVE_RDMA_CHECK", 0); OFI_NCCL_PARAM_INT(disable_gdr_required_check, "DISABLE_GDR_REQUIRED_CHECK", 0); /* - * Maximum size of a message in bytes before message is multiplexed - */ -OFI_NCCL_PARAM_INT(round_robin_threshold, "ROUND_ROBIN_THRESHOLD", (256 * 1024)); + * Messages sized larger than this threshold will be striped across multiple rails +*/ +OFI_NCCL_PARAM_INT(min_stripe_size, "MIN_STRIPE_SIZE", (256 * 1024)); /* * Minimum bounce buffers posted per endpoint. The plugin will attempt to post diff --git a/include/nccl_ofi_scheduler.h b/include/nccl_ofi_scheduler.h index e7e7a828c..f7e582e54 100644 --- a/include/nccl_ofi_scheduler.h +++ b/include/nccl_ofi_scheduler.h @@ -93,9 +93,9 @@ typedef struct nccl_net_ofi_threshold_scheduler { unsigned int rr_counter; /* Lock for round robin counter */ pthread_mutex_t rr_lock; - /* Maximum size of a message in bytes before message is + /* Minimum size of the message in bytes before message is * multiplexed */ - size_t rr_threshold; + size_t min_stripe_size; } nccl_net_ofi_threshold_scheduler_t; /* @@ -109,16 +109,15 @@ void nccl_net_ofi_release_schedule(nccl_net_ofi_scheduler_t *scheduler, * * @param num_rails * Number of rails - * @param rr_threshold - * Maximum size of a message in bytes before message is multiplexed - * + * @param min_stripe_size + * Minimum size of a message in bytes before message is multiplexed * @return Scheduler, on success * NULL, on error * @return 0, on success * non-zero, on error */ int nccl_net_ofi_threshold_scheduler_init(int num_rails, - size_t rr_threshold, + size_t min_stripe_size, nccl_net_ofi_scheduler_t **scheduler); /* @@ -127,9 +126,11 @@ int nccl_net_ofi_threshold_scheduler_init(int num_rails, * A mininal stripe size `max_stripe_size' is calculated (multiple of * `align') that is sufficient to assign the whole message. Rails are * filled from low id to large id. The last rail may get assigned less - * data. + * data. The number of rails are calculated based on the ratio of + * (`data_size` / `min_stripe_size`) */ -void nccl_net_ofi_set_multiplexing_schedule(size_t size, +int nccl_net_ofi_set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t *scheduler, + size_t size, int num_rails, size_t align, nccl_net_ofi_schedule_t *schedule); diff --git a/src/nccl_ofi_rdma.c b/src/nccl_ofi_rdma.c index ee61a885e..50b0845e1 100644 --- a/src/nccl_ofi_rdma.c +++ b/src/nccl_ofi_rdma.c @@ -6940,7 +6940,7 @@ nccl_net_ofi_rdma_device_release(nccl_net_ofi_device_t *base_device) static nccl_net_ofi_rdma_device_t * nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, int dev_id, struct fi_info *info_list, - nccl_ofi_topo_t *topo, size_t rr_threshold) + nccl_ofi_topo_t *topo, size_t min_strip_size) { int ret; @@ -6983,7 +6983,7 @@ nccl_net_ofi_rdma_device_create(nccl_net_ofi_plugin_t *plugin, /* Create scheduler */ ret = nccl_net_ofi_threshold_scheduler_init(length, - rr_threshold, + min_strip_size, &device->scheduler); if (ret != 0) { goto error; @@ -7196,7 +7196,7 @@ static inline int nccl_net_ofi_rdma_plugin_complete_init(nccl_net_ofi_plugin_t * nccl_net_ofi_rdma_device_t *device = nccl_net_ofi_rdma_device_create(&rdma_plugin->base, dev_id, info_list, rdma_plugin->topo, - ofi_nccl_round_robin_threshold()); + ofi_nccl_min_stripe_size()); if (device == NULL) { NCCL_OFI_WARN("Device creation failed"); return -ENOMEM; @@ -7318,7 +7318,7 @@ int nccl_net_ofi_rdma_init(const char *provider_filter, } if (ofi_nccl_eager_max_size() < 0 || - ofi_nccl_eager_max_size() > ofi_nccl_round_robin_threshold()) { + ofi_nccl_eager_max_size() > ofi_nccl_min_stripe_size()) { NCCL_OFI_WARN("Invalid value for EAGER_MAX_SIZE"); ret = ncclInvalidArgument; goto error; diff --git a/src/nccl_ofi_scheduler.c b/src/nccl_ofi_scheduler.c index bac026e80..094cde358 100644 --- a/src/nccl_ofi_scheduler.c +++ b/src/nccl_ofi_scheduler.c @@ -22,80 +22,54 @@ static inline size_t sizeof_schedule(int num_rails) + num_rails * sizeof(nccl_net_ofi_xfer_info_t); } -void nccl_net_ofi_set_multiplexing_schedule(size_t size, int num_rails, +int nccl_net_ofi_set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t *scheduler, + size_t size, + int num_rails, size_t align, nccl_net_ofi_schedule_t *schedule) { + int ret = 0; + + /* Number of stripes is atleast 1 for zero-sized messages and at most equal to num of rails */ + int num_stripes = (int) NCCL_OFI_MAX(1, NCCL_OFI_MIN(NCCL_OFI_DIV_CEIL(size, scheduler->min_stripe_size), num_rails)); + if (OFI_UNLIKELY(num_rails == 0)) { + return -1; + } + + assert(num_stripes <= num_rails); + + int curr_rail_id, next_rail_id; + nccl_net_ofi_mutex_lock(&scheduler->rr_lock); + + /* Retieve and increment multiplex-round-robin counter; wrap around if required */ + curr_rail_id = scheduler->rr_counter; + next_rail_id = (curr_rail_id + num_stripes) % num_rails; + scheduler->rr_counter = next_rail_id; + + nccl_net_ofi_mutex_unlock(&scheduler->rr_lock); + /* Number of bytes left to assign */ size_t left = size; /* Offset into message */ size_t offset = 0; - /* Maximum size of a stripe */ - size_t max_stripe_size = 0; - - schedule->num_xfer_infos = 0; - if (OFI_UNLIKELY(num_rails == 0)) return; + /* Calculate max stripe size as a multiple of 128 for alignment. + * Split message size across stripes, ensuring each stripe is within max_stripe_size and LL128 aligned */ + size_t max_stripe_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(size, num_stripes), align) * align; - max_stripe_size = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(size, num_rails), align) * align; + schedule->num_xfer_infos = num_stripes; /* Compute stripes and assign to rails */ - for (int rail_id = 0; rail_id != num_rails && left > 0; ++rail_id) { + for (int stripe_idx = 0; stripe_idx < num_stripes; ++stripe_idx) { size_t stripe_size = NCCL_OFI_MIN(left, max_stripe_size); - schedule->rail_xfer_infos[rail_id].rail_id = rail_id; - schedule->rail_xfer_infos[rail_id].offset = offset; - schedule->rail_xfer_infos[rail_id].msg_size = stripe_size; + schedule->rail_xfer_infos[stripe_idx].rail_id = curr_rail_id; + schedule->rail_xfer_infos[stripe_idx].offset = offset; + schedule->rail_xfer_infos[stripe_idx].msg_size = stripe_size; - schedule->num_xfer_infos++; offset += stripe_size; left -= stripe_size; - } -} - -/* - * @brief Assign message round-robin - */ -static inline int set_round_robin_schedule(nccl_net_ofi_threshold_scheduler_t *scheduler, - size_t size, - int num_rails, - nccl_net_ofi_schedule_t *schedule) -{ - int rail_id; - - nccl_net_ofi_mutex_lock(&scheduler->rr_lock); - - /* Retieve and increment round-robin counter; wrap around if required */ - rail_id = (scheduler->rr_counter)++; - scheduler->rr_counter = scheduler->rr_counter == num_rails ? 0 : scheduler->rr_counter; - - nccl_net_ofi_mutex_unlock(&scheduler->rr_lock); - - schedule->num_xfer_infos = 1; - schedule->rail_xfer_infos[0].rail_id = rail_id; - schedule->rail_xfer_infos[0].offset = 0; - schedule->rail_xfer_infos[0].msg_size = size; - - return 0; -} - -/* - * @brief Assign message round-robin or multiplex message depending on its size - * - * Messages larger than `threshold' are multiplexed. Smaller messages are assigned round-robin. - */ -static inline int set_schedule_by_threshold(nccl_net_ofi_threshold_scheduler_t *scheduler, - size_t size, - int num_rails, - size_t align, - nccl_net_ofi_schedule_t *schedule) -{ - int ret = 0; - if (size > scheduler->rr_threshold) { - nccl_net_ofi_set_multiplexing_schedule(size, num_rails, - align, schedule); - } else { - ret = set_round_robin_schedule(scheduler, size, num_rails, schedule); + curr_rail_id = (curr_rail_id + 1) % num_rails; } return ret; } @@ -146,7 +120,7 @@ static nccl_net_ofi_schedule_t *get_threshold_schedule(nccl_net_ofi_scheduler_t NCCL_OFI_WARN("Failed to allocate schedule"); return NULL; } - ret = set_schedule_by_threshold(scheduler, size, num_rails, align, + ret = nccl_net_ofi_set_schedule_by_threshold(scheduler, size, num_rails, align, schedule); if (OFI_UNLIKELY(ret)) { nccl_net_ofi_release_schedule(scheduler_p, schedule); @@ -238,7 +212,7 @@ int scheduler_init(int num_rails, nccl_net_ofi_scheduler_t *scheduler) } int nccl_net_ofi_threshold_scheduler_init(int num_rails, - size_t rr_threshold, + size_t min_stripe_size, nccl_net_ofi_scheduler_t **scheduler_p) { int ret = 0; @@ -261,7 +235,7 @@ int nccl_net_ofi_threshold_scheduler_init(int num_rails, scheduler->base.get_schedule = get_threshold_schedule; scheduler->base.fini = threshold_scheduler_fini; scheduler->rr_counter = 0; - scheduler->rr_threshold = rr_threshold; + scheduler->min_stripe_size = min_stripe_size; ret = nccl_net_ofi_mutex_init(&scheduler->rr_lock, NULL); if (ret) { diff --git a/tests/unit/scheduler.c b/tests/unit/scheduler.c index 1d535423d..03f3e060f 100644 --- a/tests/unit/scheduler.c +++ b/tests/unit/scheduler.c @@ -14,22 +14,6 @@ #include "test-common.h" #include "nccl_ofi_scheduler.h" -int create_multiplexed(size_t size, - int num_rails, - size_t align, - nccl_net_ofi_schedule_t **schedule_p) -{ - nccl_net_ofi_schedule_t *schedule = (nccl_net_ofi_schedule_t *)malloc( - sizeof(nccl_net_ofi_schedule_t) + num_rails * sizeof(nccl_net_ofi_xfer_info_t)); - if (!schedule) { - NCCL_OFI_WARN("Could not allocate schedule"); - return -ENOMEM; - } - nccl_net_ofi_set_multiplexing_schedule(size, num_rails, align, schedule); - *schedule_p = schedule; - return 0; -} - int verify_xfer_info(nccl_net_ofi_xfer_info_t *xfer, nccl_net_ofi_xfer_info_t *ref_xfer, int xfer_id) { int ret = ref_xfer->rail_id != xfer->rail_id @@ -72,331 +56,221 @@ int verify_schedule(nccl_net_ofi_schedule_t *schedule, nccl_net_ofi_schedule_t * return ret; } -int test_multiplexing_schedule() -{ - nccl_net_ofi_schedule_t *schedule = NULL; - nccl_net_ofi_schedule_t *ref_schedule = (nccl_net_ofi_schedule_t *)malloc( - sizeof(nccl_net_ofi_schedule_t) + 3 * sizeof(nccl_net_ofi_xfer_info_t)); - if (!ref_schedule) { - NCCL_OFI_WARN("Could not allocate schedule"); - return -ENOMEM; - } - size_t size; - int num_rails; - size_t align; +int create_ref_schedule(nccl_net_ofi_schedule_t **schedule, int num_xfer_infos) { int ret = 0; + *schedule = (nccl_net_ofi_schedule_t *)malloc( + sizeof(nccl_net_ofi_xfer_info_t) + num_xfer_infos * sizeof(nccl_net_ofi_xfer_info_t)); - size = 1; - num_rails = 0; - align = 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 0; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); - - /************************/ - /* Test one rail */ - /************************/ - - /* No data */ - size = 0; - num_rails = 1; - align = 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 0; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); - - /* Data size = align - 1 */ - size = 1; - num_rails = 1; - align = 2; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = size; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; + if (!(*schedule)) { + NCCL_OFI_WARN("Could not allocate schedule"); + return -ENOMEM; } - free(schedule); + (*schedule)->num_xfer_infos = num_xfer_infos; - /* Data size = align */ - size = 2; - num_rails = 1; - align = 2; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = size; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); + return ret; +} - /* Data size = align + 1 */ - size = 3; - num_rails = 1; - align = 2; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = size; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); +int set_ref_schedule(nccl_net_ofi_schedule_t *schedule, int index, int rail_id, int offset, int msg_size) { + int ret = 0; + if (index >= schedule->num_xfer_infos) { + NCCL_OFI_WARN("Index out of bounds"); + return -EINVAL; + } - /************************/ - /* Test three rail */ - /************************/ + schedule->rail_xfer_infos[index].rail_id = rail_id; + schedule->rail_xfer_infos[index].offset = offset; + schedule->rail_xfer_infos[index].msg_size = msg_size; - /* No data */ - size = 0; - num_rails = 3; - align = 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 0; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - free(schedule); + return ret; +} - /* Data size = 4 * align - 1 */ - num_rails = 3; - align = 3; - size = 4 * align - 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 2; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = 2 * align; - ref_schedule->rail_xfer_infos[1].rail_id = 1; - ref_schedule->rail_xfer_infos[1].offset = 2 * align; - ref_schedule->rail_xfer_infos[1].msg_size = 2 * align - 1; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); +int test_multiplexer(nccl_net_ofi_scheduler_t *scheduler, + int num_rails, + size_t msg_size, + int num_stripes, + int *rail_id, + int *offset, + size_t *msg_size_per_stripe) { + int ret = 0; + nccl_net_ofi_schedule_t *ref_schedule; + nccl_net_ofi_schedule_t *schedule = NULL; + if (create_ref_schedule(&ref_schedule, num_stripes)) { return ret; - } - free(schedule); + }; - /* Data size = 4 * align */ - num_rails = 3; - align = 3; - size = 4 * align; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); + schedule = scheduler->get_schedule(scheduler, msg_size, num_rails); + if (!schedule) { + NCCL_OFI_WARN("Failed to get schedule"); free(ref_schedule); return ret; } - ref_schedule->num_xfer_infos = 2; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = 2 * align; - ref_schedule->rail_xfer_infos[1].rail_id = 1; - ref_schedule->rail_xfer_infos[1].offset = 2 * align; - ref_schedule->rail_xfer_infos[1].msg_size = 2 * align; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; + for (int idx = 0; idx < num_stripes; idx++) { + set_ref_schedule(ref_schedule, idx, rail_id[idx], offset[idx], msg_size_per_stripe[idx]); } - free(schedule); - /* Data size = 4 * align + 1 */ - num_rails = 3; - align = 3; - size = 4 * align + 1; - ret = create_multiplexed(size, num_rails, align, &schedule); - if (ret) { - NCCL_OFI_WARN("Failed to create multiplexed schedule"); - free(ref_schedule); - return ret; - } - ref_schedule->num_xfer_infos = 3; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = 2 * align; - ref_schedule->rail_xfer_infos[1].rail_id = 1; - ref_schedule->rail_xfer_infos[1].offset = 2 * align; - ref_schedule->rail_xfer_infos[1].msg_size = 2 * align; - ref_schedule->rail_xfer_infos[2].rail_id = 2; - ref_schedule->rail_xfer_infos[2].offset = 4 * align; - ref_schedule->rail_xfer_infos[2].msg_size = 1; ret = verify_schedule(schedule, ref_schedule); if (ret) { NCCL_OFI_WARN("Verification failed"); free(ref_schedule); return ret; } - free(schedule); - + nccl_net_ofi_release_schedule(scheduler, schedule); free(ref_schedule); - - return 0; + return ret; } -int test_threshold_scheduler() -{ - nccl_net_ofi_schedule_t *schedule; - int num_rails = 2; +int test_threshold_scheduler() { + size_t min_stripe_size = 4096; + size_t align = 128; + int num_rails = 4; + int num_stripes = 0; int ret = 0; - size_t rr_threshold = 8192; - nccl_net_ofi_schedule_t *ref_schedule = (nccl_net_ofi_schedule_t *)malloc( - sizeof(nccl_net_ofi_schedule_t) + num_rails * sizeof(nccl_net_ofi_xfer_info_t)); + nccl_net_ofi_scheduler_t *scheduler; - if (nccl_net_ofi_threshold_scheduler_init(num_rails, rr_threshold, &scheduler)) { + if (nccl_net_ofi_threshold_scheduler_init(num_rails, min_stripe_size, &scheduler)) { NCCL_OFI_WARN("Failed to initialize threshold scheduler"); - free(ref_schedule); - return -1; - } - - /* Verify that message with more than `rr_threshold' bytes is multiplexed */ - schedule = scheduler->get_schedule(scheduler, rr_threshold + 1, num_rails); - if (!schedule) { - NCCL_OFI_WARN("Failed to get schedule"); - free(ref_schedule); - return -1; - } - ref_schedule->num_xfer_infos = 2; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold / 2 + 128; - ref_schedule->rail_xfer_infos[1].rail_id = 1; - ref_schedule->rail_xfer_infos[1].offset = rr_threshold / 2 + 128; - ref_schedule->rail_xfer_infos[1].msg_size = rr_threshold / 2- 127; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - nccl_net_ofi_release_schedule(scheduler, schedule); - - /* Verify that three messages with `rr_threshold' bytes are assigned round robin */ - schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); - if (!schedule) { - NCCL_OFI_WARN("Failed to get schedule"); - free(ref_schedule); - return -1; - } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; - } - nccl_net_ofi_release_schedule(scheduler, schedule); - - schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); - if (!schedule) { - NCCL_OFI_WARN("Failed to get schedule"); - free(ref_schedule); - return -1; - } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 1; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); return ret; } - nccl_net_ofi_release_schedule(scheduler, schedule); - schedule = scheduler->get_schedule(scheduler, rr_threshold, num_rails); - if (!schedule) { - NCCL_OFI_WARN("Failed to get schedule"); - free(ref_schedule); - return -1; - } - ref_schedule->num_xfer_infos = 1; - ref_schedule->rail_xfer_infos[0].rail_id = 0; - ref_schedule->rail_xfer_infos[0].offset = 0; - ref_schedule->rail_xfer_infos[0].msg_size = rr_threshold; - ret = verify_schedule(schedule, ref_schedule); - if (ret) { - NCCL_OFI_WARN("Verification failed"); - free(ref_schedule); - return ret; + /* To ensure that the LL128 alignment is maintained below message sizes are tested between the multiple of `min_stripe_size` + 1. min_stripe_size + 1 + 2. min_stripe_size + align - 1 + 3. min_stripe_size + align + 4. min_stripe_size + align + 1 + 5. 2*min_stripe_size - 1 + 6. 2*min_stripe_size + */ + + /* Verify that message with less than or equal to `min_stripe_size' bytes is assigned + * round-robin. Verify that zero-sized messages is also assigned one rail and follow + * round-robin algorithm */ + num_stripes = 1; + size_t msg_sizes_1[6] = {0, (min_stripe_size / 2) + align - 1, (min_stripe_size / 2) + align, + (min_stripe_size / 2) + align + 1, min_stripe_size - 1, min_stripe_size}; + size_t msg_size_per_stripe_1[6][1] = {{msg_sizes_1[0]}, {msg_sizes_1[1]}, {msg_sizes_1[2]}, {msg_sizes_1[3]}, {msg_sizes_1[4]}, {msg_sizes_1[5]}}; + int rail_ids_1[6][1] = {{0}, {1}, {2}, {3}, {0}, {1}}; /* In round-robin for each iteration a new rail-id is used */ + int offsets_1[6][1] = {{0}, {0}, {0}, {0}, {0}, {0}}; /* Offset remaines 0 in round robin */ + for (int iter = 0; iter < 6; iter++) { + ret = test_multiplexer(scheduler, num_rails, msg_sizes_1[iter], num_stripes, rail_ids_1[iter], offsets_1[iter], msg_size_per_stripe_1[iter]); + if (ret) { + NCCL_OFI_WARN("Verification failed"); + return ret; + } + } + + /* Verify that messages with greater than the `min_stripe_size' but less than 2x `min_stripe_size` + * bytes are assigned 2 rail multiplexing */ + num_stripes = 2; + size_t msg_sizes_2[6] = {min_stripe_size + 1, min_stripe_size + align - 1, min_stripe_size + align, + min_stripe_size + align + 1, (2 * min_stripe_size) - 1, (2 * min_stripe_size)}; + size_t stripe_size[6]; + size_t remaining_stripe_size[6]; + for(int iter = 0; iter < 6; iter++){ + stripe_size[iter] = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(msg_sizes_2[iter], num_stripes), align) * align; + remaining_stripe_size[iter] = msg_sizes_2[iter] - stripe_size[iter]; + } + + /* For each message ensure that two rails are used. Also ensure that the rail-id pairs + * are round-robin between each schedule */ + int rail_ids_2[6][2] = {{2, 3}, + {0, 1}, + {2, 3}, + {0, 1}, + {2, 3}, + {0, 1}}; + int offsets_2[6][2] = {{0, stripe_size[0]}, + {0, stripe_size[1]}, + {0, stripe_size[2]}, + {0, stripe_size[3]}, + {0, stripe_size[4]}, + {0, stripe_size[5]}}; + size_t msg_size_per_stripe_2[6][2] = {{stripe_size[0], remaining_stripe_size[0]}, + {stripe_size[1], remaining_stripe_size[1]}, + {stripe_size[2], remaining_stripe_size[2]}, + {stripe_size[3], remaining_stripe_size[3]}, + {stripe_size[4], remaining_stripe_size[4]}, + {stripe_size[5], remaining_stripe_size[5]}}; + for (int iter = 0; iter < 6; iter++) { + ret = test_multiplexer(scheduler, num_rails, msg_sizes_2[iter], num_stripes, rail_ids_2[iter], offsets_2[iter], msg_size_per_stripe_2[iter]); + if (ret) { + NCCL_OFI_WARN("Verification failed"); + return ret; + } + } + + /* Verify that messages with greater than the 2x `min_stripe_size' but less than or equal to + * 3x `min_stripe_size` bytes are assigned 3 rail multiplexing */ + num_stripes = 3; + size_t msg_sizes_3[6] = {(2 * min_stripe_size) + 1, (2 * min_stripe_size) + align - 1, (2 * min_stripe_size) + align, + (2 * min_stripe_size) + align + 1, (3 * min_stripe_size) - 1, (3 * min_stripe_size)}; + for (int iter = 0; iter < 6; iter++){ + stripe_size[iter] = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(msg_sizes_3[iter], num_stripes), align) * align; + remaining_stripe_size[iter] = msg_sizes_3[iter] - (2 * stripe_size[iter]); + } + /* For each message ensure that three rails are used. Also ensure that the rail-id triplets + * are round-robin between each schedule */ + int rail_ids_3[6][3] = {{2, 3, 0}, + {1, 2, 3}, + {0, 1, 2}, + {3, 0, 1}, + {2, 3, 0}, + {1, 2, 3}}; + int offsets_3[6][3] = {{0, stripe_size[0], stripe_size[0] * 2}, + {0, stripe_size[1], stripe_size[1] * 2}, + {0, stripe_size[2], stripe_size[2] * 2}, + {0, stripe_size[3], stripe_size[3] * 2}, + {0, stripe_size[4], stripe_size[4] * 2}, + {0, stripe_size[5], stripe_size[5] * 2}}; + size_t msg_size_per_stripe_3[6][3] = {{stripe_size[0], stripe_size[0], remaining_stripe_size[0]}, + {stripe_size[1], stripe_size[1], remaining_stripe_size[1]}, + {stripe_size[2], stripe_size[2], remaining_stripe_size[2]}, + {stripe_size[3], stripe_size[3], remaining_stripe_size[3]}, + {stripe_size[4], stripe_size[4], remaining_stripe_size[4]}, + {stripe_size[5], stripe_size[5], remaining_stripe_size[5]}}; + + for (int iter = 0; iter < 6; iter++) { + ret = test_multiplexer(scheduler, num_rails, msg_sizes_3[iter], num_stripes, rail_ids_3[iter], offsets_3[iter], msg_size_per_stripe_3[iter]); + if (ret) { + NCCL_OFI_WARN("Verification failed"); + return ret; + } + } + + /* Verify that messages with greater than the 3x `min_stripe_size' are assigned 4 rail multiplexing */ + num_stripes = 4; + size_t msg_sizes_4[6] = {(3 * min_stripe_size) + 1, (3 * min_stripe_size) + align - 1, (3 * min_stripe_size) + align, + (3 * min_stripe_size) + align + 1, (4 * min_stripe_size) - 1, (4 * min_stripe_size)}; + for (int iter = 0; iter < 6; iter++){ + stripe_size[iter] = NCCL_OFI_DIV_CEIL(NCCL_OFI_DIV_CEIL(msg_sizes_4[iter], num_stripes), align) * align; + remaining_stripe_size[iter] = msg_sizes_4[iter] - (3 * stripe_size[iter]); + } + /* For each message ensure that all four rails are used. */ + int rail_ids_4[6][4] = {{0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}, {0, 1, 2, 3}}; + int offsets_4[6][4] = {{0, stripe_size[0], stripe_size[0] * 2, stripe_size[0] * 3}, + {0, stripe_size[1], stripe_size[1] * 2, stripe_size[1] * 3}, + {0, stripe_size[2], stripe_size[2] * 2, stripe_size[2] * 3}, + {0, stripe_size[3], stripe_size[3] * 2, stripe_size[3] * 3}, + {0, stripe_size[4], stripe_size[4] * 2, stripe_size[4] * 3}, + {0, stripe_size[5], stripe_size[5] * 2, stripe_size[5] * 3}}; + size_t msg_size_per_stripe_4[6][4] = {{stripe_size[0], stripe_size[0], stripe_size[0], remaining_stripe_size[0]}, + {stripe_size[1], stripe_size[1], stripe_size[1], remaining_stripe_size[1]}, + {stripe_size[2], stripe_size[2], stripe_size[2], remaining_stripe_size[2]}, + {stripe_size[3], stripe_size[3], stripe_size[3], remaining_stripe_size[3]}, + {stripe_size[4], stripe_size[4], stripe_size[4], remaining_stripe_size[4]}, + {stripe_size[5], stripe_size[5], stripe_size[5], remaining_stripe_size[5]}}; + + for (int iter = 0; iter < 6; iter++) { + ret = test_multiplexer(scheduler, num_rails, msg_sizes_4[iter], num_stripes, rail_ids_4[iter], offsets_4[iter], msg_size_per_stripe_4[iter]); + if (ret) { + NCCL_OFI_WARN("Verification failed"); + return ret; + } } - nccl_net_ofi_release_schedule(scheduler, schedule); ret = scheduler->fini(scheduler); if (ret) { NCCL_OFI_WARN("Failed to destroy threshold scheduler"); } - free(ref_schedule); - return 0; } @@ -406,7 +280,7 @@ int main(int argc, char *argv[]) ofi_log_function = logger; system_page_size = 4096; - ret = test_multiplexing_schedule() || test_threshold_scheduler(); + ret = test_threshold_scheduler(); /** Success!? **/ return ret;