mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor after_val_epoch
to make it output metric by epoch (#278)
* [Refactor]:Refactor `after_val_epoch` to make it output metric by epoch * add an option for user to choose the way of outputing metric * rename variable * reformat docstring * add type alias * reformat code * add test function * add comment and test code * add comment and test code
This commit is contained in:
parent
ef946404e6
commit
dceef1f66f
@ -11,6 +11,7 @@ from mmengine.registry import HOOKS
|
|||||||
from mmengine.utils import is_tuple_of, scandir
|
from mmengine.utils import is_tuple_of, scandir
|
||||||
|
|
||||||
DATA_BATCH = Optional[Sequence[dict]]
|
DATA_BATCH = Optional[Sequence[dict]]
|
||||||
|
SUFFIX_TYPE = Union[Sequence[str], str]
|
||||||
|
|
||||||
|
|
||||||
@HOOKS.register_module()
|
@HOOKS.register_module()
|
||||||
@ -51,6 +52,11 @@ class LoggerHook(Hook):
|
|||||||
file_client_args (dict, optional): Arguments to instantiate a
|
file_client_args (dict, optional): Arguments to instantiate a
|
||||||
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
||||||
Defaults to None.
|
Defaults to None.
|
||||||
|
log_metric_by_epoch (bool): Whether to output metric in validation step
|
||||||
|
by epoch. It can be true when running in epoch based runner.
|
||||||
|
If set to True, `after_val_epoch` will set `step` to self.epoch in
|
||||||
|
`runner.visualizer.add_scalars`. Otherwise `step` will be
|
||||||
|
self.iter. Default to True.
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # The simplest LoggerHook config.
|
>>> # The simplest LoggerHook config.
|
||||||
@ -58,17 +64,15 @@ class LoggerHook(Hook):
|
|||||||
"""
|
"""
|
||||||
priority = 'BELOW_NORMAL'
|
priority = 'BELOW_NORMAL'
|
||||||
|
|
||||||
def __init__(
|
def __init__(self,
|
||||||
self,
|
interval: int = 10,
|
||||||
interval: int = 10,
|
ignore_last: bool = True,
|
||||||
ignore_last: bool = True,
|
interval_exp_name: int = 1000,
|
||||||
interval_exp_name: int = 1000,
|
out_dir: Optional[Union[str, Path]] = None,
|
||||||
out_dir: Optional[Union[str, Path]] = None,
|
out_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'),
|
||||||
out_suffix: Union[Sequence[str],
|
keep_local: bool = True,
|
||||||
str] = ('.json', '.log', '.py', 'yaml'),
|
file_client_args: Optional[dict] = None,
|
||||||
keep_local: bool = True,
|
log_metric_by_epoch: bool = True):
|
||||||
file_client_args: Optional[dict] = None,
|
|
||||||
):
|
|
||||||
self.interval = interval
|
self.interval = interval
|
||||||
self.ignore_last = ignore_last
|
self.ignore_last = ignore_last
|
||||||
self.interval_exp_name = interval_exp_name
|
self.interval_exp_name = interval_exp_name
|
||||||
@ -91,6 +95,7 @@ class LoggerHook(Hook):
|
|||||||
if self.out_dir is not None:
|
if self.out_dir is not None:
|
||||||
self.file_client = FileClient.infer_client(file_client_args,
|
self.file_client = FileClient.infer_client(file_client_args,
|
||||||
self.out_dir)
|
self.out_dir)
|
||||||
|
self.log_metric_by_epoch = log_metric_by_epoch
|
||||||
|
|
||||||
def before_run(self, runner) -> None:
|
def before_run(self, runner) -> None:
|
||||||
"""Infer ``self.file_client`` from ``self.out_dir``. Initialize the
|
"""Infer ``self.file_client`` from ``self.out_dir``. Initialize the
|
||||||
@ -203,8 +208,21 @@ class LoggerHook(Hook):
|
|||||||
tag, log_str = runner.log_processor.get_log_after_epoch(
|
tag, log_str = runner.log_processor.get_log_after_epoch(
|
||||||
runner, len(runner.val_dataloader), 'val')
|
runner, len(runner.val_dataloader), 'val')
|
||||||
runner.logger.info(log_str)
|
runner.logger.info(log_str)
|
||||||
runner.visualizer.add_scalars(
|
if self.log_metric_by_epoch:
|
||||||
tag, step=runner.iter, file_path=self.json_log_path)
|
# when `log_metric_by_epoch` is set to True, it's expected
|
||||||
|
# that validation metric can be logged by epoch rather than
|
||||||
|
# by iter. At the same time, scalars related to time should
|
||||||
|
# still be logged by iter to avoid messy visualized result.
|
||||||
|
# see details in PR #278.
|
||||||
|
time_tags = {k: v for k, v in tag.items() if 'time' in k}
|
||||||
|
metric_tags = {k: v for k, v in tag.items() if 'time' not in k}
|
||||||
|
runner.visualizer.add_scalars(
|
||||||
|
time_tags, step=runner.iter, file_path=self.json_log_path)
|
||||||
|
runner.visualizer.add_scalars(
|
||||||
|
metric_tags, step=runner.epoch, file_path=self.json_log_path)
|
||||||
|
else:
|
||||||
|
runner.visualizer.add_scalars(
|
||||||
|
tag, step=runner.iter, file_path=self.json_log_path)
|
||||||
|
|
||||||
def after_test_epoch(self,
|
def after_test_epoch(self,
|
||||||
runner,
|
runner,
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import os.path as osp
|
import os.path as osp
|
||||||
from unittest.mock import MagicMock
|
from unittest.mock import ANY, MagicMock
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -119,6 +119,49 @@ class TestLoggerHook:
|
|||||||
runner.logger.info.assert_called()
|
runner.logger.info.assert_called()
|
||||||
runner.visualizer.add_scalars.assert_called()
|
runner.visualizer.add_scalars.assert_called()
|
||||||
|
|
||||||
|
# Test when `log_metric_by_epoch` is True
|
||||||
|
runner.log_processor.get_log_after_epoch = MagicMock(
|
||||||
|
return_value=({
|
||||||
|
'time': 1,
|
||||||
|
'datatime': 1,
|
||||||
|
'acc': 0.8
|
||||||
|
}, 'string'))
|
||||||
|
logger_hook.after_val_epoch(runner)
|
||||||
|
args = {'step': ANY, 'file_path': ANY}
|
||||||
|
# expect visualizer log `time` and `metric` respectively
|
||||||
|
runner.visualizer.add_scalars.assert_any_call(
|
||||||
|
{
|
||||||
|
'time': 1,
|
||||||
|
'datatime': 1
|
||||||
|
}, **args)
|
||||||
|
runner.visualizer.add_scalars.assert_any_call({'acc': 0.8}, **args)
|
||||||
|
|
||||||
|
# Test when `log_metric_by_epoch` is False
|
||||||
|
logger_hook = LoggerHook(log_metric_by_epoch=False)
|
||||||
|
runner.log_processor.get_log_after_epoch = MagicMock(
|
||||||
|
return_value=({
|
||||||
|
'time': 5,
|
||||||
|
'datatime': 5,
|
||||||
|
'acc': 0.5
|
||||||
|
}, 'string'))
|
||||||
|
logger_hook.after_val_epoch(runner)
|
||||||
|
# expect visualizer log `time` and `metric` jointly
|
||||||
|
runner.visualizer.add_scalars.assert_any_call(
|
||||||
|
{
|
||||||
|
'time': 5,
|
||||||
|
'datatime': 5,
|
||||||
|
'acc': 0.5
|
||||||
|
}, **args)
|
||||||
|
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
runner.visualizer.add_scalars.assert_any_call(
|
||||||
|
{
|
||||||
|
'time': 5,
|
||||||
|
'datatime': 5
|
||||||
|
}, **args)
|
||||||
|
with pytest.raises(AssertionError):
|
||||||
|
runner.visualizer.add_scalars.assert_any_call({'acc': 0.5}, **args)
|
||||||
|
|
||||||
def test_after_test_epoch(self):
|
def test_after_test_epoch(self):
|
||||||
logger_hook = LoggerHook()
|
logger_hook = LoggerHook()
|
||||||
runner = MagicMock()
|
runner = MagicMock()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user