mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature]: Add evaluator base class. (#41)
* [Feature]: Add evaluator base class. * solve comments * update * fix
This commit is contained in:
parent
9437ebea67
commit
d0bcb83e41
@ -222,7 +222,9 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用
|
||||
- WEIGHT_INITIALIZERS: 权重初始化的工具
|
||||
- OPTIMIZERS: 注册了 PyTorch 中所有的 `optimizer` 以及自定义的 `optimizer`
|
||||
- OPTIMIZER_CONSTRUCTORS: optimizer 的构造器
|
||||
- PARAM_SCHEDULERS: 各种参数调度器, 如 `MultiStepLR`
|
||||
- TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder`
|
||||
- EVALUATORS: 用于验证模型精度的评估器
|
||||
|
||||
下面我们以 OpenMMLab 开源项目为例介绍如何跨项目调用模块。
|
||||
|
||||
|
6
mmengine/evaluator/__init__.py
Normal file
6
mmengine/evaluator/__init__.py
Normal file
@ -0,0 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base import BaseEvaluator
|
||||
from .builder import build_evaluator
|
||||
from .composed_evaluator import ComposedEvaluator
|
||||
|
||||
__all__ = ['BaseEvaluator', 'ComposedEvaluator', 'build_evaluator']
|
210
mmengine/evaluator/base.py
Normal file
210
mmengine/evaluator/base.py
Normal file
@ -0,0 +1,210 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
import pickle
|
||||
import shutil
|
||||
import tempfile
|
||||
import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from mmengine.utils import mkdir_or_exist
|
||||
|
||||
|
||||
class BaseEvaluator(metaclass=ABCMeta):
|
||||
"""Base class for an evaluator.
|
||||
|
||||
The evaluator first processes each batch of data_samples and
|
||||
predictions, and appends the processed results in to the results list.
|
||||
Then it collects all results together from all ranks if distributed
|
||||
training is used. Finally, it computes the metrics of the entire dataset.
|
||||
|
||||
Args:
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
"""
|
||||
|
||||
def __init__(self, collect_device: str = 'cpu') -> None:
|
||||
self._dataset_meta: Union[None, dict] = None
|
||||
self.collect_device = collect_device
|
||||
self.results: List[Any] = []
|
||||
|
||||
rank, world_size = get_dist_info()
|
||||
self.rank = rank
|
||||
self.world_size = world_size
|
||||
|
||||
@property
|
||||
def dataset_meta(self) -> Optional[dict]:
|
||||
return self._dataset_meta
|
||||
|
||||
@dataset_meta.setter
|
||||
def dataset_meta(self, dataset_meta: dict) -> None:
|
||||
self._dataset_meta = dataset_meta
|
||||
|
||||
@abstractmethod
|
||||
def process(self, data_samples: dict, predictions: dict) -> None:
|
||||
"""Process one batch of data samples and predictions. The processed
|
||||
results should be stored in ``self.results``, which will be used to
|
||||
compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_samples (dict): The data samples from the dataset.
|
||||
predictions (dict): The output of the model.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
"""Compute the metrics from processed results.
|
||||
|
||||
Args:
|
||||
results (list): The processed results of each batch.
|
||||
Returns:
|
||||
dict: The computed metrics. The keys are the names of the metrics,
|
||||
and the values are corresponding results.
|
||||
"""
|
||||
|
||||
def evaluate(self, size: int) -> dict:
|
||||
"""Evaluate the model performance of the whole dataset after processing
|
||||
all batches.
|
||||
|
||||
Args:
|
||||
size (int): Length of the entire validation dataset. When batch
|
||||
size > 1, the dataloader may pad some data samples to make
|
||||
sure all ranks have the same length of dataset slice. The
|
||||
``collect_results`` function will drop the padded data base on
|
||||
this size.
|
||||
|
||||
Returns:
|
||||
metrics (dict): Evaluation metrics dict on the val dataset. The
|
||||
keys are the names of the metrics, and the values are
|
||||
corresponding results.
|
||||
"""
|
||||
if len(self.results) == 0:
|
||||
warnings.warn(
|
||||
f'{self.__class__.__name__} got empty `self._results`. Please '
|
||||
'ensure that the processed results are properly added into '
|
||||
'`self._results` in `process` method.')
|
||||
|
||||
if self.world_size == 1:
|
||||
# non-distributed
|
||||
results = self.results
|
||||
else:
|
||||
results = collect_results(self.results, size, self.collect_device)
|
||||
|
||||
if self.rank == 0:
|
||||
# TODO: replace with mmengine.dist.master_only
|
||||
metrics = [self.compute_metrics(results)]
|
||||
else:
|
||||
metrics = [None] # type: ignore
|
||||
# TODO: replace with mmengine.dist.broadcast
|
||||
if self.world_size > 1:
|
||||
metrics = dist.broadcast_object_list(metrics)
|
||||
|
||||
# reset the results list
|
||||
self.results.clear()
|
||||
return metrics[0]
|
||||
|
||||
|
||||
# TODO: replace with mmengine.dist.get_dist_info
|
||||
def get_dist_info():
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
world_size = dist.get_world_size()
|
||||
else:
|
||||
rank = 0
|
||||
world_size = 1
|
||||
return rank, world_size
|
||||
|
||||
|
||||
# TODO: replace with mmengine.dist.collect_results
|
||||
def collect_results(results, size, device='cpu'):
|
||||
"""Collected results in distributed environments."""
|
||||
# TODO: replace with mmengine.dist.collect_results
|
||||
if device == 'gpu':
|
||||
return collect_results_gpu(results, size)
|
||||
elif device == 'cpu':
|
||||
return collect_results_cpu(results, size)
|
||||
else:
|
||||
NotImplementedError(f"device must be 'cpu' or 'gpu', but got {device}")
|
||||
|
||||
|
||||
# TODO: replace with mmengine.dist.collect_results
|
||||
def collect_results_cpu(result_part, size, tmpdir=None):
|
||||
rank, world_size = get_dist_info()
|
||||
# create a tmp dir if it is not specified
|
||||
if tmpdir is None:
|
||||
MAX_LEN = 512
|
||||
# 32 is whitespace
|
||||
dir_tensor = torch.full((MAX_LEN, ),
|
||||
32,
|
||||
dtype=torch.uint8,
|
||||
device='cuda')
|
||||
if rank == 0:
|
||||
mkdir_or_exist('.dist_test')
|
||||
tmpdir = tempfile.mkdtemp(dir='.dist_test')
|
||||
tmpdir = torch.tensor(
|
||||
bytearray(tmpdir.encode()), dtype=torch.uint8, device='cuda')
|
||||
dir_tensor[:len(tmpdir)] = tmpdir
|
||||
dist.broadcast(dir_tensor, 0)
|
||||
tmpdir = dir_tensor.cpu().numpy().tobytes().decode().rstrip()
|
||||
else:
|
||||
mkdir_or_exist(tmpdir)
|
||||
# dump the part result to the dir
|
||||
with open(osp.join(tmpdir, f'part_{rank}.pkl'), 'wb') as f:
|
||||
pickle.dump(result_part, f, protocol=2)
|
||||
dist.barrier()
|
||||
# collect all parts
|
||||
if rank != 0:
|
||||
return None
|
||||
else:
|
||||
# load results of all parts from tmp dir
|
||||
part_list = []
|
||||
for i in range(world_size):
|
||||
with open(osp.join(tmpdir, f'part_{i}.pkl'), 'wb') as f:
|
||||
part_list.append(pickle.load(f))
|
||||
# sort the results
|
||||
ordered_results = []
|
||||
for res in zip(*part_list):
|
||||
ordered_results.extend(list(res))
|
||||
# the dataloader may pad some samples
|
||||
ordered_results = ordered_results[:size]
|
||||
# remove tmp dir
|
||||
shutil.rmtree(tmpdir)
|
||||
return ordered_results
|
||||
|
||||
|
||||
# TODO: replace with mmengine.dist.collect_results
|
||||
def collect_results_gpu(result_part, size):
|
||||
rank, world_size = get_dist_info()
|
||||
# dump result part to tensor with pickle
|
||||
part_tensor = torch.tensor(
|
||||
bytearray(pickle.dumps(result_part)), dtype=torch.uint8, device='cuda')
|
||||
# gather all result part tensor shape
|
||||
shape_tensor = torch.tensor(part_tensor.shape, device='cuda')
|
||||
shape_list = [shape_tensor.clone() for _ in range(world_size)]
|
||||
dist.all_gather(shape_list, shape_tensor)
|
||||
# padding result part tensor to max length
|
||||
shape_max = torch.tensor(shape_list).max()
|
||||
part_send = torch.zeros(shape_max, dtype=torch.uint8, device='cuda')
|
||||
part_send[:shape_tensor[0]] = part_tensor
|
||||
part_recv_list = [
|
||||
part_tensor.new_zeros(shape_max) for _ in range(world_size)
|
||||
]
|
||||
# gather all result part
|
||||
dist.all_gather(part_recv_list, part_send)
|
||||
|
||||
if rank == 0:
|
||||
part_list = []
|
||||
for recv, shape in zip(part_recv_list, shape_list):
|
||||
part_list.append(
|
||||
pickle.loads(recv[:shape[0]].cpu().numpy().tobytes()))
|
||||
# sort the results
|
||||
ordered_results = []
|
||||
for res in zip(*part_list):
|
||||
ordered_results.extend(list(res))
|
||||
# the dataloader may pad some samples
|
||||
ordered_results = ordered_results[:size]
|
||||
return ordered_results
|
16
mmengine/evaluator/builder.py
Normal file
16
mmengine/evaluator/builder.py
Normal file
@ -0,0 +1,16 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from ..registry import EVALUATORS
|
||||
from .composed_evaluator import ComposedEvaluator
|
||||
|
||||
|
||||
def build_evaluator(cfg: dict) -> object:
|
||||
"""Build function of evaluator.
|
||||
|
||||
When the evaluator config is a list, it will automatically build composed
|
||||
evaluators.
|
||||
"""
|
||||
if isinstance(cfg, list):
|
||||
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
|
||||
return ComposedEvaluator(evaluators=evaluators)
|
||||
else:
|
||||
return EVALUATORS.build(cfg)
|
73
mmengine/evaluator/composed_evaluator.py
Normal file
73
mmengine/evaluator/composed_evaluator.py
Normal file
@ -0,0 +1,73 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Optional, Sequence, Union
|
||||
|
||||
from .base import BaseEvaluator
|
||||
|
||||
|
||||
class ComposedEvaluator:
|
||||
"""Wrapper class to compose multiple :class:`DatasetEvaluator` instances.
|
||||
|
||||
Args:
|
||||
evaluators (Sequence[BaseEvaluator]): The evaluators to compose.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
evaluators: Sequence[BaseEvaluator],
|
||||
collect_device='cpu'):
|
||||
self._dataset_meta: Union[None, dict] = None
|
||||
self.collect_device = collect_device
|
||||
self.evaluators = evaluators
|
||||
|
||||
@property
|
||||
def dataset_meta(self) -> Optional[dict]:
|
||||
return self._dataset_meta
|
||||
|
||||
@dataset_meta.setter
|
||||
def dataset_meta(self, dataset_meta: dict) -> None:
|
||||
self._dataset_meta = dataset_meta
|
||||
for evaluator in self.evaluators:
|
||||
evaluator.dataset_meta = dataset_meta
|
||||
|
||||
def process(self, data_samples: dict, predictions: dict):
|
||||
"""Invoke process method of each wrapped evaluator.
|
||||
|
||||
Args:
|
||||
data_samples (dict): The data samples from the dataset.
|
||||
predictions (dict): The output of the model.
|
||||
"""
|
||||
|
||||
for evalutor in self.evaluators:
|
||||
evalutor.process(data_samples, predictions)
|
||||
|
||||
def evaluate(self, size: int) -> dict:
|
||||
"""Invoke evaluate method of each wrapped evaluator and collect the
|
||||
metrics dict.
|
||||
|
||||
Args:
|
||||
size (int): Length of the entire validation dataset. When batch
|
||||
size > 1, the dataloader may pad some data samples to make
|
||||
sure all ranks have the same length of dataset slice. The
|
||||
``collect_results`` function will drop the padded data base on
|
||||
this size.
|
||||
|
||||
Returns:
|
||||
metrics (dict): Evaluation metrics of all wrapped evaluators. The
|
||||
keys are the names of the metrics, and the values are
|
||||
corresponding results.
|
||||
"""
|
||||
metrics = {}
|
||||
for evaluator in self.evaluators:
|
||||
_metrics = evaluator.evaluate(size)
|
||||
|
||||
# Check metric name conflicts
|
||||
for name in _metrics.keys():
|
||||
if name in metrics:
|
||||
raise ValueError(
|
||||
'There are multiple evaluators with the same metric '
|
||||
f'name {name}')
|
||||
|
||||
metrics.update(_metrics)
|
||||
return metrics
|
@ -1,6 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS,
|
||||
from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, MODELS,
|
||||
OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS,
|
||||
RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS,
|
||||
WEIGHT_INITIALIZERS)
|
||||
@ -8,5 +8,6 @@ from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS,
|
||||
__all__ = [
|
||||
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
|
||||
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
|
||||
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS'
|
||||
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
|
||||
'EVALUATORS'
|
||||
]
|
||||
|
@ -34,3 +34,6 @@ PARAM_SCHEDULERS = Registry('parameter scheduler')
|
||||
|
||||
# manage task-specific modules like anchor generators and box coders
|
||||
TASK_UTILS = Registry('task util')
|
||||
|
||||
# manage all kinds of evaluators for computing metrics
|
||||
EVALUATORS = Registry('evaluator')
|
||||
|
Loading…
x
Reference in New Issue
Block a user