[Refactor] Refactor Evaluator to Metric. (#152)

* [Refactor] Refactor Evaluator to Metric.

* update

* fix lint

* fix doc

* fix lint

* resolve comments

* resolve comments

* remove collect_device from evaluator

* rename
This commit is contained in:
RangiLyu 2022-04-01 15:06:38 +08:00 committed by GitHub
parent 2fdca03f19
commit 2d80367893
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 293 additions and 226 deletions

View File

@ -40,7 +40,7 @@ validation_cfg=dict(
dict(type='Accuracy', top_k=1), # 使用分类正确率评测器 dict(type='Accuracy', top_k=1), # 使用分类正确率评测器
dict(type='F1Score') # 使用 F1_score 评测器 dict(type='F1Score') # 使用 F1_score 评测器
], ],
main_metric='accuracy' main_metric='accuracy',
interval=10, interval=10,
by_epoch=True, by_epoch=True,
) )
@ -94,13 +94,14 @@ validation_cfg=dict(
具体的实现如下: 具体的实现如下:
```python ```python
from mmengine.evaluator import BaseEvaluator from mmengine.evaluator import BaseMetric
from mmengine.registry import EVALUATORS from mmengine.registry import METRICS
import numpy as np import numpy as np
@EVALUATORS.register_module()
class Accuracy(BaseEvaluator): @METRICS.register_module()
class Accuracy(BaseMetric):
""" Accuracy Evaluator """ Accuracy Evaluator
Default prefix: ACC Default prefix: ACC
@ -111,24 +112,24 @@ class Accuracy(BaseEvaluator):
default_prefix = 'ACC' default_prefix = 'ACC'
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], def process(self, data_batch: Sequence[Tuple[Any, dict]],
predictions: Sequence[BaseDataElement]): predictions: Sequence[dict]):
"""Process one batch of data and predictions. The processed """Process one batch of data and predictions. The processed
Results should be stored in `self.results`, which will be used Results should be stored in `self.results`, which will be used
to computed the metrics when all batches have been processed. to computed the metrics when all batches have been processed.
Args: Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data data_batch (Sequence[Tuple[Any, dict]]): A batch of data
from the dataloader. from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from predictions (Sequence[dict]): A batch of outputs from
the model. the model.
""" """
# 取出分类预测结果和类别标签 # 取出分类预测结果和类别标签
result = dict( result = {
'pred': predictions.pred_label, 'pred': predictions['pred_label'],
'gt': data_samples.gt_label 'gt': data_batch['gt_label']
) }
# 将当前 batch 的结果存进 self.results # 将当前 batch 的结果存进 self.results
self.results.append(result) self.results.append(result)

View File

@ -225,7 +225,7 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用
- OPTIMIZERS: 注册了 PyTorch 中所有的 `optimizer` 以及自定义的 `optimizer` - OPTIMIZERS: 注册了 PyTorch 中所有的 `optimizer` 以及自定义的 `optimizer`
- OPTIMIZER_CONSTRUCTORS: optimizer 的构造器 - OPTIMIZER_CONSTRUCTORS: optimizer 的构造器
- PARAM_SCHEDULERS: 各种参数调度器, 如 `MultiStepLR` - PARAM_SCHEDULERS: 各种参数调度器, 如 `MultiStepLR`
- EVALUATORS: 用于验证模型精度的评估器 - METRICS: 用于验证模型精度的评估指标
- TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder` - TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder`
- VISUALIZERS: 管理绘制模块,如 `DetVisualizer` 可在图片上绘制预测框 - VISUALIZERS: 管理绘制模块,如 `DetVisualizer` 可在图片上绘制预测框
- WRITERS: 存储训练日志的后端,如 `LocalWriter`, `TensorboardWriter` - WRITERS: 存储训练日志的后端,如 `LocalWriter`, `TensorboardWriter`

View File

@ -497,6 +497,13 @@ class BaseDataElement:
new_data.set_data(data) new_data.set_data(data)
return new_data return new_data
def to_dict(self) -> dict:
"""Convert BaseDataElement to dict."""
return {
k: v.to_dict() if isinstance(v, BaseDataElement) else v
for k, v in self.items()
}
def __repr__(self) -> str: def __repr__(self) -> str:
def _addindent(s_: str, num_spaces: int) -> str: def _addindent(s_: str, num_spaces: int) -> str:

