[Enhance] Support build evaluator from list of built metric (#423)

* Support build evaluator from list of built metric

* regist evaluator

* fix as comment

* add unit test
pull/445/head
Mashiro 2022-08-19 10:56:51 +08:00 committed by GitHub
parent 10330cde9d
commit 4abf1a0454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 15 deletions

View File

@ -2,10 +2,11 @@
from typing import Iterator, List, Optional, Sequence, Union
from mmengine.data import BaseDataElement
from ..registry.root import METRICS
from ..registry.root import EVALUATOR, METRICS
from .metric import BaseMetric
@EVALUATOR.register_module()
class Evaluator:
"""Wrapper class to compose multiple :class:`BaseMetric` instances.
@ -19,13 +20,10 @@ class Evaluator:
metrics = [metrics]
self.metrics: List[BaseMetric] = []
for metric in metrics:
if isinstance(metric, BaseMetric):
self.metrics.append(metric)
elif isinstance(metric, dict):
if isinstance(metric, dict):
self.metrics.append(METRICS.build(metric))
else:
raise TypeError('metric should be a dict or a BaseMetric, '
f'but got {metric}.')
self.metrics.append(metric)
@property
def dataset_meta(self) -> Optional[dict]:

View File

@ -316,9 +316,12 @@ class ValLoop(BaseLoop):
fp16: bool = False) -> None:
super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
if isinstance(evaluator, dict) or isinstance(evaluator, list):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else:
assert isinstance(evaluator, Evaluator), (
'evaluator must be one of dict, list or Evaluator instance, '
f'but got {type(evaluator)}.')
self.evaluator = evaluator # type: ignore
if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo

View File

@ -38,8 +38,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
DefaultScope, count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
get_git_hash, is_list_of, is_seq_of,
revert_sync_batchnorm, set_multi_processing)
get_git_hash, is_seq_of, revert_sync_batchnorm,
set_multi_processing)
from mmengine.visualization import Visualizer
from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
@ -1246,13 +1246,14 @@ class Runner:
return param_schedulers
def build_evaluator(
self, evaluator: Union[Dict, List[Dict], Evaluator]) -> Evaluator:
def build_evaluator(self, evaluator: Union[Dict, List,
Evaluator]) -> Evaluator:
"""Build evaluator.
Examples of ``evaluator``::
evaluator = dict(type='ToyMetric')
# evaluator could be a built Evaluator instance
evaluator = Evaluator(metrics=[ToyMetric()])
# evaluator can also be a list of dict
evaluator = [
@ -1260,6 +1261,14 @@ class Runner:
dict(type='ToyEvaluator2')
]
# evaluator can also be a list of built metric
evaluator = [ToyMetric1(), ToyMetric2()]
# evaluator can also be a dict with key metrics
evaluator = dict(metrics=ToyMetric())
# metric is a list
evaluator = dict(metrics=[ToyMetric()])
Args:
evaluator (Evaluator or dict or list): An Evaluator object or a
config dict or list of config dict used to build an Evaluator.
@ -1272,13 +1281,12 @@ class Runner:
elif isinstance(evaluator, dict):
# if `metrics` in dict keys, it means to build customized evalutor
if 'metrics' in evaluator:
assert 'type' in evaluator, 'expected customized evaluator' \
f' with key `type`, but got {evaluator}'
evaluator.setdefault('type', 'Evaluator')
return EVALUATOR.build(evaluator)
# otherwise, default evalutor will be built
else:
return Evaluator(evaluator) # type: ignore
elif is_list_of(evaluator, dict):
elif isinstance(evaluator, list):
# use the default `Evaluator`
return Evaluator(evaluator) # type: ignore
else:

View File

@ -1094,6 +1094,12 @@ class TestRunner(TestCase):
evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
# input is a list of built metric.
metric = [ToyMetric1(), ToyMetric2()]
_evaluator = runner.build_evaluator(metric)
self.assertIs(_evaluator.metrics[0], metric[0])
self.assertIs(_evaluator.metrics[1], metric[1])
# test collect device
evaluator = [
dict(type='ToyMetric1', collect_device='cpu'),
@ -1115,6 +1121,10 @@ class TestRunner(TestCase):
self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
# test evaluator must be a Evaluator instance
with self.assertRaisesRegex(TypeError, 'evaluator should be'):
_evaluator = runner.build_evaluator(ToyMetric1())
def test_build_dataloader(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_dataloader'