[Refactor] Refactor `after_val_epoch` to make it output metric by epoch (#278)
* [Refactor]:Refactor `after_val_epoch` to make it output metric by epoch * add an option for user to choose the way of outputing metric * rename variable * reformat docstring * add type alias * reformat code * add test function * add comment and test code * add comment and test codepull/315/head^2
parent
ef946404e6
commit
dceef1f66f
|
@ -11,6 +11,7 @@ from mmengine.registry import HOOKS
|
|||
from mmengine.utils import is_tuple_of, scandir
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
SUFFIX_TYPE = Union[Sequence[str], str]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
|
@ -51,6 +52,11 @@ class LoggerHook(Hook):
|
|||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
||||
Defaults to None.
|
||||
log_metric_by_epoch (bool): Whether to output metric in validation step
|
||||
by epoch. It can be true when running in epoch based runner.
|
||||
If set to True, `after_val_epoch` will set `step` to self.epoch in
|
||||
`runner.visualizer.add_scalars`. Otherwise `step` will be
|
||||
self.iter. Default to True.
|
||||
|
||||
Examples:
|
||||
>>> # The simplest LoggerHook config.
|
||||
|
@ -58,17 +64,15 @@ class LoggerHook(Hook):
|
|||
"""
|
||||
priority = 'BELOW_NORMAL'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
interval: int = 10,
|
||||
ignore_last: bool = True,
|
||||
interval_exp_name: int = 1000,
|
||||
out_dir: Optional[Union[str, Path]] = None,
|
||||
out_suffix: Union[Sequence[str],
|
||||
str] = ('.json', '.log', '.py', 'yaml'),
|
||||
keep_local: bool = True,
|
||||
file_client_args: Optional[dict] = None,
|
||||
):
|
||||
def __init__(self,
|
||||
interval: int = 10,
|
||||
ignore_last: bool = True,
|
||||
interval_exp_name: int = 1000,
|
||||
out_dir: Optional[Union[str, Path]] = None,
|
||||
out_suffix: SUFFIX_TYPE = ('.json', '.log', '.py', 'yaml'),
|
||||
keep_local: bool = True,
|
||||
file_client_args: Optional[dict] = None,
|
||||
log_metric_by_epoch: bool = True):
|
||||
self.interval = interval
|
||||
self.ignore_last = ignore_last
|
||||
self.interval_exp_name = interval_exp_name
|
||||
|
@ -91,6 +95,7 @@ class LoggerHook(Hook):
|
|||
if self.out_dir is not None:
|
||||
self.file_client = FileClient.infer_client(file_client_args,
|
||||
self.out_dir)
|
||||
self.log_metric_by_epoch = log_metric_by_epoch
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
"""Infer ``self.file_client`` from ``self.out_dir``. Initialize the
|
||||
|
@ -203,8 +208,21 @@ class LoggerHook(Hook):
|
|||
tag, log_str = runner.log_processor.get_log_after_epoch(
|
||||
runner, len(runner.val_dataloader), 'val')
|
||||
runner.logger.info(log_str)
|
||||
runner.visualizer.add_scalars(
|
||||
tag, step=runner.iter, file_path=self.json_log_path)
|
||||
if self.log_metric_by_epoch:
|
||||
# when `log_metric_by_epoch` is set to True, it's expected
|
||||
# that validation metric can be logged by epoch rather than
|
||||
# by iter. At the same time, scalars related to time should
|
||||
# still be logged by iter to avoid messy visualized result.
|
||||
# see details in PR #278.
|
||||
time_tags = {k: v for k, v in tag.items() if 'time' in k}
|
||||
metric_tags = {k: v for k, v in tag.items() if 'time' not in k}
|
||||
runner.visualizer.add_scalars(
|
||||
time_tags, step=runner.iter, file_path=self.json_log_path)
|
||||
runner.visualizer.add_scalars(
|
||||
metric_tags, step=runner.epoch, file_path=self.json_log_path)
|
||||
else:
|
||||
runner.visualizer.add_scalars(
|
||||
tag, step=runner.iter, file_path=self.json_log_path)
|
||||
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import os.path as osp
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import ANY, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -119,6 +119,49 @@ class TestLoggerHook:
|
|||
runner.logger.info.assert_called()
|
||||
runner.visualizer.add_scalars.assert_called()
|
||||
|
||||
# Test when `log_metric_by_epoch` is True
|
||||
runner.log_processor.get_log_after_epoch = MagicMock(
|
||||
return_value=({
|
||||
'time': 1,
|
||||
'datatime': 1,
|
||||
'acc': 0.8
|
||||
}, 'string'))
|
||||
logger_hook.after_val_epoch(runner)
|
||||
args = {'step': ANY, 'file_path': ANY}
|
||||
# expect visualizer log `time` and `metric` respectively
|
||||
runner.visualizer.add_scalars.assert_any_call(
|
||||
{
|
||||
'time': 1,
|
||||
'datatime': 1
|
||||
}, **args)
|
||||
runner.visualizer.add_scalars.assert_any_call({'acc': 0.8}, **args)
|
||||
|
||||
# Test when `log_metric_by_epoch` is False
|
||||
logger_hook = LoggerHook(log_metric_by_epoch=False)
|
||||
runner.log_processor.get_log_after_epoch = MagicMock(
|
||||
return_value=({
|
||||
'time': 5,
|
||||
'datatime': 5,
|
||||
'acc': 0.5
|
||||
}, 'string'))
|
||||
logger_hook.after_val_epoch(runner)
|
||||
# expect visualizer log `time` and `metric` jointly
|
||||
runner.visualizer.add_scalars.assert_any_call(
|
||||
{
|
||||
'time': 5,
|
||||
'datatime': 5,
|
||||
'acc': 0.5
|
||||
}, **args)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
runner.visualizer.add_scalars.assert_any_call(
|
||||
{
|
||||
'time': 5,
|
||||
'datatime': 5
|
||||
}, **args)
|
||||
with pytest.raises(AssertionError):
|
||||
runner.visualizer.add_scalars.assert_any_call({'acc': 0.5}, **args)
|
||||
|
||||
def test_after_test_epoch(self):
|
||||
logger_hook = LoggerHook()
|
||||
runner = MagicMock()
|
||||
|
|
Loading…
Reference in New Issue