[Enhance] Support build evaluator from list of built metric (#423)
* Support build evaluator from list of built metric * regist evaluator * fix as comment * add unit testpull/445/head
parent
10330cde9d
commit
4abf1a0454
|
@ -2,10 +2,11 @@
|
||||||
from typing import Iterator, List, Optional, Sequence, Union
|
from typing import Iterator, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from mmengine.data import BaseDataElement
|
from mmengine.data import BaseDataElement
|
||||||
from ..registry.root import METRICS
|
from ..registry.root import EVALUATOR, METRICS
|
||||||
from .metric import BaseMetric
|
from .metric import BaseMetric
|
||||||
|
|
||||||
|
|
||||||
|
@EVALUATOR.register_module()
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
"""Wrapper class to compose multiple :class:`BaseMetric` instances.
|
"""Wrapper class to compose multiple :class:`BaseMetric` instances.
|
||||||
|
|
||||||
|
@ -19,13 +20,10 @@ class Evaluator:
|
||||||
metrics = [metrics]
|
metrics = [metrics]
|
||||||
self.metrics: List[BaseMetric] = []
|
self.metrics: List[BaseMetric] = []
|
||||||
for metric in metrics:
|
for metric in metrics:
|
||||||
if isinstance(metric, BaseMetric):
|
if isinstance(metric, dict):
|
||||||
self.metrics.append(metric)
|
|
||||||
elif isinstance(metric, dict):
|
|
||||||
self.metrics.append(METRICS.build(metric))
|
self.metrics.append(METRICS.build(metric))
|
||||||
else:
|
else:
|
||||||
raise TypeError('metric should be a dict or a BaseMetric, '
|
self.metrics.append(metric)
|
||||||
f'but got {metric}.')
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def dataset_meta(self) -> Optional[dict]:
|
def dataset_meta(self) -> Optional[dict]:
|
||||||
|
|
|
@ -316,9 +316,12 @@ class ValLoop(BaseLoop):
|
||||||
fp16: bool = False) -> None:
|
fp16: bool = False) -> None:
|
||||||
super().__init__(runner, dataloader)
|
super().__init__(runner, dataloader)
|
||||||
|
|
||||||
if isinstance(evaluator, dict) or is_list_of(evaluator, dict):
|
if isinstance(evaluator, dict) or isinstance(evaluator, list):
|
||||||
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
self.evaluator = runner.build_evaluator(evaluator) # type: ignore
|
||||||
else:
|
else:
|
||||||
|
assert isinstance(evaluator, Evaluator), (
|
||||||
|
'evaluator must be one of dict, list or Evaluator instance, '
|
||||||
|
f'but got {type(evaluator)}.')
|
||||||
self.evaluator = evaluator # type: ignore
|
self.evaluator = evaluator # type: ignore
|
||||||
if hasattr(self.dataloader.dataset, 'metainfo'):
|
if hasattr(self.dataloader.dataset, 'metainfo'):
|
||||||
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
|
self.evaluator.dataset_meta = self.dataloader.dataset.metainfo
|
||||||
|
|
|
@ -38,8 +38,8 @@ from mmengine.registry import (DATA_SAMPLERS, DATASETS, EVALUATOR, HOOKS,
|
||||||
DefaultScope, count_registered_modules)
|
DefaultScope, count_registered_modules)
|
||||||
from mmengine.registry.root import LOG_PROCESSORS
|
from mmengine.registry.root import LOG_PROCESSORS
|
||||||
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
from mmengine.utils import (TORCH_VERSION, collect_env, digit_version,
|
||||||
get_git_hash, is_list_of, is_seq_of,
|
get_git_hash, is_seq_of, revert_sync_batchnorm,
|
||||||
revert_sync_batchnorm, set_multi_processing)
|
set_multi_processing)
|
||||||
from mmengine.visualization import Visualizer
|
from mmengine.visualization import Visualizer
|
||||||
from .base_loop import BaseLoop
|
from .base_loop import BaseLoop
|
||||||
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
from .checkpoint import (_load_checkpoint, _load_checkpoint_to_model,
|
||||||
|
@ -1246,13 +1246,14 @@ class Runner:
|
||||||
|
|
||||||
return param_schedulers
|
return param_schedulers
|
||||||
|
|
||||||
def build_evaluator(
|
def build_evaluator(self, evaluator: Union[Dict, List,
|
||||||
self, evaluator: Union[Dict, List[Dict], Evaluator]) -> Evaluator:
|
Evaluator]) -> Evaluator:
|
||||||
"""Build evaluator.
|
"""Build evaluator.
|
||||||
|
|
||||||
Examples of ``evaluator``::
|
Examples of ``evaluator``::
|
||||||
|
|
||||||
evaluator = dict(type='ToyMetric')
|
# evaluator could be a built Evaluator instance
|
||||||
|
evaluator = Evaluator(metrics=[ToyMetric()])
|
||||||
|
|
||||||
# evaluator can also be a list of dict
|
# evaluator can also be a list of dict
|
||||||
evaluator = [
|
evaluator = [
|
||||||
|
@ -1260,6 +1261,14 @@ class Runner:
|
||||||
dict(type='ToyEvaluator2')
|
dict(type='ToyEvaluator2')
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# evaluator can also be a list of built metric
|
||||||
|
evaluator = [ToyMetric1(), ToyMetric2()]
|
||||||
|
|
||||||
|
# evaluator can also be a dict with key metrics
|
||||||
|
evaluator = dict(metrics=ToyMetric())
|
||||||
|
# metric is a list
|
||||||
|
evaluator = dict(metrics=[ToyMetric()])
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
evaluator (Evaluator or dict or list): An Evaluator object or a
|
evaluator (Evaluator or dict or list): An Evaluator object or a
|
||||||
config dict or list of config dict used to build an Evaluator.
|
config dict or list of config dict used to build an Evaluator.
|
||||||
|
@ -1272,13 +1281,12 @@ class Runner:
|
||||||
elif isinstance(evaluator, dict):
|
elif isinstance(evaluator, dict):
|
||||||
# if `metrics` in dict keys, it means to build customized evalutor
|
# if `metrics` in dict keys, it means to build customized evalutor
|
||||||
if 'metrics' in evaluator:
|
if 'metrics' in evaluator:
|
||||||
assert 'type' in evaluator, 'expected customized evaluator' \
|
evaluator.setdefault('type', 'Evaluator')
|
||||||
f' with key `type`, but got {evaluator}'
|
|
||||||
return EVALUATOR.build(evaluator)
|
return EVALUATOR.build(evaluator)
|
||||||
# otherwise, default evalutor will be built
|
# otherwise, default evalutor will be built
|
||||||
else:
|
else:
|
||||||
return Evaluator(evaluator) # type: ignore
|
return Evaluator(evaluator) # type: ignore
|
||||||
elif is_list_of(evaluator, dict):
|
elif isinstance(evaluator, list):
|
||||||
# use the default `Evaluator`
|
# use the default `Evaluator`
|
||||||
return Evaluator(evaluator) # type: ignore
|
return Evaluator(evaluator) # type: ignore
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -1094,6 +1094,12 @@ class TestRunner(TestCase):
|
||||||
evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
|
evaluator = [dict(type='ToyMetric1'), dict(type='ToyMetric2')]
|
||||||
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
|
self.assertIsInstance(runner.build_evaluator(evaluator), Evaluator)
|
||||||
|
|
||||||
|
# input is a list of built metric.
|
||||||
|
metric = [ToyMetric1(), ToyMetric2()]
|
||||||
|
_evaluator = runner.build_evaluator(metric)
|
||||||
|
self.assertIs(_evaluator.metrics[0], metric[0])
|
||||||
|
self.assertIs(_evaluator.metrics[1], metric[1])
|
||||||
|
|
||||||
# test collect device
|
# test collect device
|
||||||
evaluator = [
|
evaluator = [
|
||||||
dict(type='ToyMetric1', collect_device='cpu'),
|
dict(type='ToyMetric1', collect_device='cpu'),
|
||||||
|
@ -1115,6 +1121,10 @@ class TestRunner(TestCase):
|
||||||
self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
|
self.assertEqual(_evaluator.metrics[0].collect_device, 'cpu')
|
||||||
self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
|
self.assertEqual(_evaluator.metrics[1].collect_device, 'gpu')
|
||||||
|
|
||||||
|
# test evaluator must be a Evaluator instance
|
||||||
|
with self.assertRaisesRegex(TypeError, 'evaluator should be'):
|
||||||
|
_evaluator = runner.build_evaluator(ToyMetric1())
|
||||||
|
|
||||||
def test_build_dataloader(self):
|
def test_build_dataloader(self):
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
cfg.experiment_name = 'test_build_dataloader'
|
cfg.experiment_name = 'test_build_dataloader'
|
||||||
|
|
Loading…
Reference in New Issue