mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor Evaluator to Metric. (#152)
* [Refactor] Refactor Evaluator to Metric. * update * fix lint * fix doc * fix lint * resolve comments * resolve comments * remove collect_device from evaluator * rename
This commit is contained in:
parent
2fdca03f19
commit
2d80367893
@ -40,7 +40,7 @@ validation_cfg=dict(
|
||||
dict(type='Accuracy', top_k=1), # 使用分类正确率评测器
|
||||
dict(type='F1Score') # 使用 F1_score 评测器
|
||||
],
|
||||
main_metric='accuracy'
|
||||
main_metric='accuracy',
|
||||
interval=10,
|
||||
by_epoch=True,
|
||||
)
|
||||
@ -94,13 +94,14 @@ validation_cfg=dict(
|
||||
具体的实现如下:
|
||||
|
||||
```python
|
||||
from mmengine.evaluator import BaseEvaluator
|
||||
from mmengine.registry import EVALUATORS
|
||||
from mmengine.evaluator import BaseMetric
|
||||
from mmengine.registry import METRICS
|
||||
|
||||
import numpy as np
|
||||
|
||||
@EVALUATORS.register_module()
|
||||
class Accuracy(BaseEvaluator):
|
||||
|
||||
@METRICS.register_module()
|
||||
class Accuracy(BaseMetric):
|
||||
""" Accuracy Evaluator
|
||||
|
||||
Default prefix: ACC
|
||||
@ -111,24 +112,24 @@ class Accuracy(BaseEvaluator):
|
||||
|
||||
default_prefix = 'ACC'
|
||||
|
||||
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
|
||||
predictions: Sequence[BaseDataElement]):
|
||||
def process(self, data_batch: Sequence[Tuple[Any, dict]],
|
||||
predictions: Sequence[dict]):
|
||||
"""Process one batch of data and predictions. The processed
|
||||
Results should be stored in `self.results`, which will be used
|
||||
to computed the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
|
||||
data_batch (Sequence[Tuple[Any, dict]]): A batch of data
|
||||
from the dataloader.
|
||||
predictions (Sequence[BaseDataElement]): A batch of outputs from
|
||||
predictions (Sequence[dict]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
|
||||
# 取出分类预测结果和类别标签
|
||||
result = dict(
|
||||
'pred': predictions.pred_label,
|
||||
'gt': data_samples.gt_label
|
||||
)
|
||||
result = {
|
||||
'pred': predictions['pred_label'],
|
||||
'gt': data_batch['gt_label']
|
||||
}
|
||||
|
||||
# 将当前 batch 的结果存进 self.results
|
||||
self.results.append(result)
|
||||
|
@ -225,7 +225,7 @@ MMEngine 的注册器支持跨项目调用,即可以在一个项目中使用
|
||||
- OPTIMIZERS: 注册了 PyTorch 中所有的 `optimizer` 以及自定义的 `optimizer`
|
||||
- OPTIMIZER_CONSTRUCTORS: optimizer 的构造器
|
||||
- PARAM_SCHEDULERS: 各种参数调度器, 如 `MultiStepLR`
|
||||
- EVALUATORS: 用于验证模型精度的评估器
|
||||
- METRICS: 用于验证模型精度的评估指标
|
||||
- TASK_UTILS: 任务强相关的一些组件,如 `AnchorGenerator`, `BboxCoder`
|
||||
- VISUALIZERS: 管理绘制模块,如 `DetVisualizer` 可在图片上绘制预测框
|
||||
- WRITERS: 存储训练日志的后端,如 `LocalWriter`, `TensorboardWriter`
|
||||
|
@ -497,6 +497,13 @@ class BaseDataElement:
|
||||
new_data.set_data(data)
|
||||
return new_data
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Convert BaseDataElement to dict."""
|
||||
return {
|
||||
k: v.to_dict() if isinstance(v, BaseDataElement) else v
|
||||
for k, v in self.items()
|
||||
}
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
def _addindent(s_: str, num_spaces: int) -> str:
|
||||
|
@ -1,9 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .base import BaseEvaluator
|
||||
from .builder import build_evaluator
|
||||
from .composed_evaluator import ComposedEvaluator
|
||||
from .evaluator import Evaluator
|
||||
from .metric import BaseMetric
|
||||
from .utils import get_metric_value
|
||||
|
||||
__all__ = [
|
||||
'BaseEvaluator', 'ComposedEvaluator', 'build_evaluator', 'get_metric_value'
|
||||
]
|
||||
__all__ = ['BaseMetric', 'Evaluator', 'get_metric_value']
|
||||
|
@ -1,27 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Union
|
||||
|
||||
from ..registry import EVALUATORS
|
||||
from .base import BaseEvaluator
|
||||
from .composed_evaluator import ComposedEvaluator
|
||||
|
||||
|
||||
def build_evaluator(
|
||||
cfg: Union[dict, list]) -> Union[BaseEvaluator, ComposedEvaluator]:
|
||||
"""Build function of evaluator.
|
||||
|
||||
When the evaluator config is a list, it will automatically build composed
|
||||
evaluators.
|
||||
|
||||
Args:
|
||||
cfg (dict | list): Config of evaluator. When the config is a list, it
|
||||
will automatically build composed evaluators.
|
||||
|
||||
Returns:
|
||||
BaseEvaluator or ComposedEvaluator: The built evaluator.
|
||||
"""
|
||||
if isinstance(cfg, list):
|
||||
evaluators = [EVALUATORS.build(_cfg) for _cfg in cfg]
|
||||
return ComposedEvaluator(evaluators=evaluators)
|
||||
else:
|
||||
return EVALUATORS.build(cfg)
|
@ -1,76 +0,0 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from .base import BaseEvaluator
|
||||
|
||||
|
||||
class ComposedEvaluator:
|
||||
"""Wrapper class to compose multiple :class:`BaseEvaluator` instances.
|
||||
|
||||
Args:
|
||||
evaluators (Sequence[BaseEvaluator]): The evaluators to compose.
|
||||
collect_device (str): Device name used for collecting results from
|
||||
different ranks during distributed training. Must be 'cpu' or
|
||||
'gpu'. Defaults to 'cpu'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
evaluators: Sequence[BaseEvaluator],
|
||||
collect_device='cpu'):
|
||||
self._dataset_meta: Union[None, dict] = None
|
||||
self.collect_device = collect_device
|
||||
self.evaluators = evaluators
|
||||
|
||||
@property
|
||||
def dataset_meta(self) -> Optional[dict]:
|
||||
return self._dataset_meta
|
||||
|
||||
@dataset_meta.setter
|
||||
def dataset_meta(self, dataset_meta: dict) -> None:
|
||||
self._dataset_meta = dataset_meta
|
||||
for evaluator in self.evaluators:
|
||||
evaluator.dataset_meta = dataset_meta
|
||||
|
||||
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
|
||||
predictions: Sequence[BaseDataElement]):
|
||||
"""Invoke process method of each wrapped evaluator.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
|
||||
from the dataloader.
|
||||
predictions (Sequence[BaseDataElement]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
|
||||
for evalutor in self.evaluators:
|
||||
evalutor.process(data_batch, predictions)
|
||||
|
||||
def evaluate(self, size: int) -> dict:
|
||||
"""Invoke evaluate method of each wrapped evaluator and collect the
|
||||
metrics dict.
|
||||
|
||||
Args:
|
||||
size (int): Length of the entire validation dataset. When batch
|
||||
size > 1, the dataloader may pad some data samples to make
|
||||
sure all ranks have the same length of dataset slice. The
|
||||
``collect_results`` function will drop the padded data base on
|
||||
this size.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation metrics of all wrapped evaluators. The keys are
|
||||
the names of the metrics, and the values are corresponding results.
|
||||
"""
|
||||
metrics = {}
|
||||
for evaluator in self.evaluators:
|
||||
_metrics = evaluator.evaluate(size)
|
||||
|
||||
# Check metric name conflicts
|
||||
for name in _metrics.keys():
|
||||
if name in metrics:
|
||||
raise ValueError(
|
||||
'There are multiple evaluators with the same metric '
|
||||
f'name {name}')
|
||||
|
||||
metrics.update(_metrics)
|
||||
return metrics
|
131
mmengine/evaluator/evaluator.py
Normal file
131
mmengine/evaluator/evaluator.py
Normal file
@ -0,0 +1,131 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Any, Iterator, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from ..registry.root import METRICS
|
||||
from .metric import BaseMetric
|
||||
|
||||
|
||||
class Evaluator:
|
||||
"""Wrapper class to compose multiple :class:`BaseMetric` instances.
|
||||
|
||||
Args:
|
||||
metrics (dict or BaseMetric or Sequence): The config of metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, metrics: Union[dict, BaseMetric, Sequence]):
|
||||
self._dataset_meta: Optional[dict] = None
|
||||
if not isinstance(metrics, Sequence):
|
||||
metrics = [metrics]
|
||||
self.metrics: List[BaseMetric] = []
|
||||
for metric in metrics:
|
||||
if isinstance(metric, BaseMetric):
|
||||
self.metrics.append(metric)
|
||||
elif isinstance(metric, dict):
|
||||
self.metrics.append(METRICS.build(metric))
|
||||
else:
|
||||
raise TypeError('metric should be a dict or a BaseMetric, '
|
||||
f'but got {metric}.')
|
||||
|
||||
@property
|
||||
def dataset_meta(self) -> Optional[dict]:
|
||||
return self._dataset_meta
|
||||
|
||||
@dataset_meta.setter
|
||||
def dataset_meta(self, dataset_meta: dict) -> None:
|
||||
self._dataset_meta = dataset_meta
|
||||
for metric in self.metrics:
|
||||
metric.dataset_meta = dataset_meta
|
||||
|
||||
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
|
||||
predictions: Sequence[BaseDataElement]):
|
||||
"""Convert ``BaseDataSample`` to dict and invoke process method of each
|
||||
metric.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
|
||||
from the dataloader.
|
||||
predictions (Sequence[BaseDataElement]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
_data_batch = []
|
||||
for input, data in data_batch:
|
||||
if isinstance(data, BaseDataElement):
|
||||
_data_batch.append((input, data.to_dict()))
|
||||
else:
|
||||
_data_batch.append((input, data))
|
||||
_predictions = []
|
||||
for pred in predictions:
|
||||
if isinstance(pred, BaseDataElement):
|
||||
_predictions.append(pred.to_dict())
|
||||
else:
|
||||
_predictions.append(pred)
|
||||
|
||||
for metric in self.metrics:
|
||||
metric.process(_data_batch, _predictions)
|
||||
|
||||
def evaluate(self, size: int) -> dict:
|
||||
"""Invoke ``evaluate`` method of each metric and collect the metrics
|
||||
dictionary.
|
||||
|
||||
Args:
|
||||
size (int): Length of the entire validation dataset. When batch
|
||||
size > 1, the dataloader may pad some data samples to make
|
||||
sure all ranks have the same length of dataset slice. The
|
||||
``collect_results`` function will drop the padded data based on
|
||||
this size.
|
||||
|
||||
Returns:
|
||||
dict: Evaluation results of all metrics. The keys are the names
|
||||
of the metrics, and the values are corresponding results.
|
||||
"""
|
||||
metrics = {}
|
||||
for metric in self.metrics:
|
||||
_results = metric.evaluate(size)
|
||||
|
||||
# Check metric name conflicts
|
||||
for name in _results.keys():
|
||||
if name in metrics:
|
||||
raise ValueError(
|
||||
'There are multiple evaluation results with the same '
|
||||
f'metric name {name}. Please make sure all metrics '
|
||||
'have different prefixes.')
|
||||
|
||||
metrics.update(_results)
|
||||
return metrics
|
||||
|
||||
def offline_evaluate(self,
|
||||
data: Sequence,
|
||||
predictions: Sequence,
|
||||
chunk_size: int = 1):
|
||||
"""Offline evaluate the dumped predictions on the given data .
|
||||
|
||||
Args:
|
||||
data (Sequence): All data of the validation set.
|
||||
predictions (Sequence): All predictions of the model on the
|
||||
validation set.
|
||||
chunk_size (int): The number of data samples and predictions to be
|
||||
processed in a batch.
|
||||
"""
|
||||
|
||||
# support chunking iterable objects
|
||||
def get_chunks(seq: Iterator, chunk_size=1):
|
||||
stop = False
|
||||
while not stop:
|
||||
chunk = []
|
||||
for _ in range(chunk_size):
|
||||
try:
|
||||
chunk.append(next(seq))
|
||||
except StopIteration:
|
||||
stop = True
|
||||
break
|
||||
if chunk:
|
||||
yield chunk
|
||||
|
||||
size = 0
|
||||
for data_chunk, pred_chunk in zip(
|
||||
get_chunks(iter(data), chunk_size),
|
||||
get_chunks(iter(predictions), chunk_size)):
|
||||
size += len(data_chunk)
|
||||
self.process(data_chunk, pred_chunk)
|
||||
return self.evaluate(size)
|
@ -3,20 +3,19 @@ import warnings
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Any, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.dist import (broadcast_object_list, collect_results,
|
||||
is_main_process)
|
||||
|
||||
|
||||
class BaseEvaluator(metaclass=ABCMeta):
|
||||
"""Base class for an evaluator.
|
||||
class BaseMetric(metaclass=ABCMeta):
|
||||
"""Base class for a metric.
|
||||
|
||||
The evaluator first processes each batch of data_samples and
|
||||
predictions, and appends the processed results in to the results list.
|
||||
Then it collects all results together from all ranks if distributed
|
||||
training is used. Finally, it computes the metrics of the entire dataset.
|
||||
The metric first processes each batch of data_samples and predictions,
|
||||
and appends the processed results to the results list. Then it
|
||||
collects all results together from all ranks if distributed training
|
||||
is used. Finally, it computes the metrics of the entire dataset.
|
||||
|
||||
A subclass of class:`BaseEvaluator` should assign a meaningful value to the
|
||||
A subclass of class:`BaseMetric` should assign a meaningful value to the
|
||||
class attribute `default_prefix`. See the argument `prefix` for details.
|
||||
|
||||
Args:
|
||||
@ -39,7 +38,7 @@ class BaseEvaluator(metaclass=ABCMeta):
|
||||
self.results: List[Any] = []
|
||||
self.prefix = prefix or self.default_prefix
|
||||
if self.prefix is None:
|
||||
warnings.warn('The prefix is not set in evaluator class '
|
||||
warnings.warn('The prefix is not set in metric class '
|
||||
f'{self.__class__.__name__}.')
|
||||
|
||||
@property
|
||||
@ -51,16 +50,16 @@ class BaseEvaluator(metaclass=ABCMeta):
|
||||
self._dataset_meta = dataset_meta
|
||||
|
||||
@abstractmethod
|
||||
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
|
||||
predictions: Sequence[BaseDataElement]) -> None:
|
||||
def process(self, data_batch: Sequence[Tuple[Any, dict]],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
"""Process one batch of data samples and predictions. The processed
|
||||
results should be stored in ``self.results``, which will be used to
|
||||
compute the metrics when all batches have been processed.
|
||||
|
||||
Args:
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]]): A batch of data
|
||||
data_batch (Sequence[Tuple[Any, dict]]): A batch of data
|
||||
from the dataloader.
|
||||
predictions (Sequence[BaseDataElement]): A batch of outputs from
|
||||
predictions (Sequence[dict]): A batch of outputs from
|
||||
the model.
|
||||
"""
|
||||
|
||||
@ -84,7 +83,7 @@ class BaseEvaluator(metaclass=ABCMeta):
|
||||
size (int): Length of the entire validation dataset. When batch
|
||||
size > 1, the dataloader may pad some data samples to make
|
||||
sure all ranks have the same length of dataset slice. The
|
||||
``collect_results`` function will drop the padded data base on
|
||||
``collect_results`` function will drop the padded data based on
|
||||
this size.
|
||||
|
||||
Returns:
|
||||
@ -93,9 +92,9 @@ class BaseEvaluator(metaclass=ABCMeta):
|
||||
"""
|
||||
if len(self.results) == 0:
|
||||
warnings.warn(
|
||||
f'{self.__class__.__name__} got empty `self._results`. Please '
|
||||
f'{self.__class__.__name__} got empty `self.results`. Please '
|
||||
'ensure that the processed results are properly added into '
|
||||
'`self._results` in `process` method.')
|
||||
'`self.results` in `process` method.')
|
||||
|
||||
results = collect_results(self.results, size, self.collect_device)
|
||||
|
@ -1,7 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .default_scope import DefaultScope
|
||||
from .registry import Registry, build_from_cfg
|
||||
from .root import (DATA_SAMPLERS, DATASETS, EVALUATORS, HOOKS, LOOPS,
|
||||
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, METRICS,
|
||||
MODEL_WRAPPERS, MODELS, OPTIMIZER_CONSTRUCTORS, OPTIMIZERS,
|
||||
PARAM_SCHEDULERS, RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS,
|
||||
TRANSFORMS, VISUALIZERS, WEIGHT_INITIALIZERS, WRITERS)
|
||||
@ -10,6 +10,6 @@ __all__ = [
|
||||
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
|
||||
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
|
||||
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS',
|
||||
'EVALUATORS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS',
|
||||
'METRICS', 'MODEL_WRAPPERS', 'LOOPS', 'WRITERS', 'VISUALIZERS',
|
||||
'DefaultScope'
|
||||
]
|
||||
|
@ -35,8 +35,8 @@ OPTIMIZERS = Registry('optimizer')
|
||||
OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor')
|
||||
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||
PARAM_SCHEDULERS = Registry('parameter scheduler')
|
||||
# manage all kinds of evaluators for computing metrics
|
||||
EVALUATORS = Registry('evaluator')
|
||||
# manage all kinds of metrics
|
||||
METRICS = Registry('metric')
|
||||
|
||||
# manage task-specific modules like anchor generators and box coders
|
||||
TASK_UTILS = Registry('task util')
|
||||
|
@ -5,7 +5,7 @@ import torch
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.evaluator import BaseEvaluator, build_evaluator
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.registry import LOOPS
|
||||
from mmengine.utils import is_list_of
|
||||
from .base_loop import BaseLoop
|
||||
@ -165,19 +165,19 @@ class ValLoop(BaseLoop):
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (BaseEvaluator or dict or list): Used for computing metrics.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
interval (int): Validation interval. Defaults to 1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
runner,
|
||||
dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[BaseEvaluator, Dict, List],
|
||||
evaluator: Union[Evaluator, Dict, List],
|
||||
interval: int = 1) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
self.evaluator = build_evaluator(evaluator) # type: ignore
|
||||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
self.evaluator = evaluator # type: ignore
|
||||
|
||||
@ -228,15 +228,15 @@ class TestLoop(BaseLoop):
|
||||
runner (Runner): A reference of runner.
|
||||
dataloader (Dataloader or dict): A dataloader object or a dict to
|
||||
build a dataloader.
|
||||
evaluator (BaseEvaluator or dict or list): Used for computing metrics.
|
||||
evaluator (Evaluator or dict or list): Used for computing metrics.
|
||||
"""
|
||||
|
||||
def __init__(self, runner, dataloader: Union[DataLoader, Dict],
|
||||
evaluator: Union[BaseEvaluator, Dict, List]):
|
||||
evaluator: Union[Evaluator, Dict, List]):
|
||||
super().__init__(runner, dataloader)
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
self.evaluator = build_evaluator(evaluator) # type: ignore
|
||||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
self.evaluator = evaluator # type: ignore
|
||||
|
||||
|
@ -23,8 +23,7 @@ from mmengine.config import Config, ConfigDict
|
||||
from mmengine.data import pseudo_collate, worker_init_fn
|
||||
from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only,
|
||||
sync_random_seed)
|
||||
from mmengine.evaluator import (BaseEvaluator, ComposedEvaluator,
|
||||
build_evaluator)
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.model import is_model_wrapper
|
||||
@ -41,7 +40,6 @@ from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||
from .loops import EpochBasedTrainLoop, IterBasedTrainLoop, TestLoop, ValLoop
|
||||
from .priority import Priority, get_priority
|
||||
|
||||
EvaluatorType = Union[BaseEvaluator, ComposedEvaluator]
|
||||
ConfigType = Union[Dict, Config, ConfigDict]
|
||||
|
||||
|
||||
@ -211,8 +209,8 @@ class Runner:
|
||||
test_cfg: Optional[Dict] = None,
|
||||
optimizer: Optional[Union[Optimizer, Dict]] = None,
|
||||
param_scheduler: Optional[Union[_ParamScheduler, Dict, List]] = None,
|
||||
val_evaluator: Optional[Union[EvaluatorType, Dict, List]] = None,
|
||||
test_evaluator: Optional[Union[EvaluatorType, Dict, List]] = None,
|
||||
val_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
|
||||
test_evaluator: Optional[Union[Evaluator, Dict, List]] = None,
|
||||
default_hooks: Optional[Dict[str, Union[Hook, Dict]]] = None,
|
||||
custom_hooks: Optional[List[Union[Hook, Dict]]] = None,
|
||||
load_from: Optional[str] = None,
|
||||
@ -804,37 +802,35 @@ class Runner:
|
||||
return param_schedulers
|
||||
|
||||
def build_evaluator(
|
||||
self, evaluator: Union[Dict, List[Dict],
|
||||
EvaluatorType]) -> EvaluatorType:
|
||||
self, evaluator: Union[Dict, List[Dict], Evaluator]) -> Evaluator:
|
||||
"""Build evaluator.
|
||||
|
||||
Examples of ``evaluator``::
|
||||
|
||||
evaluator = dict(type='ToyEvaluator')
|
||||
evaluator = dict(type='ToyMetric')
|
||||
|
||||
# evaluator can also be a list of dict
|
||||
evaluator = [
|
||||
dict(type='ToyEvaluator1'),
|
||||
dict(type='ToyMetric1'),
|
||||
dict(type='ToyEvaluator2')
|
||||
]
|
||||
|
||||
Args:
|
||||
evaluator (BaseEvaluator or ComposedEvaluator or dict or list):
|
||||
evaluator (Evaluator or dict or list):
|
||||
An Evaluator object or a config dict or list of config dict
|
||||
used to build evaluators.
|
||||
used to build an Evaluator.
|
||||
|
||||
Returns:
|
||||
BaseEvaluator or ComposedEvaluator: Evaluators build from
|
||||
``evaluator``.
|
||||
Evaluator: Evaluator build from ``evaluator``.
|
||||
"""
|
||||
if isinstance(evaluator, (BaseEvaluator, ComposedEvaluator)):
|
||||
if isinstance(evaluator, Evaluator):
|
||||
return evaluator
|
||||
elif isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
return build_evaluator(evaluator) # type: ignore
|
||||
return Evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
raise TypeError(
|
||||
'evaluator should be one of dict, list of dict, BaseEvaluator '
|
||||
f'and ComposedEvaluator, but got {evaluator}')
|
||||
'evaluator should be one of dict, list of dict, and Evaluator'
|
||||
f', but got {evaluator}')
|
||||
|
||||
def build_dataloader(self, dataloader: Union[DataLoader,
|
||||
Dict]) -> DataLoader:
|
||||
|
@ -417,3 +417,12 @@ class TestBaseDataElement(TestCase):
|
||||
|
||||
# test_items
|
||||
assert len(dict(instances.items())) == len(dict(data.items()))
|
||||
|
||||
def test_to_dict(self):
|
||||
metainfo, data = self.setup_data()
|
||||
instances = BaseDataElement(metainfo=metainfo, **data)
|
||||
dict_instances = instances.to_dict()
|
||||
# test convert BaseDataElement to dict
|
||||
assert isinstance(dict_instances, dict)
|
||||
assert isinstance(dict_instances['gt_instances'], dict)
|
||||
assert isinstance(dict_instances['pred_instances'], dict)
|
||||
|
@ -6,12 +6,12 @@ from unittest import TestCase
|
||||
import numpy as np
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.evaluator import BaseEvaluator, build_evaluator, get_metric_value
|
||||
from mmengine.registry import EVALUATORS
|
||||
from mmengine.evaluator import BaseMetric, Evaluator, get_metric_value
|
||||
from mmengine.registry import METRICS
|
||||
|
||||
|
||||
@EVALUATORS.register_module()
|
||||
class ToyEvaluator(BaseEvaluator):
|
||||
@METRICS.register_module()
|
||||
class ToyMetric(BaseMetric):
|
||||
"""Evaluaotr that calculates the metric `accuracy` from predictions and
|
||||
labels. Alternatively, this evaluator can return arbitrary dummy metrics
|
||||
set in the config.
|
||||
@ -39,8 +39,8 @@ class ToyEvaluator(BaseEvaluator):
|
||||
|
||||
def process(self, data_batch, predictions):
|
||||
results = [{
|
||||
'pred': pred.pred,
|
||||
'label': data[1].label
|
||||
'pred': pred.get('pred'),
|
||||
'label': data[1].get('label')
|
||||
} for pred, data in zip(predictions, data_batch)]
|
||||
self.results.extend(results)
|
||||
|
||||
@ -61,13 +61,13 @@ class ToyEvaluator(BaseEvaluator):
|
||||
return metrics
|
||||
|
||||
|
||||
@EVALUATORS.register_module()
|
||||
class NonPrefixedEvaluator(BaseEvaluator):
|
||||
@METRICS.register_module()
|
||||
class NonPrefixedMetric(BaseMetric):
|
||||
"""Evaluator with unassigned `default_prefix` to test the warning
|
||||
information."""
|
||||
|
||||
def process(self, data_batch: Sequence[Tuple[Any, BaseDataElement]],
|
||||
predictions: Sequence[BaseDataElement]) -> None:
|
||||
def process(self, data_batch: Sequence[Tuple[Any, dict]],
|
||||
predictions: Sequence[dict]) -> None:
|
||||
pass
|
||||
|
||||
def compute_metrics(self, results: list) -> dict:
|
||||
@ -85,11 +85,11 @@ def generate_test_results(size, batch_size, pred, label):
|
||||
yield (data_batch, predictions)
|
||||
|
||||
|
||||
class TestBaseEvaluator(TestCase):
|
||||
class TestEvaluator(TestCase):
|
||||
|
||||
def test_single_evaluator(self):
|
||||
cfg = dict(type='ToyEvaluator')
|
||||
evaluator = build_evaluator(cfg)
|
||||
def test_single_metric(self):
|
||||
cfg = dict(type='ToyMetric')
|
||||
evaluator = Evaluator(cfg)
|
||||
|
||||
size = 10
|
||||
batch_size = 4
|
||||
@ -103,18 +103,18 @@ class TestBaseEvaluator(TestCase):
|
||||
self.assertEqual(metrics['Toy/size'], size)
|
||||
|
||||
# Test empty results
|
||||
cfg = dict(type='ToyEvaluator', dummy_metrics=dict(accuracy=1.0))
|
||||
evaluator = build_evaluator(cfg)
|
||||
with self.assertWarnsRegex(UserWarning, 'got empty `self._results`.'):
|
||||
cfg = dict(type='ToyMetric', dummy_metrics=dict(accuracy=1.0))
|
||||
evaluator = Evaluator(cfg)
|
||||
with self.assertWarnsRegex(UserWarning, 'got empty `self.results`.'):
|
||||
evaluator.evaluate(0)
|
||||
|
||||
def test_composed_evaluator(self):
|
||||
def test_composed_metrics(self):
|
||||
cfg = [
|
||||
dict(type='ToyEvaluator'),
|
||||
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0))
|
||||
dict(type='ToyMetric'),
|
||||
dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
|
||||
]
|
||||
|
||||
evaluator = build_evaluator(cfg)
|
||||
evaluator = Evaluator(cfg)
|
||||
|
||||
size = 10
|
||||
batch_size = 4
|
||||
@ -129,14 +129,13 @@ class TestBaseEvaluator(TestCase):
|
||||
self.assertAlmostEqual(metrics['Toy/mAP'], 0.0)
|
||||
self.assertEqual(metrics['Toy/size'], size)
|
||||
|
||||
def test_ambiguate_metric(self):
|
||||
|
||||
def test_ambiguous_metric(self):
|
||||
cfg = [
|
||||
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0)),
|
||||
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0))
|
||||
dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0)),
|
||||
dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
|
||||
]
|
||||
|
||||
evaluator = build_evaluator(cfg)
|
||||
evaluator = Evaluator(cfg)
|
||||
|
||||
size = 10
|
||||
batch_size = 4
|
||||
@ -147,28 +146,42 @@ class TestBaseEvaluator(TestCase):
|
||||
|
||||
with self.assertRaisesRegex(
|
||||
ValueError,
|
||||
'There are multiple evaluators with the same metric name'):
|
||||
'There are multiple evaluation results with the same metric '
|
||||
'name'):
|
||||
_ = evaluator.evaluate(size=size)
|
||||
|
||||
def test_dataset_meta(self):
|
||||
dataset_meta = dict(classes=('cat', 'dog'))
|
||||
|
||||
cfg = [
|
||||
dict(type='ToyEvaluator'),
|
||||
dict(type='ToyEvaluator', dummy_metrics=dict(mAP=0.0))
|
||||
dict(type='ToyMetric'),
|
||||
dict(type='ToyMetric', dummy_metrics=dict(mAP=0.0))
|
||||
]
|
||||
|
||||
evaluator = build_evaluator(cfg)
|
||||
evaluator = Evaluator(cfg)
|
||||
evaluator.dataset_meta = dataset_meta
|
||||
|
||||
self.assertDictEqual(evaluator.dataset_meta, dataset_meta)
|
||||
for _evaluator in evaluator.evaluators:
|
||||
for _evaluator in evaluator.metrics:
|
||||
self.assertDictEqual(_evaluator.dataset_meta, dataset_meta)
|
||||
|
||||
def test_collect_device(self):
|
||||
cfg = [
|
||||
dict(type='ToyMetric', collect_device='cpu'),
|
||||
dict(
|
||||
type='ToyMetric',
|
||||
collect_device='gpu',
|
||||
dummy_metrics=dict(mAP=0.0))
|
||||
]
|
||||
|
||||
evaluator = Evaluator(cfg)
|
||||
self.assertEqual(evaluator.metrics[0].collect_device, 'cpu')
|
||||
self.assertEqual(evaluator.metrics[1].collect_device, 'gpu')
|
||||
|
||||
def test_prefix(self):
|
||||
cfg = dict(type='NonPrefixedEvaluator')
|
||||
cfg = dict(type='NonPrefixedMetric')
|
||||
with self.assertWarnsRegex(UserWarning, 'The prefix is not set'):
|
||||
_ = build_evaluator(cfg)
|
||||
_ = Evaluator(cfg)
|
||||
|
||||
def test_get_metric_value(self):
|
||||
|
||||
@ -208,3 +221,14 @@ class TestBaseEvaluator(TestCase):
|
||||
indicator = 'metric_2' # unmatched indicator
|
||||
with self.assertRaisesRegex(ValueError, 'can not match any metric'):
|
||||
_ = get_metric_value(indicator, metrics)
|
||||
|
||||
def test_offline_evaluate(self):
|
||||
cfg = dict(type='ToyMetric')
|
||||
evaluator = Evaluator(cfg)
|
||||
|
||||
size = 10
|
||||
|
||||
all_data = [(np.zeros((3, 10, 10)), BaseDataElement(label=1))
|
||||
for _ in range(size)]
|
||||
all_predictions = [BaseDataElement(pred=0) for _ in range(size)]
|
||||
evaluator.offline_evaluate(all_data, all_predictions)
|
@ -13,16 +13,14 @@ from torch.utils.data import DataLoader, Dataset
|
||||
|
||||
from mmengine.config import Config
|
||||
from mmengine.data import DefaultSampler
|
||||
from mmengine.evaluator import (BaseEvaluator, ComposedEvaluator,
|
||||
build_evaluator)
|
||||
from mmengine.evaluator import BaseMetric, Evaluator
|
||||
from mmengine.hooks import (Hook, IterTimerHook, LoggerHook, OptimizerHook,
|
||||
ParamSchedulerHook)
|
||||
from mmengine.hooks.checkpoint_hook import CheckpointHook
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.optim.scheduler import MultiStepLR, StepLR
|
||||
from mmengine.registry import (DATASETS, EVALUATORS, HOOKS, LOOPS,
|
||||
MODEL_WRAPPERS, MODELS, PARAM_SCHEDULERS,
|
||||
Registry)
|
||||
from mmengine.registry import (DATASETS, HOOKS, LOOPS, METRICS, MODEL_WRAPPERS,
|
||||
MODELS, PARAM_SCHEDULERS, Registry)
|
||||
from mmengine.runner import (BaseLoop, EpochBasedTrainLoop, IterBasedTrainLoop,
|
||||
Runner, TestLoop, ValLoop)
|
||||
from mmengine.runner.priority import Priority, get_priority
|
||||
@ -80,8 +78,8 @@ class ToyDataset(Dataset):
|
||||
return self.data[index], self.label[index]
|
||||
|
||||
|
||||
@EVALUATORS.register_module()
|
||||
class ToyEvaluator1(BaseEvaluator):
|
||||
@METRICS.register_module()
|
||||
class ToyMetric1(BaseMetric):
|
||||
|
||||
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
||||
super().__init__(collect_device=collect_device)
|
||||
@ -95,8 +93,8 @@ class ToyEvaluator1(BaseEvaluator):
|
||||
return dict(acc=1)
|
||||
|
||||
|
||||
@EVALUATORS.register_module()
|
||||
class ToyEvaluator2(BaseEvaluator):
|
||||
@METRICS.register_module()
|
||||
class ToyMetric2(BaseMetric):
|
||||
|
||||
def __init__(self, collect_device='cpu', dummy_metrics=None):
|
||||
super().__init__(collect_device=collect_device)
|
||||
@ -145,7 +143,7 @@ class CustomValLoop(BaseLoop):
|
||||
self._runner = runner
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
self.evaluator = build_evaluator(evaluator) # type: ignore
|
||||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
self.evaluator = evaluator
|
||||
|
||||
@ -161,7 +159,7 @@ class CustomTestLoop(BaseLoop):
|
||||
self._runner = runner
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
self.evaluator = build_evaluator(evaluator) # type: ignore
|
||||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
self.evaluator = evaluator
|
||||
|
||||
@ -197,8 +195,8 @@ class TestRunner(TestCase):
|
||||
num_workers=0),
|
||||
optimizer=dict(type='SGD', lr=0.01),
|
||||
param_scheduler=dict(type='MultiStepLR', milestones=[1, 2]),
|
||||
val_evaluator=dict(type='ToyEvaluator1'),
|
||||
test_evaluator=dict(type='ToyEvaluator1'),
|
||||
val_evaluator=dict(type='ToyMetric1'),
|
||||
test_evaluator=dict(type='ToyMetric1'),
|
||||
train_cfg=dict(by_epoch=True, max_epochs=3),
|
||||
val_cfg=dict(interval=1),
|
||||
test_cfg=dict(),
|
||||
@ -355,14 +353,14 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
self.assertIsInstance(runner.val_loop, BaseLoop)
|
||||
self.assertIsInstance(runner.val_loop.dataloader, DataLoader)
|
||||
self.assertIsInstance(runner.val_loop.evaluator, ToyEvaluator1)
|
||||
self.assertIsInstance(runner.val_loop.evaluator, Evaluator)
|
||||
|
||||
# After calling runner.test(), test_dataloader should be initialized
|
||||
self.assertIsInstance(runner.test_loop, dict)
|
||||
runner.test()
|
||||
self.assertIsInstance(runner.test_loop, BaseLoop)
|
||||
self.assertIsInstance(runner.test_loop.dataloader, DataLoader)
|
||||
self.assertIsInstance(runner.test_loop.evaluator, ToyEvaluator1)
|
||||
self.assertIsInstance(runner.test_loop.evaluator, Evaluator)
|
||||
|
||||
# 4. initialize runner with objects rather than config
|
||||
model = ToyModel()
|
||||
@ -385,10 +383,10 @@ class TestRunner(TestCase):
|
||||
param_scheduler=MultiStepLR(optimizer, milestones=[1, 2]),
|
||||
val_cfg=dict(interval=1),
|
||||
val_dataloader=val_dataloader,
|
||||
val_evaluator=ToyEvaluator1(),
|
||||
val_evaluator=ToyMetric1(),
|
||||
test_cfg=dict(),
|
||||
test_dataloader=test_dataloader,
|
||||
test_evaluator=ToyEvaluator1(),
|
||||
test_evaluator=ToyMetric1(),
|
||||
default_hooks=dict(param_scheduler=toy_hook),
|
||||
custom_hooks=[toy_hook2],
|
||||
experiment_name='test_init14')
|
||||
@ -585,20 +583,28 @@ class TestRunner(TestCase):
|
||||
runner = Runner.from_cfg(cfg)
|
||||
|
||||
# input is a BaseEvaluator or ComposedEvaluator object
|
||||
evaluator = ToyEvaluator1()
|
||||
evaluator = Evaluator(ToyMetric1())
|
||||
self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))
|
||||
|
||||
evaluator = ComposedEvaluator([ToyEvaluator1(), ToyEvaluator2()])
|
||||
evaluator = Evaluator([ToyMetric1(), ToyMetric2()])
|
||||
self.assertEqual(id(runner.build_evaluator(evaluator)), id(evaluator))
|
||||
|
||||
# input is a dict or list of dict
|
||||
evaluator = dict(type='ToyEvaluator1')
|
||||
self.assertIsInstance(runner.build_evaluator(evaluator), ToyEvaluator1)
|
||||
# input is a dict
|
||||
evaluator = dict(type='ToyMetric1')
|
||||
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
|
||||
|
||||
# input is a invalid type
|
||||
evaluator = [dict(type='ToyEvaluator1'), dict(type='ToyEvaluator2')]
|
||||
self.assertIsInstance(
|
||||
runner.build_evaluator(evaluator), ComposedEvaluator)
|
||||
# input is a list of dict
|
||||
evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
|
||||
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
|
||||
|
||||
# test collect device
|
||||
evaluator = [
|
||||
dict(type='ToyMetric1', collect_device='cpu'),
|
||||
dict(type='ToyMetric2', collect_device='gpu')
|
||||
]
|
||||
_evaluator = runner.build_evaluator(evaluator)
|
||||
self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
|
||||
self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
|
||||
|
||||
def test_build_dataloader(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
|
Loading…
x
Reference in New Issue
Block a user