[Refactor] Use mmengine distributed in evaluator (#123)
* [Refactor] Use mmengine distributed in evaluator * remove 'TODO' commentpull/127/head
parent
4d49de7d81
commit
6d73b6cdf2
|
@ -1,17 +1,11 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
from mmengine.dist import (broadcast_object_list, collect_results,
|
||||
is_main_process)
|
||||
|
||||
|
||||
class BaseEvaluator(metaclass=ABCMeta):
|
||||
|
@ -43,11 +37,6 @@ class BaseEvaluator(metaclass=ABCMeta):
|
|||
self._dataset_meta: Union[None, dict] = None
|
||||
self.collect_device = collect_device
|
||||
self.results: List[Any] = []
|
||||
|
||||
rank, world_size = get_dist_info()
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
self.prefix = prefix or self.default_prefix
|
||||
if self.prefix is None:
|
||||
warnings.warn('The prefix is not set in evaluator class '
|
||||
|
@ -108,131 +97,22 @@ class BaseEvaluator(metaclass=ABCMeta):
|
|||
'ensure that the processed results are properly added into '
|
||||
'`self._results` in `process` method.')
|
||||
|
||||
if self.world_size == 1:
|
||||
# non-distributed
|
||||
results = self.results
|
||||
else:
|
||||
results = collect_results(self.results, size, self.collect_device)
|
||||
results = collect_results(self.results, size, self.collect_device)
|
||||
|
||||
if self.rank == 0:
|
||||
# TODO: replace with mmengine.dist.master_only
|
||||
metrics = self.compute_metrics(results)
|
||||
if is_main_process():
|
||||
_metrics = self.compute_metrics(results) # type: ignore
|
||||
# Add prefix to metric names
|
||||
if self.prefix:
|
||||
metrics = {
|
||||
_metrics = {
|
||||
'/'.join((self.prefix, k)): v
|
||||
for k, v in metrics.items()
|
||||
for k, v in _metrics.items()
|
||||
}
|
||||
metrics = [metrics] # type: ignore
|
||||
metrics = [_metrics]
|
||||
else:
|
||||
metrics = [None] # type: ignore
|
||||
|
||||
# TODO: replace with mmengine.dist.broadcast
|
||||
if self.world_size > 1:
|
||||
metrics = dist.broadcast_object_list(metrics)
|
||||
broadcast_object_list(metrics)
|
||||
|
||||
# reset the results list
|
||||
self.results.clear()
|
||||
return metrics[0]
|
||||
|
||||
|
||||
# TODO: replace with mmengine.dist.get_dist_info
|
||||
def get_dist_info():
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
# TODO: replace with mmengine.dist.collect_results
|
||||
def collect_results(results, size, device='cpu'):
|
||||
"""Collected results in distributed environments."""
|
||||
# TODO: replace with mmengine.dist.collect_results
|
||||
if device == 'gpu':
|
||||
return collect_results_gpu(results, size)
|
||||
elif device == 'cpu':
|
||||
return collect_results_cpu(results, size)
|
||||
else:
|
||||
NotImplementedError(f"device must be 'cpu' or 'gpu', but got {device}")
|
||||
|
||||
|
||||
# TODO: replace with mmengine.dist.collect_results
|
||||
def collect_results_cpu(result_part, size, tmpdir=None):
|
||||
rank, world_size = get_dist_info()
|
||||
# create a tmp dir if it is not specified
|
||||
if tmpdir is None:
|
||||
MAX_LEN = 512
|
||||
# 32 is whitespace
|
||||
dir_tensor = torch.full((MAX_LEN, ),
|
||||
32,
|
||||
dtype=torch.uint8,
|
||||
device='cuda')
|
||||
if rank == 0:
|
||||
mkdir_or_exist('.dist_test')
|
||||
tmpdir = tempfile.mkdtemp(dir='.dist_test')
|
||||
tmpdir = torch.tensor(
|
||||
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
|
||||
dir_tensor[:len(tmpdir)] = tmpdir
|
||||
dist.broadcast(dir_tensor, 0)
|
||||
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
|
||||
else:
|
||||
mkdir_or_exist(tmpdir)
|
||||
# dump the part result to the dir
|
||||
with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f:
|
||||
pickle.dump(result_part, f, protocol=2)
|
||||
dist.barrier()
|
||||
# collect all parts
|
||||
if rank != 0:
|
||||
return None
|
||||
else:
|
||||
# load results of all parts from tmp dir
|
||||
part_list = []
|
||||
for i in range(world_size):
|
||||
with open(osp.join(tmpdir, f'part_{i}.pkl'), 'wb') as f:
|
||||
part_list.append(pickle.load(f))
|
||||
# sort the results
|
||||
ordered_results = []
|
||||
for res in zip(*part_list):
|
||||
ordered_results.extend(list(res))
|
||||
# the dataloader may pad some samples
|
||||
ordered_results = ordered_results[:size]
|
||||
# remove tmp dir
|
||||
shutil.rmtree(tmpdir)
|
||||
return ordered_results
|
||||
|
||||
|
||||
# TODO: replace with mmengine.dist.collect_results
|
||||
def collect_results_gpu(result_part, size):
|
||||
rank, world_size = get_dist_info()
|
||||
# dump result part to tensor with pickle
|
||||
part_tensor = torch.tensor(
|
||||
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
|
||||
# gather all result part tensor shape
|
||||
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
|
||||
shape_list = [shape_tensor.clone() for _ in range(world_size)]
|
||||
dist.all_gather(shape_list, shape_tensor)
|
||||
# padding result part tensor to max length
|
||||
shape_max = torch.tensor(shape_list).max()
|
||||
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
|
||||
part_send[:shape_tensor[0]] = part_tensor
|
||||
part_recv_list = [
|
||||
part_tensor.new_zeros(shape_max) for _ in range(world_size)
|
||||
]
|
||||
# gather all result part
|
||||
dist.all_gather(part_recv_list, part_send)
|
||||
|
||||
if rank == 0:
|
||||
part_list = []
|
||||
for recv, shape in zip(part_recv_list, shape_list):
|
||||
part_list.append(
|
||||
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
|
||||
# sort the results
|
||||
ordered_results = []
|
||||
for res in zip(*part_list):
|
||||
ordered_results.extend(list(res))
|
||||
# the dataloader may pad some samples
|
||||
ordered_results = ordered_results[:size]
|
||||
return ordered_results
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Dict
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue