[Enhance] Support build evaluator from list of built metric (#423)
* Support build evaluator from list of built metric * regist evaluator * fix as comment * add unit testpull/445/head
parent
10330cde9d
commit
4abf1a0454
|
@ -2,10 +2,11 @@
|
|||
from typing import Iterator, List, Optional, Sequence, Union
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from ..registry.root import METRICS
|
||||
from ..registry.root import EVALUATOR, METRICS
|
||||
from .metric import BaseMetric
|
||||
|
||||
|
||||
@EVALUATOR.register_module()
|
||||
class Evaluator:
|
||||
"""Wrapper class to compose multiple :class:`BaseMetric` instances.
|
||||
|
||||
|
@ -19,13 +20,10 @@ class Evaluator:
|
|||
metrics = [metrics]
|
||||
self.metrics: List[BaseMetric] = []
|
||||
for metric in metrics:
|
||||
if isinstance(metric, BaseMetric):
|
||||
self.metrics.append(metric)
|
||||
elif isinstance(metric, dict):
|
||||
if isinstance(metric, dict):
|
||||
self.metrics.append(METRICS.build(metric))
|
||||
else:
|
||||
raise TypeError('metric should be a dict or a BaseMetric, '
|
||||
f'but got {metric}.')
|
||||
self.metrics.append(metric)
|
||||
|
||||
@property
|
||||
def dataset_meta(self) -> Optional[dict]:
|
||||
|
|
|
@ -316,9 +316,12 @@ class ValLoop(BaseLoop):
|
|||
fp16: bool = False) -> None:
|
||||
super().__init__(runner, dataloader)
|
||||
|
||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
||||
if isinstance(evaluator, dict) or isinstance(evaluator, list):
|
||||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
assert isinstance(evaluator, Evaluator), (
|
||||
'evaluator must be one of dict, list or Evaluator instance, '
|
||||
f'but got {type(evaluator)}.')
|
||||
self.evaluator = evaluator # type: ignore
|
||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
|
||||
|
|
|
@ -38,8 +38,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
|||
DefaultScope, count_registered_modules)
|
||||
from mmengine.registry.root import LOG_PROCESSORS
|
||||
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
||||
get_git_hash, is_list_of, is_seq_of,
|
||||
revert_sync_batchnorm, set_multi_processing)
|
||||
get_git_hash, is_seq_of, revert_sync_batchnorm,
|
||||
set_multi_processing)
|
||||
from mmengine.visualization import Visualizer
|
||||
from .base_loop import BaseLoop
|
||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||
|
@ -1246,13 +1246,14 @@ class Runner:
|
|||
|
||||
return param_schedulers
|
||||
|
||||
def build_evaluator(
|
||||
self, evaluator: Union[Dict, List[Dict], Evaluator]) -> Evaluator:
|
||||
def build_evaluator(self, evaluator: Union[Dict, List,
|
||||
Evaluator]) -> Evaluator:
|
||||
"""Build evaluator.
|
||||
|
||||
Examples of ``evaluator``::
|
||||
|
||||
evaluator = dict(type='ToyMetric')
|
||||
# evaluator could be a built Evaluator instance
|
||||
evaluator = Evaluator(metrics=[ToyMetric()])
|
||||
|
||||
# evaluator can also be a list of dict
|
||||
evaluator = [
|
||||
|
@ -1260,6 +1261,14 @@ class Runner:
|
|||
dict(type='ToyEvaluator2')
|
||||
]
|
||||
|
||||
# evaluator can also be a list of built metric
|
||||
evaluator = [ToyMetric1(), ToyMetric2()]
|
||||
|
||||
# evaluator can also be a dict with key metrics
|
||||
evaluator = dict(metrics=ToyMetric())
|
||||
# metric is a list
|
||||
evaluator = dict(metrics=[ToyMetric()])
|
||||
|
||||
Args:
|
||||
evaluator (Evaluator or dict or list): An Evaluator object or a
|
||||
config dict or list of config dict used to build an Evaluator.
|
||||
|
@ -1272,13 +1281,12 @@ class Runner:
|
|||
elif isinstance(evaluator, dict):
|
||||
# if `metrics` in dict keys, it means to build customized evalutor
|
||||
if 'metrics' in evaluator:
|
||||
assert 'type' in evaluator, 'expected customized evaluator' \
|
||||
f' with key `type`, but got {evaluator}'
|
||||
evaluator.setdefault('type', 'Evaluator')
|
||||
return EVALUATOR.build(evaluator)
|
||||
# otherwise, default evalutor will be built
|
||||
else:
|
||||
return Evaluator(evaluator) # type: ignore
|
||||
elif is_list_of(evaluator, dict):
|
||||
elif isinstance(evaluator, list):
|
||||
# use the default `Evaluator`
|
||||
return Evaluator(evaluator) # type: ignore
|
||||
else:
|
||||
|
|
|
@ -1094,6 +1094,12 @@ class TestRunner(TestCase):
|
|||
evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
|
||||
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
|
||||
|
||||
# input is a list of built metric.
|
||||
metric = [ToyMetric1(), ToyMetric2()]
|
||||
_evaluator = runner.build_evaluator(metric)
|
||||
self.assertIs(_evaluator.metrics[0], metric[0])
|
||||
self.assertIs(_evaluator.metrics[1], metric[1])
|
||||
|
||||
# test collect device
|
||||
evaluator = [
|
||||
dict(type='ToyMetric1', collect_device='cpu'),
|
||||
|
@ -1115,6 +1121,10 @@ class TestRunner(TestCase):
|
|||
self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
|
||||
self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
|
||||
|
||||
# test evaluator must be a Evaluator instance
|
||||
with self.assertRaisesRegex(TypeError, 'evaluator should be'):
|
||||
_evaluator = runner.build_evaluator(ToyMetric1())
|
||||
|
||||
def test_build_dataloader(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_dataloader'
|
||||
|
|
Loading…
Reference in New Issue