[Refactor] Refactor Evaluator to Metric. (#152)

* [Refactor] Refactor Evaluator to Metric.

* update

* fix lint

* fix doc

* fix lint

* resolve comments

* resolve comments

* remove collect_device from evaluator

* rename
This commit is contained in:
RangiLyu 2022-04-01 15:06:38 +08:00 committed by GitHub
parent 2fdca03f19
commit 2d80367893
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 293 additions and 226 deletions

View File

@ -40,7 +40,7 @@ validation_cfg=dict(
dict(type='Accuracy', top_k=1), # 使用分类正确率评测器
dict(type='F1Score') # 使用 F1_score 评测器
],
main_metric='accuracy'
main_metric='accuracy',
interval=10,
by_epoch=True,
)
@ -94,13 +94,14 @@ validation_cfg=dict(
具体的实现如下:
```python
from mmengine.evaluator import BaseEvaluator
from mmengine.registry import EVALUATORS
from mmengine.evaluator import BaseMetric
from mmengine.registry import METRICS
import numpy as np
@EVALUATORS.register_module()
class Accuracy(BaseEvaluator):
@METRICS.register_module()
class Accuracy(BaseMetric):
""" Accuracy Evaluator
Default prefix: ACC
@ -111,24 +112,24 @@ class Accuracy(BaseEvaluator):
default_prefix = 'ACC'
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
predictions: Sequence[BaseDataElement]):
def process(self, data_batch: Sequence[Tuple[Any, dict]],
predictions: Sequence[dict]):
"""Process one batch of data and predictions. The processed
Results should be stored in `self.results`, which will be used
to computed the metrics when all batches have been processed.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
data_batch (Sequence[Tuple[Any, dict]]): A batch of data
from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
predictions (Sequence[dict]): A batch of outputs from
the model.
"""
# 取出分类预测结果和类别标签
result = dict(
'pred': predictions.pred_label,
'gt': data_samples.gt_label
)
result = {
'pred': predictions['pred_label'],
'gt': data_batch['gt_label']
}
# 将当前 batch 的结果存进 self.results
self.results.append(result)

View File

@ -225,7 +225,7 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用
- OPTIMIZERS: 注册了 PyTorch 中所有的 `optimizer` 以及自定义的 `optimizer`
- OPTIMIZER_CONSTRUCTORS: optimizer 的构造器
- PARAM_SCHEDULERS: 各种参数调度器, 如 `MultiStepLR`
- EVALUATORS: 用于验证模型精度的评估器
- METRICS: 用于验证模型精度的评估指标
- TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder`
- VISUALIZERS: 管理绘制模块,如 `DetVisualizer` 可在图片上绘制预测框
- WRITERS: 存储训练日志的后端,如 `LocalWriter`, `TensorboardWriter`

View File

@ -497,6 +497,13 @@ class BaseDataElement:
new_data.set_data(data)
return new_data
def to_dict(self) -> dict:
"""Convert BaseDataElement to dict."""
return {
k: v.to_dict() if isinstance(v, BaseDataElement) else v
for k, v in self.items()
}
def __repr__(self) -> str:
def _addindent(s_: str, num_spaces: int) -> str:

View File

@ -1,9 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .base import BaseEvaluator
from .builder import build_evaluator
from .composed_evaluator import ComposedEvaluator
from .evaluator import Evaluator
from .metric import BaseMetric
from .utils import get_metric_value
__all__ = [
'BaseEvaluator', 'ComposedEvaluator', 'build_evaluator', 'get_metric_value'
]
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value']

View File

@ -1,27 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Union
from ..registry import EVALUATORS
from .base import BaseEvaluator
from .composed_evaluator import ComposedEvaluator
def build_evaluator(
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
"""Build function of evaluator.
When the evaluator config is a list, it will automatically build composed
evaluators.
Args:
cfg (dict | list): Config of evaluator. When the config is a list, it
will automatically build composed evaluators.
Returns:
BaseEvaluator or ComposedEvaluator: The built evaluator.
"""
if isinstance(cfg, list):
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
return ComposedEvaluator(evaluators=evaluators)
else:
return EVALUATORS.build(cfg)

View File

@ -1,76 +0,0 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataElement
from .base import BaseEvaluator
class ComposedEvaluator:
"""Wrapper class to compose multiple :class:`BaseEvaluator` instances.
Args:
evaluators (Sequence[BaseEvaluator]): The evaluators to compose.
collect_device (str): Device name used for collecting results from
different ranks during distributed training. Must be 'cpu' or
'gpu'. Defaults to 'cpu'.
"""
def __init__(self,
evaluators: Sequence[BaseEvaluator],
collect_device='cpu'):
self._dataset_meta: Union[None, dict] = None
self.collect_device = collect_device
self.evaluators = evaluators
@property
def dataset_meta(self) -> Optional[dict]:
return self._dataset_meta
@dataset_meta.setter
def dataset_meta(self, dataset_meta: dict) -> None:
self._dataset_meta = dataset_meta
for evaluator in self.evaluators:
evaluator.dataset_meta = dataset_meta
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
predictions: Sequence[BaseDataElement]):
"""Invoke process method of each wrapped evaluator.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
the model.
"""
for evalutor in self.evaluators:
evalutor.process(data_batch, predictions)
def evaluate(self, size: int) -> dict:
"""Invoke evaluate method of each wrapped evaluator and collect the
metrics dict.
Args:
size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data base on
this size.
Returns:
dict: Evaluation metrics of all wrapped evaluators. The keys are
the names of the metrics, and the values are corresponding results.
"""
metrics = {}
for evaluator in self.evaluators:
_metrics = evaluator.evaluate(size)
# Check metric name conflicts
for name in _metrics.keys():
if name in metrics:
raise ValueError(
'There are multiple evaluators with the same metric '
f'name {name}')
metrics.update(_metrics)
return metrics

View File

@ -0,0 +1,131 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataElement
from ..registry.root import METRICS
from .metric import BaseMetric
class Evaluator:
"""Wrapper class to compose multiple :class:`BaseMetric` instances.
Args:
metrics (dict or BaseMetric or Sequence): The config of metrics.
"""
def __init__(self, metrics: Union[dict, BaseMetric, Sequence]):
self._dataset_meta: Optional[dict] = None
if not isinstance(metrics, Sequence):
metrics = [metrics]
self.metrics: List[BaseMetric] = []
for metric in metrics:
if isinstance(metric, BaseMetric):
self.metrics.append(metric)
elif isinstance(metric, dict):
self.metrics.append(METRICS.build(metric))
else:
raise TypeError('metric should be a dict or a BaseMetric, '
f'but got {metric}.')
@property
def dataset_meta(self) -> Optional[dict]:
return self._dataset_meta
@dataset_meta.setter
def dataset_meta(self, dataset_meta: dict) -> None:
self._dataset_meta = dataset_meta
for metric in self.metrics:
metric.dataset_meta = dataset_meta
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
predictions: Sequence[BaseDataElement]):
"""Convert ``BaseDataSample`` to dict and invoke process method of each
metric.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
the model.
"""
_data_batch = []
for input, data in data_batch:
if isinstance(data, BaseDataElement):
_data_batch.append((input, data.to_dict()))
else:
_data_batch.append((input, data))
_predictions = []
for pred in predictions:
if isinstance(pred, BaseDataElement):
_predictions.append(pred.to_dict())
else:
_predictions.append(pred)
for metric in self.metrics:
metric.process(_data_batch, _predictions)
def evaluate(self, size: int) -> dict:
"""Invoke ``evaluate`` method of each metric and collect the metrics
dictionary.
Args:
size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data based on
this size.
Returns:
dict: Evaluation results of all metrics. The keys are the names
of the metrics, and the values are corresponding results.
"""
metrics = {}
for metric in self.metrics:
_results = metric.evaluate(size)
# Check metric name conflicts
for name in _results.keys():
if name in metrics:
raise ValueError(
'There are multiple evaluation results with the same '
f'metric name {name}. Please make sure all metrics '
'have different prefixes.')
metrics.update(_results)
return metrics
def offline_evaluate(self,
data: Sequence,
predictions: Sequence,
chunk_size: int = 1):
"""Offline evaluate the dumped predictions on the given data .
Args:
data (Sequence): All data of the validation set.
predictions (Sequence): All predictions of the model on the
validation set.
chunk_size (int): The number of data samples and predictions to be
processed in a batch.
"""
# support chunking iterable objects
def get_chunks(seq: Iterator, chunk_size=1):
stop = False
while not stop:
chunk = []
for _ in range(chunk_size):
try:
chunk.append(next(seq))
except StopIteration:
stop = True
break
if chunk:
yield chunk
size = 0
for data_chunk, pred_chunk in zip(
get_chunks(iter(data), chunk_size),
get_chunks(iter(predictions), chunk_size)):
size += len(data_chunk)
self.process(data_chunk, pred_chunk)
return self.evaluate(size)