View File

@ -1,9 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseEvaluator from .evaluator import Evaluator
from .builder import build_evaluator from .metric import BaseMetric
from .composed_evaluator import ComposedEvaluator
from .utils import get_metric_value from .utils import get_metric_value
__all__ = [ __all__ = ['BaseMetric', 'Evaluator', 'get_metric_value']
'BaseEvaluator', 'ComposedEvaluator', 'build_evaluator', 'get_metric_value'
]

View File

@ -1,27 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
from ..registry import EVALUATORS
from .base import BaseEvaluator
from .composed_evaluator import ComposedEvaluator
def build_evaluator(
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
"""Build function of evaluator.
When the evaluator config is a list, it will automatically build composed
evaluators.
Args:
cfg (dict | list): Config of evaluator. When the config is a list, it
will automatically build composed evaluators.
Returns:
BaseEvaluator or ComposedEvaluator: The built evaluator.
"""
if isinstance(cfg, list):
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
return ComposedEvaluator(evaluators=evaluators)
else:
return EVALUATORS.build(cfg)

View File

@ -1,76 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataElement
from .base import BaseEvaluator
class ComposedEvaluator:
"""Wrapper class to compose multiple :class:`BaseEvaluator` instances.
Args:
evaluators (Sequence[BaseEvaluator]): The evaluators to compose.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
"""
def __init__(self,
evaluators: Sequence[BaseEvaluator],
collect_device='cpu'):
self._dataset_meta: Union[None, dict] = None
self.collect_device = collect_device
self.evaluators = evaluators
@property
def dataset_meta(self) -> Optional[dict]:
return self._dataset_meta
@dataset_meta.setter
def dataset_meta(self, dataset_meta: dict) -> None:
self._dataset_meta = dataset_meta
for evaluator in self.evaluators:
evaluator.dataset_meta = dataset_meta
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
predictions: Sequence[BaseDataElement]):
"""Invoke process method of each wrapped evaluator.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
the model.
"""
for evalutor in self.evaluators:
evalutor.process(data_batch, predictions)
def evaluate(self, size: int) -> dict:
"""Invoke evaluate method of each wrapped evaluator and collect the
metrics dict.
Args:
size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data base on
this size.
Returns:
dict: Evaluation metrics of all wrapped evaluators. The keys are
the names of the metrics, and the values are corresponding results.
"""
metrics = {}
for evaluator in self.evaluators:
_metrics = evaluator.evaluate(size)
# Check metric name conflicts
for name in _metrics.keys():
if name in metrics:
raise ValueError(
'There are multiple evaluators with the same metric '
f'name {name}')
metrics.update(_metrics)
return metrics

View File

