[Enhance] Support build evaluator from list of built metric (#423)

* Support build evaluator from list of built metric

* regist evaluator

* fix as comment

* add unit test
pull/445/head
Mashiro 2022-08-19 10:56:51 +08:00 committed by GitHub
parent 10330cde9d
commit 4abf1a0454
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 34 additions and 15 deletions

View File

@ -2,10 +2,11 @@
from typing import Iterator, List, Optional, Sequence, Union from typing import Iterator, List, Optional, Sequence, Union
from mmengine.data import BaseDataElement from mmengine.data import BaseDataElement
from ..registry.root import METRICS from ..registry.root import EVALUATOR, METRICS
from .metric import BaseMetric from .metric import BaseMetric
@EVALUATOR.register_module()
class Evaluator: class Evaluator:
"""Wrapper class to compose multiple :class:`BaseMetric` instances. """Wrapper class to compose multiple :class:`BaseMetric` instances.
@ -19,13 +20,10 @@ class Evaluator:
metrics = [metrics] metrics = [metrics]
self.metrics: List[BaseMetric] = [] self.metrics: List[BaseMetric] = []
for metric in metrics: for metric in metrics:
if isinstance(metric, BaseMetric): if isinstance(metric, dict):
self.metrics.append(metric)
elif isinstance(metric, dict):
self.metrics.append(METRICS.build(metric)) self.metrics.append(METRICS.build(metric))
else: else:
raise TypeError('metric should be a dict or a BaseMetric, ' self.metrics.append(metric)
f'but got {metric}.')
@property @property
def dataset_meta(self) -> Optional[dict]: def dataset_meta(self) -> Optional[dict]:

View File

@ -316,9 +316,12 @@ class ValLoop(BaseLoop):
fp16: bool = False) -> None: fp16: bool = False) -> None:
super().__init__(runner, dataloader) super().__init__(runner, dataloader)
if isinstance(evaluator, dict) or is_list_of(evaluator, dict): if isinstance(evaluator, dict) or isinstance(evaluator, list):
self.evaluator = runner.build_evaluator(evaluator) # type: ignore self.evaluator = runner.build_evaluator(evaluator) # type: ignore
else: else:
assert isinstance(evaluator, Evaluator), (
'evaluator must be one of dict, list or Evaluator instance, '
f'but got {type(evaluator)}.')
self.evaluator = evaluator # type: ignore self.evaluator = evaluator # type: ignore
if hasattr(self.dataloader.dataset, 'metainfo'): if hasattr(self.dataloader.dataset, 'metainfo'):
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo self.evaluator.dataset_meta = self.dataloader.dataset.metainfo

View File

@ -38,8 +38,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
DefaultScope, count_registered_modules) DefaultScope, count_registered_modules)
from mmengine.registry.root import LOG_PROCESSORS from mmengine.registry.root import LOG_PROCESSORS
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version, from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
get_git_hash, is_list_of, is_seq_of, get_git_hash, is_seq_of, revert_sync_batchnorm,
revert_sync_batchnorm, set_multi_processing) set_multi_processing)
from mmengine.visualization import Visualizer from mmengine.visualization import Visualizer
from .base_loop import BaseLoop from .base_loop import BaseLoop
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model, from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
@ -1246,13 +1246,14 @@ class Runner:
return param_schedulers return param_schedulers
def build_evaluator( def build_evaluator(self, evaluator: Union[Dict, List,
self, evaluator: Union[Dict, List[Dict], Evaluator]) -> Evaluator: Evaluator]) -> Evaluator:
"""Build evaluator. """Build evaluator.
Examples of ``evaluator``:: Examples of ``evaluator``::
evaluator = dict(type='ToyMetric') # evaluator could be a built Evaluator instance
evaluator = Evaluator(metrics=[ToyMetric()])
# evaluator can also be a list of dict # evaluator can also be a list of dict
evaluator = [ evaluator = [
@ -1260,6 +1261,14 @@ class Runner:
dict(type='ToyEvaluator2') dict(type='ToyEvaluator2')
] ]
# evaluator can also be a list of built metric
evaluator = [ToyMetric1(), ToyMetric2()]
# evaluator can also be a dict with key metrics
evaluator = dict(metrics=ToyMetric())
# metric is a list
evaluator = dict(metrics=[ToyMetric()])
Args: Args:
evaluator (Evaluator or dict or list): An Evaluator object or a evaluator (Evaluator or dict or list): An Evaluator object or a
config dict or list of config dict used to build an Evaluator. config dict or list of config dict used to build an Evaluator.
@ -1272,13 +1281,12 @@ class Runner:
elif isinstance(evaluator, dict): elif isinstance(evaluator, dict):
# if `metrics` in dict keys, it means to build customized evalutor # if `metrics` in dict keys, it means to build customized evalutor
if 'metrics' in evaluator: if 'metrics' in evaluator:
assert 'type' in evaluator, 'expected customized evaluator' \ evaluator.setdefault('type', 'Evaluator')
f' with key `type`, but got {evaluator}'
return EVALUATOR.build(evaluator) return EVALUATOR.build(evaluator)
# otherwise, default evalutor will be built # otherwise, default evalutor will be built
else: else:
return Evaluator(evaluator) # type: ignore return Evaluator(evaluator) # type: ignore
elif is_list_of(evaluator, dict): elif isinstance(evaluator, list):
# use the default `Evaluator` # use the default `Evaluator`
return Evaluator(evaluator) # type: ignore return Evaluator(evaluator) # type: ignore
else: else:

View File

@ -1094,6 +1094,12 @@ class TestRunner(TestCase):
evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')] evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator) self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
# input is a list of built metric.
metric = [ToyMetric1(), ToyMetric2()]
_evaluator = runner.build_evaluator(metric)
self.assertIs(_evaluator.metrics[0], metric[0])
self.assertIs(_evaluator.metrics[1], metric[1])
# test collect device # test collect device
evaluator = [ evaluator = [
dict(type='ToyMetric1', collect_device='cpu'), dict(type='ToyMetric1', collect_device='cpu'),
@ -1115,6 +1121,10 @@ class TestRunner(TestCase):
self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu') self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu') self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
# test evaluator must be a Evaluator instance
with self.assertRaisesRegex(TypeError, 'evaluator should be'):
_evaluator = runner.build_evaluator(ToyMetric1())
def test_build_dataloader(self): def test_build_dataloader(self):
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_dataloader' cfg.experiment_name = 'test_build_dataloader'