View File

@ -3,20 +3,19 @@ import warnings
from abc import ABCMeta, abstractmethod
from typing import Any, List, Optional, Sequence, Tuple, Union
from mmengine.data import BaseDataElement
from mmengine.dist import (broadcast_object_list, collect_results,
is_main_process)
class BaseEvaluator(metaclass=ABCMeta):
"""Base class for an evaluator.
class BaseMetric(metaclass=ABCMeta):
"""Base class for a metric.
The evaluator first processes each batch of data_samples and
predictions, and appends the processed results in to the results list.
Then it collects all results together from all ranks if distributed
training is used. Finally, it computes the metrics of the entire dataset.
The metric first processes each batch of data_samples and predictions,
and appends the processed results to the results list. Then it
collects all results together from all ranks if distributed training
is used. Finally, it computes the metrics of the entire dataset.
A subclass of class:`BaseEvaluator` should assign a meaningful value to the
A subclass of class:`BaseMetric` should assign a meaningful value to the
class attribute `default_prefix`. See the argument `prefix` for details.
Args:
@ -39,7 +38,7 @@ class BaseEvaluator(metaclass=ABCMeta):
self.results: List[Any] = []
self.prefix = prefix or self.default_prefix
if self.prefix is None:
warnings.warn('The prefix is not set in evaluator class '
warnings.warn('The prefix is not set in metric class '
f'{self.__class__.__name__}.')
@property
@ -51,16 +50,16 @@ class BaseEvaluator(metaclass=ABCMeta):
self._dataset_meta = dataset_meta
@abstractmethod
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
predictions: Sequence[BaseDataElement]) -> None:
def process(self, data_batch: Sequence[Tuple[Any, dict]],
predictions: Sequence[dict]) -> None:
"""Process one batch of data samples and predictions. The processed
results should be stored in ``self.results``, which will be used to
compute the metrics when all batches have been processed.
Args:
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
data_batch (Sequence[Tuple[Any, dict]]): A batch of data
from the dataloader.
predictions (Sequence[BaseDataElement]): A batch of outputs from
predictions (Sequence[dict]): A batch of outputs from
the model.
"""
@ -84,7 +83,7 @@ class BaseEvaluator(metaclass=ABCMeta):
size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data base on
``collect_results`` function will drop the padded data based on
this size.
Returns:
@ -93,9 +92,9 @@ class BaseEvaluator(metaclass=ABCMeta):
"""
if len(self.results) == 0:
warnings.warn(
f'{self.__class__.__name__} got empty `self._results`. Please '
f'{self.__class__.__name__} got empty `self.results`. Please '
'ensure that the processed results are properly added into '
'`self._results` in `process` method.')
'`self.results` in `process` method.')
results = collect_results(self.results, size, self.collect_device)

View File

@ -1,7 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .default_scope import DefaultScope
from .registry import Registry, build_from_cfg
from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, LOOPS,
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS, WRITERS)
@ -10,6 +10,6 @@ __all__ = [
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS',
'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS',
'DefaultScope'
]

View File

@ -35,8 +35,8 @@ OPTIMIZERS = Registry('optimizer')
OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor')
# mangage all kinds of parameter schedulers like `MultiStepLR`
PARAM_SCHEDULERS = Registry('parameter scheduler')
# manage all kinds of evaluators for computing metrics
EVALUATORS = Registry('evaluator')
# manage all kinds of metrics
METRICS = Registry('metric')
# manage task-specific modules like anchor generators and box coders
TASK_UTILS = Registry('task util')

View File

@ -5,7 +5,7 @@ import torch
from torch.utils.data import DataLoader
from mmengine.data import BaseDataElement
from mmengine.evaluator import BaseEvaluator, build_evaluator
from mmengine.evaluator import Evaluator
from mmengine.registry import LOOPS
from mmengine.utils import is_list_of
from .base_loop import BaseLoop
@ -165,19 +165,19 @@ class ValLoop(BaseLoop):
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (BaseEvaluator or dict or list): Used for computing metrics.
evaluator (Evaluator or dict or list): Used for computing metrics.
interval (int): Validation interval. Defaults to 1.
"""
def __init__(self,
runner,
dataloader: Union[DataLoader, Dict],
evaluator: Union[BaseEvaluator, Dict, List],
evaluator: Union[Evaluator, Dict, List],
interval: int = 1) -> None:
super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = build_evaluator(evaluator) # type: ignore
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator # type: ignore
@ -228,15 +228,15 @@ class TestLoop(BaseLoop):
runner (Runner): A reference of runner.
dataloader (Dataloader or dict): A dataloader object or a dict to
build a dataloader.
evaluator (BaseEvaluator or dict or list): Used for computing metrics.
evaluator (Evaluator or dict or list): Used for computing metrics.
"""
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
evaluator: Union[BaseEvaluator, Dict, List]):
evaluator: Union[Evaluator, Dict, List]):
super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = build_evaluator(evaluator) # type: ignore
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator # type: ignore

View File

@ -23,8 +23,7 @@ from mmengine.config import Config, ConfigDict
from mmengine.data import pseudo_collate, worker_init_fn
from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only,
sync_random_seed)
from mmengine.evaluator import (BaseEvaluator, ComposedEvaluator,
build_evaluator)
from mmengine.evaluator import Evaluator
from mmengine.hooks import Hook
from mmengine.logging import MessageHub, MMLogger
from mmengine.model import is_model_wrapper
@ -41,7 +40,6 @@ from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
from .priority import Priority, get_priority
EvaluatorType = Union[BaseEvaluator, ComposedEvaluator]
ConfigType = Union[Dict, Config, ConfigDict]
@ -211,8 +209,8 @@ class Runner:
test_cfg: Optional[Dict] = None,
optimizer: Optional[Union[Optimizer, Dict]] = None,
param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
val_evaluator: Optional[Union[EvaluatorType, Dict, List]] = None,
test_evaluator: Optional[Union[EvaluatorType, Dict, List]] = None,
val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
custom_hooks: Optional[List[Union[Hook, Dict]]] = None,
load_from: Optional[str] = None,
@ -804,37 +802,35 @@ class Runner:
return param_schedulers
def build_evaluator(
self, evaluator: Union[Dict, List[Dict],
EvaluatorType]) -> EvaluatorType:
self, evaluator: Union[Dict, List[Dict], Evaluator]) -> Evaluator:
"""Build evaluator.
Examples of ``evaluator``::
evaluator = dict(type='ToyEvaluator')
evaluator = dict(type='ToyMetric')
# evaluator can also be a list of dict
evaluator = [
dict(type='ToyEvaluator1'),
dict(type='ToyMetric1'),
dict(type='ToyEvaluator2')
]
Args:
evaluator (BaseEvaluator or ComposedEvaluator or dict or list):
evaluator (Evaluator or dict or list):
An Evaluator object or a config dict or list of config dict
used to build evaluators.
used to build an Evaluator.
Returns:
BaseEvaluator or ComposedEvaluator: Evaluators build from
``evaluator``.
Evaluator: Evaluator build from ``evaluator``.
"""
if isinstance(evaluator, (BaseEvaluator, ComposedEvaluator)):
if isinstance(evaluator, Evaluator):
return evaluator
elif isinstance(evaluator, dict) or is_list_of(evaluator, dict):
return build_evaluator(evaluator) # type: ignore
return Evaluator(evaluator) # type: ignore
else:
raise TypeError(
'evaluator should be one of dict, list of dict, BaseEvaluator '
f'and ComposedEvaluator, but got {evaluator}')
'evaluator should be one of dict, list of dict, and Evaluator'
f', but got {evaluator}')
def build_dataloader(self, dataloader: Union[DataLoader,
Dict]) -> DataLoader:

View File

@ -417,3 +417,12 @@ class TestBaseDataElement(TestCase):
# test_items
assert len(dict(instances.items())) == len(dict(data.items()))
def test_to_dict(self):
metainfo, data = self.setup_data()
instances = BaseDataElement(metainfo=metainfo, **data)
dict_instances = instances.to_dict()
# test convert BaseDataElement to dict
assert isinstance(dict_instances, dict)
assert isinstance(dict_instances['gt_instances'], dict)
assert isinstance(dict_instances['pred_instances'], dict)

View File

@ -6,12 +6,12 @@ from unittest import TestCase
import numpy as np
from mmengine.data import BaseDataElement
from mmengine.evaluator import BaseEvaluator, build_evaluator, get_metric_value
from mmengine.registry import EVALUATORS
from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value
from mmengine.registry import METRICS
@EVALUATORS.register_module()
class ToyEvaluator(BaseEvaluator):
@METRICS.register_module()
class ToyMetric(BaseMetric):
"""Evaluaotr that calculates the metric `accuracy` from predictions and
labels. Alternatively, this evaluator can return arbitrary dummy metrics
set in the config.
@ -39,8 +39,8 @@ class ToyEvaluator(BaseEvaluator):
def process(self, data_batch, predictions):
results = [{
'pred': pred.pred,
'label': data[1].label
'pred': pred.get('pred'),
'label': data[1].get('label')
} for pred, data in zip(predictions, data_batch)]
self.results.extend(results)
@ -61,13 +61,13 @@ class ToyEvaluator(BaseEvaluator):
return metrics
@EVALUATORS.register_module()
class NonPrefixedEvaluator(BaseEvaluator):
@METRICS.register_module()
class NonPrefixedMetric(BaseMetric):
"""Evaluator with unassigned `default_prefix` to test the warning
information."""
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
predictions: Sequence[BaseDataElement]) -> None:
def process(self, data_batch: Sequence[Tuple[Any, dict]],
predictions: Sequence[dict]) -> None:
pass
def compute_metrics(self, results: list) -> dict:
@ -85,11 +85,11 @@ def generate_test_results(size, batch_size, pred, label):
yield (data_batch, predictions)
class TestBaseEvaluator(TestCase):
class TestEvaluator(TestCase):
def test_single_evaluator(self):
cfg = dict(type='ToyEvaluator')
evaluator = build_evaluator(cfg)
def test_single_metric(self):
cfg = dict(type='ToyMetric')
evaluator = Evaluator(cfg)
size = 10
batch_size = 4
@ -103,18 +103,18 @@ class TestBaseEvaluator(TestCase):
self.assertEqual(metrics['Toy/size'], size)
# Test empty results
cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0))
evaluator = build_evaluator(cfg)
with self.assertWarnsRegex(UserWarning, 'got empty `self._results`.'):
cfg = dict(type='ToyMetric', dummy_metrics=dict(accuracy=1.0))
evaluator = Evaluator(cfg)
with self.assertWarnsRegex(UserWarning, 'got empty `self.results`.'):
evaluator.evaluate(0)
def test_composed_evaluator(self):
def test_composed_metrics(self):
cfg = [
dict(type='ToyEvaluator'),
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0))
dict(type='ToyMetric'),
dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
]
evaluator = build_evaluator(cfg)
evaluator = Evaluator(cfg)
size = 10
batch_size = 4
@ -129,14 +129,13 @@ class TestBaseEvaluator(TestCase):
self.assertAlmostEqual(metrics['Toy/mAP'], 0.0)
self.assertEqual(metrics['Toy/size'], size)
def test_ambiguate_metric(self):
def test_ambiguous_metric(self):
cfg = [
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)),
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0))
dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)),
dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
]
evaluator = build_evaluator(cfg)
evaluator = Evaluator(cfg)
size = 10
batch_size = 4
@ -147,28 +146,42 @@ class TestBaseEvaluator(TestCase):
with self.assertRaisesRegex(
ValueError,
'There are multiple evaluators with the same metric name'):
'There are multiple evaluation results with the same metric '
'name'):
_ = evaluator.evaluate(size=size)
def test_dataset_meta(self):
dataset_meta = dict(classes=('cat', 'dog'))
cfg = [
dict(type='ToyEvaluator'),
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0))
dict(type='ToyMetric'),
dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
]
evaluator = build_evaluator(cfg)
evaluator = Evaluator(cfg)
evaluator.dataset_meta = dataset_meta
self.assertDictEqual(evaluator.dataset_meta, dataset_meta)
for _evaluator in evaluator.evaluators:
for _evaluator in evaluator.metrics:
self.assertDictEqual(_evaluator.dataset_meta, dataset_meta)
def test_collect_device(self):
cfg = [
dict(type='ToyMetric', collect_device='cpu'),
dict(
type='ToyMetric',
collect_device='gpu',
dummy_metrics=dict(mAP=0.0))
]
evaluator = Evaluator(cfg)
self.assertEqual(evaluator.metrics[0].collect_device, 'cpu')
self.assertEqual(evaluator.metrics[1].collect_device, 'gpu')
def test_prefix(self):
cfg = dict(type='NonPrefixedEvaluator')
cfg = dict(type='NonPrefixedMetric')
with self.assertWarnsRegex(UserWarning, 'The prefix is not set'):
_ = build_evaluator(cfg)
_ = Evaluator(cfg)
def test_get_metric_value(self):
@ -208,3 +221,14 @@ class TestBaseEvaluator(TestCase):
indicator = 'metric_2' # unmatched indicator
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
_ = get_metric_value(indicator, metrics)
def test_offline_evaluate(self):
cfg = dict(type='ToyMetric')
evaluator = Evaluator(cfg)
size = 10
all_data = [(np.zeros((3, 10, 10)), BaseDataElement(label=1))
for _ in range(size)]
all_predictions = [BaseDataElement(pred=0) for _ in range(size)]
evaluator.offline_evaluate(all_data, all_predictions)

