[Fix] Fix dist.collect_results to keep all ranks' elements (#1469)
parent
b51bf60964
commit
109cd44c7e
|
@ -13,7 +13,7 @@ from torch import distributed as torch_dist
|
|||
from torch._utils import (_flatten_dense_tensors, _take_tensors,
|
||||
_unflatten_dense_tensors)
|
||||
from torch.distributed import ProcessGroup
|
||||
|
||||
from itertools import zip_longest, chain
|
||||
import mmengine
|
||||
from .utils import (get_world_size, get_rank, get_backend, get_dist_info,
|
||||
get_default_group, barrier, get_data_device,
|
||||
|
@ -1010,8 +1010,10 @@ def collect_results_cpu(result_part: list,
|
|||
part_list.append(pickle.load(f))
|
||||
# sort the results
|
||||
ordered_results = []
|
||||
for res in zip(*part_list):
|
||||
ordered_results.extend(list(res))
|
||||
zipped_results = zip_longest(*part_list)
|
||||
ordered_results = [
|
||||
i for i in chain.from_iterable(zipped_results) if i is not None
|
||||
]
|
||||
# the dataloader may pad some samples
|
||||
ordered_results = ordered_results[:size]
|
||||
# remove tmp dir
|
||||
|
@ -1032,8 +1034,10 @@ def _collect_results_device(result_part: list, size: int) -> Optional[list]:
|
|||
if rank == 0:
|
||||
# sort the results
|
||||
ordered_results = []
|
||||
for res in zip(*part_list):
|
||||
ordered_results.extend(list(res))
|
||||
zipped_results = zip_longest(*part_list)
|
||||
ordered_results = [
|
||||
i for i in chain.from_iterable(zipped_results) if i is not None
|
||||
]
|
||||
# the dataloader may pad some samples
|
||||
ordered_results = ordered_results[:size]
|
||||
return ordered_results
|
||||
|
|
Loading…
Reference in New Issue