[Fix] Fix dist.collect_results to keep all ranks' elements (#1469)

pull/1411/head^2
Zhihao Lin 2024-01-11 10:50:36 +08:00 committed by GitHub
parent b51bf60964
commit 109cd44c7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 9 additions and 5 deletions

14
mmengine/dist/dist.py vendored
View File

@ -13,7 +13,7 @@ from torch import distributed as torch_dist
from torch._utils import (_flatten_dense_tensors, _take_tensors,
_unflatten_dense_tensors)
from torch.distributed import ProcessGroup
from itertools import zip_longest, chain
import mmengine
from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
get_default_group, barrier, get_data_device,
@ -1010,8 +1010,10 @@ def collect_results_cpu(result_part: list,
part_list.append(pickle.load(f))
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
zipped_results = zip_longest(*part_list)
ordered_results = [
i for i in chain.from_iterable(zipped_results) if i is not None
]
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
# remove tmp dir
@ -1032,8 +1034,10 @@ def _collect_results_device(result_part: list, size: int) -> Optional[list]:
if rank == 0:
# sort the results
ordered_results = []
for res in zip(*part_list):
ordered_results.extend(list(res))
zipped_results = zip_longest(*part_list)
ordered_results = [
i for i in chain.from_iterable(zipped_results) if i is not None
]
# the dataloader may pad some samples
ordered_results = ordered_results[:size]
return ordered_results