@ -0,0 +1,131 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataElement
from ..registry.root import METRICS
from .metric import BaseMetric
class Evaluator:
"""Wrapper class to compose multiple :class:`BaseMetric` instances.
Args:
metrics (dict or BaseMetric or Sequence): The config of metrics.
"""
def __init__(self, metrics: Union[dict, BaseMetric, Sequence]):
self._dataset_meta: Optional[dict] = None
if not isinstance(metrics, Sequence):
metrics = [metrics]
self.metrics: List[BaseMetric] = []
for metric in metrics:
if isinstance(metric, BaseMetric):
self.metrics.append(metric)
elif isinstance(metric, dict):
self.metrics.append(METRICS.build(metric))
else:
raise TypeError('metric should be a dict or a BaseMetric, '
f'but got {metric}.')
@property
def dataset_meta(self) -> Optional[dict]:
return self._dataset_meta
@dataset_meta.setter
def dataset_meta(self, dataset_meta: dict) -> None:
self._dataset_meta = dataset_meta
for metric in self.metrics:
metric.dataset_meta = dataset_meta
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
predictions: Sequence[BaseDataElement]):
"""Convert ``BaseDataSample`` to dict and invoke process method of each
metric.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
the model.
"""
_data_batch = []
for input, data in data_batch:
if isinstance(data, BaseDataElement):
_data_batch.append((input, data.to_dict()))
else:
_data_batch.append((input, data))
_predictions = []
for pred in predictions:
if isinstance(pred, BaseDataElement):
_predictions.append(pred.to_dict())
else:
_predictions.append(pred)
for metric in self.metrics:
metric.process(_data_batch, _predictions)
def evaluate(self, size: int) -> dict:
"""Invoke ``evaluate`` method of each metric and collect the metrics
dictionary.
Args:
size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data based on
this size.
Returns:
dict: Evaluation results of all metrics. The keys are the names
of the metrics, and the values are corresponding results.
"""
metrics = {}
for metric in self.metrics:
_results = metric.evaluate(size)
# Check metric name conflicts
for name in _results.keys():
if name in metrics:
raise ValueError(
'There are multiple evaluation results with the same '
f'metric name {name}. Please make sure all metrics '
'have different prefixes.')
metrics.update(_results)
return metrics
def offline_evaluate(self,
data: Sequence,
predictions: Sequence,
chunk_size: int = 1):
"""Offline evaluate the dumped predictions on the given data .
Args:
data (Sequence): All data of the validation set.
predictions (Sequence): All predictions of the model on the
validation set.
chunk_size (int): The number of data samples and predictions to be
processed in a batch.
"""
# support chunking iterable objects
def get_chunks(seq: Iterator, chunk_size=1):
stop = False
while not stop:
chunk = []
for _ in range(chunk_size):
try:
chunk.append(next(seq))
except StopIteration:
stop = True
break
if chunk:
yield chunk
size = 0
for data_chunk, pred_chunk in zip(
get_chunks(iter(data), chunk_size),
get_chunks(iter(predictions), chunk_size)):
size += len(data_chunk)
self.process(data_chunk, pred_chunk)
return self.evaluate(size)

View File