View File

@ -13,16 +13,14 @@ from torch.utils.data import DataLoader, Dataset
from mmengine.config import Config
from mmengine.data import DefaultSampler
from mmengine.evaluator import (BaseEvaluator, ComposedEvaluator,
build_evaluator)
from mmengine.evaluator import BaseMetric, Evaluator
from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook,
ParamSchedulerHook)
from mmengine.hooks.checkpoint_hook import CheckpointHook
from mmengine.logging import MessageHub, MMLogger
from mmengine.optim.scheduler import MultiStepLR, StepLR
from mmengine.registry import (DATASETS, EVALUATORS, HOOKS, LOOPS,
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
Registry)
from mmengine.registry import (DATASETS, HOOKS, LOOPS, METRICS, MODEL_WRAPPERS,
MODELS, PARAM_SCHEDULERS, Registry)
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
Runner, TestLoop, ValLoop)
from mmengine.runner.priority import Priority, get_priority
@ -80,8 +78,8 @@ class ToyDataset(Dataset):
return self.data[index], self.label[index]
@EVALUATORS.register_module()
class ToyEvaluator1(BaseEvaluator):
@METRICS.register_module()
class ToyMetric1(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None):
super().__init__(collect_device=collect_device)
@ -95,8 +93,8 @@ class ToyEvaluator1(BaseEvaluator):
return dict(acc=1)
@EVALUATORS.register_module()
class ToyEvaluator2(BaseEvaluator):
@METRICS.register_module()
class ToyMetric2(BaseMetric):
def __init__(self, collect_device='cpu', dummy_metrics=None):
super().__init__(collect_device=collect_device)
@ -145,7 +143,7 @@ class CustomValLoop(BaseLoop):
self._runner = runner
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = build_evaluator(evaluator) # type: ignore
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator
@ -161,7 +159,7 @@ class CustomTestLoop(BaseLoop):
self._runner = runner
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
self.evaluator = build_evaluator(evaluator) # type: ignore
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
self.evaluator = evaluator
@ -197,8 +195,8 @@ class TestRunner(TestCase):
num_workers=0),
optimizer=dict(type='SGD', lr=0.01),
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
val_evaluator=dict(type='ToyEvaluator1'),
test_evaluator=dict(type='ToyEvaluator1'),
val_evaluator=dict(type='ToyMetric1'),
test_evaluator=dict(type='ToyMetric1'),
train_cfg=dict(by_epoch=True, max_epochs=3),
val_cfg=dict(interval=1),
test_cfg=dict(),
@ -355,14 +353,14 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
self.assertIsInstance(runner.val_loop, BaseLoop)
self.assertIsInstance(runner.val_loop.dataloader, DataLoader)
self.assertIsInstance(runner.val_loop.evaluator, ToyEvaluator1)
self.assertIsInstance(runner.val_loop.evaluator, Evaluator)
# After calling runner.test(), test_dataloader should be initialized
self.assertIsInstance(runner.test_loop, dict)
runner.test()
self.assertIsInstance(runner.test_loop, BaseLoop)
self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1)
self.assertIsInstance(runner.test_loop.evaluator, Evaluator)
# 4. initialize runner with objects rather than config
model = ToyModel()
@ -385,10 +383,10 @@ class TestRunner(TestCase):
param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
val_cfg=dict(interval=1),
val_dataloader=val_dataloader,
val_evaluator=ToyEvaluator1(),
val_evaluator=ToyMetric1(),
test_cfg=dict(),
test_dataloader=test_dataloader,
test_evaluator=ToyEvaluator1(),
test_evaluator=ToyMetric1(),
default_hooks=dict(param_scheduler=toy_hook),
custom_hooks=[toy_hook2],
experiment_name='test_init14')
@ -585,20 +583,28 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg)
# input is a BaseEvaluator or ComposedEvaluator object
evaluator = ToyEvaluator1()
evaluator = Evaluator(ToyMetric1())
self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))
evaluator = ComposedEvaluator([ToyEvaluator1(), ToyEvaluator2()])
evaluator = Evaluator([ToyMetric1(), ToyMetric2()])
self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))
# input is a dict or list of dict
evaluator = dict(type='ToyEvaluator1')
self.assertIsInstance(runner.build_evaluator(evaluator), ToyEvaluator1)
# input is a dict
evaluator = dict(type='ToyMetric1')
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
# input is a invalid type
evaluator = [dict(type='ToyEvaluator1'), dict(type='ToyEvaluator2')]
self.assertIsInstance(
runner.build_evaluator(evaluator), ComposedEvaluator)
# input is a list of dict
evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
# test collect device
evaluator = [
dict(type='ToyMetric1', collect_device='cpu'),
dict(type='ToyMetric2', collect_device='gpu')
]
_evaluator = runner.build_evaluator(evaluator)
self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
def test_build_dataloader(self):
cfg = copy.deepcopy(self.epoch_based_cfg)