From 109cd44c7ea2a384ace8a255b8d20c3b9b3dd351 Mon Sep 17 00:00:00 2001 From: Zhihao Lin <36994684+LZHgrla@users.noreply.github.com> Date: Thu, 11 Jan 2024 10:50:36 +0800 Subject: [PATCH] [Fix] Fix dist.collect_results to keep all ranks' elements (#1469) --- mmengine/dist/dist.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/mmengine/dist/dist.py b/mmengine/dist/dist.py index b6dd769f90..1dbedb3430 100644 --- a/mmengine/dist/dist.py +++ b/mmengine/dist/dist.py @@ -13,7 +13,7 @@ from torch._utils import (_flatten_dense_tensors, _take_tensors, _unflatten_dense_tensors) from torch.distributed import ProcessGroup - +from itertools import zip_longest, chain import mmengine from .utils import (get_world_size, get_rank, get_backend, get_dist_info, get_default_group, barrier, get_data_device, @@ -1010,8 +1010,10 @@ def collect_results_cpu(result_part: list, part_list.append(pickle.load(f)) # sort the results ordered_results = [] - for res in zip(*part_list): - ordered_results.extend(list(res)) + zipped_results = zip_longest(*part_list) + ordered_results = [ + i for i in chain.from_iterable(zipped_results) if i is not None + ] # the dataloader may pad some samples ordered_results = ordered_results[:size] # remove tmp dir @@ -1032,8 +1034,10 @@ def _collect_results_device(result_part: list, size: int) -> Optional[list]: if rank == 0: # sort the results ordered_results = [] - for res in zip(*part_list): - ordered_results.extend(list(res)) + zipped_results = zip_longest(*part_list) + ordered_results = [ + i for i in chain.from_iterable(zipped_results) if i is not None + ] # the dataloader may pad some samples ordered_results = ordered_results[:size] return ordered_results