@ -3,20 +3,19 @@ import warnings
from abc import ABCMeta, abstractmethod from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence, Tuple, Union from typing import Any, List, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataElement
from mmengine.dist import (broadcast_object_list, collect_results, from mmengine.dist import (broadcast_object_list, collect_results,
is_main_process) is_main_process)
class BaseEvaluator(metaclass=ABCMeta): class BaseMetric(metaclass=ABCMeta):
"""Base class for an evaluator. """Base class for a metric.
The evaluator first processes each batch of data_samples and The metric first processes each batch of data_samples and predictions,
predictions, and appends the processed results in to the results list. and appends the processed results to the results list. Then it
Then it collects all results together from all ranks if distributed collects all results together from all ranks if distributed training
training is used. Finally, it computes the metrics of the entire dataset. is used. Finally, it computes the metrics of the entire dataset.
A subclass of class:`BaseEvaluator` should assign a meaningful value to the A subclass of class:`BaseMetric` should assign a meaningful value to the
class attribute `default_prefix`. See the argument `prefix` for details. class attribute `default_prefix`. See the argument `prefix` for details.
Args: Args:
@ -39,7 +38,7 @@ class BaseEvaluator(metaclass=ABCMeta):
self.results: List[Any] = [] self.results: List[Any] = []
self.prefix = prefix or self.default_prefix self.prefix = prefix or self.default_prefix
if self.prefix is None: if self.prefix is None:
warnings.warn('The prefix is not set in evaluator class ' warnings.warn('The prefix is not set in metric class '
f'{self.__class__.__name__}.') f'{self.__class__.__name__}.')
@property @property
@ -51,16 +50,16 @@ class BaseEvaluator(metaclass=ABCMeta):
self._dataset_meta = dataset_meta self._dataset_meta = dataset_meta
@abstractmethod @abstractmethod
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], def process(self, data_batch: Sequence[Tuple[Any, dict]],
predictions: Sequence[BaseDataElement]) -> None: predictions: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed """Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed. compute the metrics when all batches have been processed.
Args: Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data data_batch (Sequence[Tuple[Any, dict]]): A batch of data
from the dataloader. from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from predictions (Sequence[dict]): A batch of outputs from
the model. the model.
""" """
@ -84,7 +83,7 @@ class BaseEvaluator(metaclass=ABCMeta):
size (int): Length of the entire validation dataset. When batch size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data base on ``collect_results`` function will drop the padded data based on
this size. this size.
Returns: Returns:
@ -93,9 +92,9 @@ class BaseEvaluator(metaclass=ABCMeta):
""" """
if len(self.results) == 0: if len(self.results) == 0:
warnings.warn( warnings.warn(
f'{self.__class__.__name__} got empty `self._results`. Please ' f'{self.__class__.__name__} got empty `self.results`. Please '
'ensure that the processed results are properly added into ' 'ensure that the processed results are properly added into '
'`self._results` in `process` method.') '`self.results` in `process` method.')
results = collect_results(self.results, size, self.collect_device) results = collect_results(self.results, size, self.collect_device)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .default_scope import DefaultScope from .default_scope import DefaultScope
from .registry import Registry, build_from_cfg from .registry import Registry, build_from_cfg
from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, LOOPS, from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS, WRITERS) TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS, WRITERS)
@ -10,6 +10,6 @@ __all__ = [
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS', 'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS', 'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS', 'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS', 'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS',
'DefaultScope' 'DefaultScope'
] ]

View File

@ -35,8 +35,8 @@ OPTIMIZERS = Registry('optimizer')
OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor') OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor')
# mangage all kinds of parameter schedulers like `MultiStepLR` # mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry('parameter scheduler') PARAM_SCHEDULERS = Registry('parameter scheduler')
# manage all kinds of evaluators for computing metrics # manage all kinds of metrics
EVALUATORS = Registry('evaluator') METRICS = Registry('metric')
# manage task-specific modules like anchor generators and box coders # manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry('task util') TASK_UTILS = Registry('task util')

View File

@ -5,7 +5,7 @@ import torch
from torch.utils.data import DataLoader from torch.utils.data import DataLoader
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
from mmengine.evaluator import BaseEvaluator, build_evaluator from mmengine.evaluator import Evaluator
from mmengine.registry import LOOPS from mmengine.registry import LOOPS
from mmengine.utils import is_list_of from mmengine.utils import is_list_of
from .base_loop import BaseLoop from .base_loop import BaseLoop
@ -165,19 +165,19 @@ class ValLoop(BaseLoop):
runner (Runner): A reference of runner. runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader. build a dataloader.
evaluator (BaseEvaluator or dict or list): Used for computing metrics. evaluator (Evaluator or dict or list): Used for computing metrics.
interval (int): Validation interval. Defaults to 1. interval (int): Validation interval. Defaults to 1.
""" """
def __init__(self, def __init__(self,
runner, runner,
dataloader: Union[DataLoader, Dict], dataloader: Union[DataLoader, Dict],
evaluator: Union[BaseEvaluator, Dict, List], evaluator: Union[Evaluator, Dict, List],
interval: int = 1) -> None: interval: int = 1) -> None:
super().__init__(runner, dataloader) super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict): if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = build_evaluator(evaluator) # type: ignore self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else: else:
self.evaluator = evaluator # type: ignore self.evaluator = evaluator # type: ignore
@ -228,15 +228,15 @@ class TestLoop(BaseLoop):
runner (Runner): A reference of runner. runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader. build a dataloader.
evaluator (BaseEvaluator or dict or list): Used for computing metrics. evaluator (Evaluator or dict or list): Used for computing metrics.
""" """
def __init__(self, runner, dataloader: Union[DataLoader, Dict], def __init__(self, runner, dataloader: Union[DataLoader, Dict],
evaluator: Union[BaseEvaluator, Dict, List]): evaluator: Union[Evaluator, Dict, List]):
super().__init__(runner, dataloader) super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict): if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = build_evaluator(evaluator) # type: ignore self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else: else:
self.evaluator = evaluator # type: ignore self.evaluator = evaluator # type: ignore

View File

@ -23,8 +23,7 @@ from mmengine.config import Config, ConfigDict
from mmengine.data import pseudo_collate, worker_init_fn from mmengine.data import pseudo_collate, worker_init_fn
from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only, from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only,
sync_random_seed) sync_random_seed)
from mmengine.evaluator import (BaseEvaluator, ComposedEvaluator, from mmengine.evaluator import Evaluator
build_evaluator)
from mmengine.hooks import Hook from mmengine.hooks import Hook
from mmengine.logging import MessageHub, MMLogger from mmengine.logging import MessageHub, MMLogger
from mmengine.model import is_model_wrapper from mmengine.model import is_model_wrapper
@ -41,7 +40,6 @@ from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority from .priority import Priority, get_priority
EvaluatorType = Union[BaseEvaluator, ComposedEvaluator]
ConfigType = Union[Dict, Config, ConfigDict] ConfigType = Union[Dict, Config, ConfigDict]
@ -211,8 +209,8 @@ class Runner:
test_cfg: Optional[Dict] = None, test_cfg: Optional[Dict] = None,
optimizer: Optional[Union[Optimizer, Dict]] = None, optimizer: Optional[Union[Optimizer, Dict]] = None,
param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None, param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
val_evaluator: Optional[Union[EvaluatorType, Dict, List]] = None, val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
test_evaluator: Optional[Union[EvaluatorType, Dict, List]] = None, test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None, default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
custom_hooks: Optional[List[Union[Hook, Dict]]] = None, custom_hooks: Optional[List[Union[Hook, Dict]]] = None,
load_from: Optional[str] = None, load_from: Optional[str] = None,
@ -804,37 +802,35 @@ class Runner:
return param_schedulers return param_schedulers
def build_evaluator( def build_evaluator(
self, evaluator: Union[Dict, List[Dict], self, evaluator: Union[Dict, List[Dict], Evaluator]) -> Evaluator:
EvaluatorType]) -> EvaluatorType:
"""Build evaluator. """Build evaluator.
Examples of ``evaluator``:: Examples of ``evaluator``::
evaluator = dict(type='ToyEvaluator') evaluator = dict(type='ToyMetric')
# evaluator can also be a list of dict # evaluator can also be a list of dict
evaluator = [ evaluator = [
dict(type='ToyEvaluator1'), dict(type='ToyMetric1'),
dict(type='ToyEvaluator2') dict(type='ToyEvaluator2')
] ]
Args: Args:
evaluator (BaseEvaluator or ComposedEvaluator or dict or list): evaluator (Evaluator or dict or list):
An Evaluator object or a config dict or list of config dict An Evaluator object or a config dict or list of config dict
used to build evaluators. used to build an Evaluator.
Returns: Returns:
BaseEvaluator or ComposedEvaluator: Evaluators build from Evaluator: Evaluator build from ``evaluator``.
``evaluator``.
""" """
if isinstance(evaluator, (BaseEvaluator, ComposedEvaluator)): if isinstance(evaluator, Evaluator):
return evaluator return evaluator
elif isinstance(evaluator, dict) or is_list_of(evaluator, dict): elif isinstance(evaluator, dict) or is_list_of(evaluator, dict):
return build_evaluator(evaluator) # type: ignore return Evaluator(evaluator) # type: ignore
else: else:
raise TypeError( raise TypeError(
'evaluator should be one of dict, list of dict, BaseEvaluator ' 'evaluator should be one of dict, list of dict, and Evaluator'
f'and ComposedEvaluator, but got {evaluator}') f', but got {evaluator}')
def build_dataloader(self, dataloader: Union[DataLoader, def build_dataloader(self, dataloader: Union[DataLoader,
Dict]) -> DataLoader: Dict]) -> DataLoader:

View File

@ -417,3 +417,12 @@ class TestBaseDataElement(TestCase):
# test_items # test_items
assert len(dict(instances.items())) == len(dict(data.items())) assert len(dict(instances.items())) == len(dict(data.items()))
def test_to_dict(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
dict_instances = instances.to_dict()
# test convert BaseDataElement to dict
assert isinstance(dict_instances, dict)
assert isinstance(dict_instances['gt_instances'], dict)
assert isinstance(dict_instances['pred_instances'], dict)

View File

@ -6,12 +6,12 @@ from unittest import TestCase
import numpy as np import numpy as np
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
from mmengine.evaluator import BaseEvaluator, build_evaluator, get_metric_value from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value
from mmengine.registry import EVALUATORS from mmengine.registry import METRICS
@EVALUATORS.register_module() @METRICS.register_module()
class ToyEvaluator(BaseEvaluator): class ToyMetric(BaseMetric):
"""Evaluaotr that calculates the metric `accuracy` from predictions and """Evaluaotr that calculates the metric `accuracy` from predictions and
labels. Alternatively, this evaluator can return arbitrary dummy metrics labels. Alternatively, this evaluator can return arbitrary dummy metrics
set in the config. set in the config.
@ -39,8 +39,8 @@ class ToyEvaluator(BaseEvaluator):
def process(self, data_batch, predictions): def process(self, data_batch, predictions):
results = [{ results = [{
'pred': pred.pred, 'pred': pred.get('pred'),
'label': data[1].label 'label': data[1].get('label')
} for pred, data in zip(predictions, data_batch)] } for pred, data in zip(predictions, data_batch)]
self.results.extend(results) self.results.extend(results)
@ -61,13 +61,13 @@ class ToyEvaluator(BaseEvaluator):
return metrics return metrics
@EVALUATORS.register_module() @METRICS.register_module()
class NonPrefixedEvaluator(BaseEvaluator): class NonPrefixedMetric(BaseMetric):
"""Evaluator with unassigned `default_prefix` to test the warning """Evaluator with unassigned `default_prefix` to test the warning
information.""" information."""
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]], def process(self, data_batch: Sequence[Tuple[Any, dict]],
predictions: Sequence[BaseDataElement]) -> None: predictions: Sequence[dict]) -> None:
pass pass
def compute_metrics(self, results: list) -> dict: def compute_metrics(self, results: list) -> dict:
@ -85,11 +85,11 @@ def generate_test_results(size, batch_size, pred, label):
yield (data_batch, predictions) yield (data_batch, predictions)
class TestBaseEvaluator(TestCase): class TestEvaluator(TestCase):
def test_single_evaluator(self): def test_single_metric(self):
cfg = dict(type='ToyEvaluator') cfg = dict(type='ToyMetric')
evaluator = build_evaluator(cfg) evaluator = Evaluator(cfg)
size = 10 size = 10
batch_size = 4 batch_size = 4
@ -103,18 +103,18 @@ class TestBaseEvaluator(TestCase):
self.assertEqual(metrics['Toy/size'], size) self.assertEqual(metrics['Toy/size'], size)
# Test empty results # Test empty results
cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0)) cfg = dict(type='ToyMetric', dummy_metrics=dict(accuracy=1.0))
evaluator = build_evaluator(cfg) evaluator = Evaluator(cfg)
with self.assertWarnsRegex(UserWarning, 'got empty `self._results`.'): with self.assertWarnsRegex(UserWarning, 'got empty `self.results`.'):
evaluator.evaluate(0) evaluator.evaluate(0)
def test_composed_evaluator(self): def test_composed_metrics(self):
cfg = [ cfg = [
dict(type='ToyEvaluator'), dict(type='ToyMetric'),
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)) dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
] ]
evaluator = build_evaluator(cfg) evaluator = Evaluator(cfg)
size = 10 size = 10
batch_size = 4 batch_size = 4
@ -129,14 +129,13 @@ class TestBaseEvaluator(TestCase):
self.assertAlmostEqual(metrics['Toy/mAP'], 0.0) self.assertAlmostEqual(metrics['Toy/mAP'], 0.0)
self.assertEqual(metrics['Toy/size'], size) self.assertEqual(metrics['Toy/size'], size)
def test_ambiguate_metric(self): def test_ambiguous_metric(self):
cfg = [ cfg = [
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)), dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)),
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)) dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
] ]
evaluator = build_evaluator(cfg) evaluator = Evaluator(cfg)
size = 10 size = 10
batch_size = 4 batch_size = 4
@ -147,28 +146,42 @@ class TestBaseEvaluator(TestCase):
with self.assertRaisesRegex( with self.assertRaisesRegex(
ValueError, ValueError,
'There are multiple evaluators with the same metric name'): 'There are multiple evaluation results with the same metric '
'name'):
_ = evaluator.evaluate(size=size) _ = evaluator.evaluate(size=size)
def test_dataset_meta(self): def test_dataset_meta(self):
dataset_meta = dict(classes=('cat', 'dog')) dataset_meta = dict(classes=('cat', 'dog'))
cfg = [ cfg = [
dict(type='ToyEvaluator'), dict(type='ToyMetric'),
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)) dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
] ]
evaluator = build_evaluator(cfg) evaluator = Evaluator(cfg)
evaluator.dataset_meta = dataset_meta evaluator.dataset_meta = dataset_meta
self.assertDictEqual(evaluator.dataset_meta, dataset_meta) self.assertDictEqual(evaluator.dataset_meta, dataset_meta)
for _evaluator in evaluator.evaluators: for _evaluator in evaluator.metrics:
self.assertDictEqual(_evaluator.dataset_meta, dataset_meta) self.assertDictEqual(_evaluator.dataset_meta, dataset_meta)
def test_collect_device(self):
cfg = [
dict(type='ToyMetric', collect_device='cpu'),
dict(
type='ToyMetric',
collect_device='gpu',
dummy_metrics=dict(mAP=0.0))
]
evaluator = Evaluator(cfg)
self.assertEqual(evaluator.metrics[0].collect_device, 'cpu')
self.assertEqual(evaluator.metrics[1].collect_device, 'gpu')
def test_prefix(self): def test_prefix(self):
cfg = dict(type='NonPrefixedEvaluator') cfg = dict(type='NonPrefixedMetric')
with self.assertWarnsRegex(UserWarning, 'The prefix is not set'): with self.assertWarnsRegex(UserWarning, 'The prefix is not set'):
_ = build_evaluator(cfg) _ = Evaluator(cfg)
def test_get_metric_value(self): def test_get_metric_value(self):
@ -208,3 +221,14 @@ class TestBaseEvaluator(TestCase):
indicator = 'metric_2' # unmatched indicator indicator = 'metric_2' # unmatched indicator
with self.assertRaisesRegex(ValueError, 'can not match any metric'): with self.assertRaisesRegex(ValueError, 'can not match any metric'):
_ = get_metric_value(indicator, metrics) _ = get_metric_value(indicator, metrics)
def test_offline_evaluate(self):
cfg = dict(type='ToyMetric')
evaluator = Evaluator(cfg)
size = 10
all_data = [(np.zeros((3, 10, 10)), BaseDataElement(label=1))
for _ in range(size)]
all_predictions = [BaseDataElement(pred=0) for _ in range(size)]
evaluator.offline_evaluate(all_data, all_predictions)

View File

@ -13,16 +13,14 @@ from torch.utils.data import DataLoader, Dataset
from mmengine.config import Config from mmengine.config import Config
from mmengine.data import DefaultSampler from mmengine.data import DefaultSampler
from mmengine.evaluator import (BaseEvaluator, ComposedEvaluator, from mmengine.evaluator import BaseMetric, Evaluator
build_evaluator)
from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook, from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook,
ParamSchedulerHook) ParamSchedulerHook)
from mmengine.hooks.checkpoint_hook import CheckpointHook from mmengine.hooks.checkpoint_hook import CheckpointHook
from mmengine.logging import MessageHub, MMLogger from mmengine.logging import MessageHub, MMLogger
from mmengine.optim.scheduler import MultiStepLR, StepLR from mmengine.optim.scheduler import MultiStepLR, StepLR
from mmengine.registry import (DATASETS, EVALUATORS, HOOKS, LOOPS, from mmengine.registry import (DATASETS, HOOKS, LOOPS, METRICS, MODEL_WRAPPERS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS, MODELS, PARAM_SCHEDULERS, Registry)
Registry)
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop, from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop) Runner, TestLoop, ValLoop)
from mmengine.runner.priority import Priority, get_priority from mmengine.runner.priority import Priority, get_priority
@ -80,8 +78,8 @@ class ToyDataset(Dataset):
return self.data[index], self.label[index] return self.data[index], self.label[index]
@EVALUATORS.register_module() @METRICS.register_module()
class ToyEvaluator1(BaseEvaluator): class ToyMetric1(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None): def __init__(self, collect_device='cpu', dummy_metrics=None):
super().__init__(collect_device=collect_device) super().__init__(collect_device=collect_device)
@ -95,8 +93,8 @@ class ToyEvaluator1(BaseEvaluator):
return dict(acc=1) return dict(acc=1)
@EVALUATORS.register_module() @METRICS.register_module()
class ToyEvaluator2(BaseEvaluator): class ToyMetric2(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None): def __init__(self, collect_device='cpu', dummy_metrics=None):
super().__init__(collect_device=collect_device) super().__init__(collect_device=collect_device)
@ -145,7 +143,7 @@ class CustomValLoop(BaseLoop):
self._runner = runner self._runner = runner
if isinstance(evaluator, dict) or is_list_of(evaluator, dict): if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = build_evaluator(evaluator) # type: ignore self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else: else:
self.evaluator = evaluator self.evaluator = evaluator
@ -161,7 +159,7 @@ class CustomTestLoop(BaseLoop):
self._runner = runner self._runner = runner
if isinstance(evaluator, dict) or is_list_of(evaluator, dict): if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = build_evaluator(evaluator) # type: ignore self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else: else:
self.evaluator = evaluator self.evaluator = evaluator
@ -197,8 +195,8 @@ class TestRunner(TestCase):
num_workers=0), num_workers=0),
optimizer=dict(type='SGD', lr=0.01), optimizer=dict(type='SGD', lr=0.01),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]), param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=dict(type='ToyEvaluator1'), val_evaluator=dict(type='ToyMetric1'),
test_evaluator=dict(type='ToyEvaluator1'), test_evaluator=dict(type='ToyMetric1'),
train_cfg=dict(by_epoch=True, max_epochs=3), train_cfg=dict(by_epoch=True, max_epochs=3),
val_cfg=dict(interval=1), val_cfg=dict(interval=1),
test_cfg=dict(), test_cfg=dict(),
@ -355,14 +353,14 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
self.assertIsInstance(runner.val_loop, BaseLoop) self.assertIsInstance(runner.val_loop, BaseLoop)
self.assertIsInstance(runner.val_loop.dataloader, DataLoader) self.assertIsInstance(runner.val_loop.dataloader, DataLoader)
self.assertIsInstance(runner.val_loop.evaluator, ToyEvaluator1) self.assertIsInstance(runner.val_loop.evaluator, Evaluator)
# After calling runner.test(), test_dataloader should be initialized # After calling runner.test(), test_dataloader should be initialized
self.assertIsInstance(runner.test_loop, dict) self.assertIsInstance(runner.test_loop, dict)
runner.test() runner.test()
self.assertIsInstance(runner.test_loop, BaseLoop) self.assertIsInstance(runner.test_loop, BaseLoop)
self.assertIsInstance(runner.test_loop.dataloader, DataLoader) self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1) self.assertIsInstance(runner.test_loop.evaluator, Evaluator)
# 4. initialize runner with objects rather than config # 4. initialize runner with objects rather than config
model = ToyModel() model = ToyModel()
@ -385,10 +383,10 @@ class TestRunner(TestCase):
param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]), param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
val_cfg=dict(interval=1), val_cfg=dict(interval=1),
val_dataloader=val_dataloader, val_dataloader=val_dataloader,
val_evaluator=ToyEvaluator1(), val_evaluator=ToyMetric1(),
test_cfg=dict(), test_cfg=dict(),
test_dataloader=test_dataloader, test_dataloader=test_dataloader,
test_evaluator=ToyEvaluator1(), test_evaluator=ToyMetric1(),
default_hooks=dict(param_scheduler=toy_hook), default_hooks=dict(param_scheduler=toy_hook),
custom_hooks=[toy_hook2], custom_hooks=[toy_hook2],
experiment_name='test_init14') experiment_name='test_init14')
@ -585,20 +583,28 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
# input is a BaseEvaluator or ComposedEvaluator object # input is a BaseEvaluator or ComposedEvaluator object
evaluator = ToyEvaluator1() evaluator = Evaluator(ToyMetric1())
self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator)) self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))
evaluator = ComposedEvaluator([ToyEvaluator1(), ToyEvaluator2()]) evaluator = Evaluator([ToyMetric1(), ToyMetric2()])
self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator)) self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))
# input is a dict or list of dict # input is a dict
evaluator = dict(type='ToyEvaluator1') evaluator = dict(type='ToyMetric1')
self.assertIsInstance(runner.build_evaluator(evaluator), ToyEvaluator1) self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
# input is a invalid type # input is a list of dict
evaluator = [dict(type='ToyEvaluator1'), dict(type='ToyEvaluator2')] evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
self.assertIsInstance( self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
runner.build_evaluator(evaluator), ComposedEvaluator)
# test collect device
evaluator = [
dict(type='ToyMetric1', collect_device='cpu'),
dict(type='ToyMetric2', collect_device='gpu')
]
_evaluator = runner.build_evaluator(evaluator)
self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
def test_build_dataloader(self): def test_build_dataloader(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)