[Feature] Add LoggerHook (#77)
* add logger hook * update * update * update test * update * update test * update * update * update * update * update * Add logger hook * Fix pre-commit * Fix as comment * Fix as comment * Fix as comment * Fix as comment * Fix as comment * Fix bytes * update * Fix as comment * Fix as comment * Update runner * Fix as comment * Fix as comment * Fix as comment * Fix as commentpull/112/head
parent
49b7d0ce6f
commit
6f69039ca9
|
@ -3,6 +3,7 @@ from .checkpoint_hook import CheckpointHook
|
|||
from .empty_cache_hook import EmptyCacheHook
|
||||
from .hook import Hook
|
||||
from .iter_timer_hook import IterTimerHook
|
||||
from .logger_hook import LoggerHook
|
||||
from .optimizer_hook import OptimizerHook
|
||||
from .param_scheduler_hook import ParamSchedulerHook
|
||||
from .sampler_seed_hook import DistSamplerSeedHook
|
||||
|
@ -10,5 +11,6 @@ from .sync_buffer_hook import SyncBuffersHook
|
|||
|
||||
__all__ = [
|
||||
'Hook', 'IterTimerHook', 'DistSamplerSeedHook', 'ParamSchedulerHook',
|
||||
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook'
|
||||
'OptimizerHook', 'SyncBuffersHook', 'EmptyCacheHook', 'CheckpointHook',
|
||||
'LoggerHook'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,509 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import datetime
|
||||
import os
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataSample
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
from mmengine.utils import is_tuple_of, scandir
|
||||
|
||||
DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataSample]]]
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class LoggerHook(Hook):
|
||||
"""In this logger hook, the information will be printed on the terminal and
|
||||
saved in JSON file, tensorboard, wandb .etc.
|
||||
|
||||
Args:
|
||||
by_epoch (bool): Whether ``EpochBasedLoop`` is used.
|
||||
Defaults to True.
|
||||
interval (int): Logging interval (every k iterations).
|
||||
Defaults to 10.
|
||||
custom_keys (dict, optional): Defines the keys in the log and which
|
||||
kinds of statistic methods should be used to log them.
|
||||
|
||||
- ``custom_keys`` contains multiple string-dict pairs. In each
|
||||
string-dict pair, the string defines a key name in the log and the
|
||||
dict is a config defines the statistic methods and corresponding
|
||||
arguments used to log the value. For example,
|
||||
``dict(loss=dict(method_name='mean', log_name='global_loss',
|
||||
window_size='global'))`` which means the log key ``loss`` will be
|
||||
counted as global mean and additionally logged as ``global_loss``.
|
||||
If ``log_name`` is not defined in config dict, the original logged
|
||||
key will be overwritten.
|
||||
- The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten
|
||||
because ``time`` and ``iter_time`` will be used to calculate
|
||||
estimated time of arrival. If you want to recount the time, you
|
||||
should set ``log_name`` in corresponding values.
|
||||
- For those statistic methods with the ``window_size`` argument,
|
||||
if ``by_epoch`` is set to False, ``windows_size`` should not be
|
||||
`epoch` to statistics log value by epoch.
|
||||
ignore_last (bool): Ignore the log of last iterations in each epoch if
|
||||
the number of remaining iterations is less than :attr:`interval`.
|
||||
Defaults to True.
|
||||
interval_exp_name (int): Logging interval for experiment name. This
|
||||
feature is to help users conveniently get the experiment
|
||||
information from screen or log file. Defaults to 1000.
|
||||
out_dir (str or Path, optional): The root directory to save
|
||||
checkpoints. If not specified, ``runner.work_dir`` will be used
|
||||
by default. If specified, the ``out_dir`` will be the concatenation
|
||||
of ``out_dir`` and the last level directory of
|
||||
``runner.work_dir``. For example, if the input ``our_dir`` is
|
||||
``./tmp`` and ``runner.work_dir`` is ``./work_dir/cur_exp``,
|
||||
then the log will be saved in ``./tmp/cur_exp``. Deafule to None.
|
||||
out_suffix (Tuple[str] or str): Those filenames ending with
|
||||
``out_suffix`` will be copied to ``out_dir``. Defaults to
|
||||
('.log.json', '.log', '.py').
|
||||
keep_local (bool): Whether to keep local logs in the local machine
|
||||
when :attr:`out_dir` is specified. If False, the local log will be
|
||||
removed. Defaults to True.
|
||||
file_client_args (dict, optional): Arguments to instantiate a
|
||||
FileClient. See :class:`mmengine.fileio.FileClient` for details.
|
||||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> # `log_name` is defined, `loss_mean_window` will be an additional
|
||||
>>> # record.
|
||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
||||
>>> custom_keys=dict(
|
||||
>>> loss=dict(
|
||||
>>> log_name='loss_mean_window',
|
||||
>>> method_name='mean',
|
||||
>>> window_size=10)))
|
||||
>>> # `log_name` is not defined. `loss` will be overwritten by
|
||||
>>> # `global_mean` statistics.
|
||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
||||
>>> custom_keys=dict(
|
||||
>>> loss=dict(
|
||||
>>> method_name='mean',
|
||||
>>> window_size='global')))
|
||||
>>> # `time` cannot be overwritten, `global_time` will be an additional
|
||||
>>> # record.
|
||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
||||
>>> custom_keys=dict(
|
||||
>>> time=dict(
|
||||
>>> log_name='global_time',
|
||||
>>> method='mean',
|
||||
>>> window_size='global')))
|
||||
>>> # Record loss with different statistics methods.
|
||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
||||
>>> custom_keys=dict(loss=[
|
||||
>>> dict(log_name='loss_mean_window',
|
||||
>>> method_name='mean',
|
||||
>>> window_size=10),
|
||||
>>> dict(method_name='mean',
|
||||
>>> window_size='global')]))
|
||||
"""
|
||||
# eta will be calculated by time. `time` and `data_time` should not be
|
||||
# overwritten.
|
||||
fixed_smooth_keys = ('time', 'data_time')
|
||||
priority = 'BELOW_NORMAL'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
by_epoch: bool = True,
|
||||
interval: int = 10,
|
||||
custom_keys: Optional[dict] = None,
|
||||
ignore_last: bool = True,
|
||||
interval_exp_name: int = 1000,
|
||||
out_dir: Optional[Union[str, Path]] = None,
|
||||
out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'),
|
||||
keep_local=True,
|
||||
file_client_args=None,
|
||||
):
|
||||
self.by_epoch = by_epoch
|
||||
self.interval = interval
|
||||
self.custom_keys = custom_keys if custom_keys is not None else dict()
|
||||
self.ignore_last = ignore_last
|
||||
|
||||
self.time_sec_tot = 0
|
||||
self.interval_exp_name = interval_exp_name
|
||||
self._check_custom_keys()
|
||||
|
||||
if out_dir is None and file_client_args is not None:
|
||||
raise ValueError(
|
||||
'file_client_args should be "None" when `out_dir` is not'
|
||||
'specified.')
|
||||
self.out_dir = out_dir
|
||||
|
||||
if not (out_dir is None or isinstance(out_dir, str)
|
||||
or is_tuple_of(out_dir, str)):
|
||||
raise TypeError('out_dir should be None or string or tuple of '
|
||||
f'string, but got {type(out_dir)}')
|
||||
self.out_suffix = out_suffix
|
||||
|
||||
self.keep_local = keep_local
|
||||
self.file_client_args = file_client_args
|
||||
if self.out_dir is not None:
|
||||
self.file_client = FileClient.infer_client(file_client_args,
|
||||
self.out_dir)
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
"""Infer ``self.file_client`` from ``self.out_dir``. Initialize the
|
||||
``self.start_iter`` and record the meta information.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
if self.out_dir is not None:
|
||||
# The final `self.out_dir` is the concatenation of `self.out_dir`
|
||||
# and the last level directory of `runner.work_dir`
|
||||
basename = osp.basename(runner.work_dir.rstrip(osp.sep))
|
||||
self.out_dir = self.file_client.join_path(self.out_dir, basename)
|
||||
runner.logger.info(
|
||||
(f'Text logs will be saved to {self.out_dir} by '
|
||||
f'{self.file_client.name} after the training process.'))
|
||||
|
||||
self.json_log_path = osp.join(runner.work_dir,
|
||||
f'{runner.timestamp}.log.json')
|
||||
self.yaml_log_path = osp.join(runner.work_dir,
|
||||
f'{runner.timestamp}.log.json')
|
||||
self.start_iter = runner.iter
|
||||
if runner.meta is not None:
|
||||
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
|
||||
|
||||
def after_train_iter(
|
||||
self,
|
||||
runner,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataSample]] = None) -> None:
|
||||
"""Record training logs.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
data_batch (Sequence[BaseDataSample], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
outputs (Sequence[BaseDataSample], optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
||||
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
||||
self.by_epoch and self.end_of_epoch(runner)):
|
||||
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
||||
runner.logger.info(exp_info)
|
||||
if self.by_epoch and self.every_n_inner_iters(runner, self.interval):
|
||||
self._log_train(runner)
|
||||
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
|
||||
self._log_train(runner)
|
||||
elif self.end_of_epoch(runner) and not self.ignore_last:
|
||||
# `runner.max_iters` may not be divisible by `self.interval`. if
|
||||
# `self.ignore_last==True`, the log of remaining iterations will
|
||||
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
|
||||
# iterations will be recorded).
|
||||
self._log_train(runner)
|
||||
|
||||
def after_val_epoch(self, runner) -> None:
|
||||
"""Record validation logs.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
self._log_val(runner)
|
||||
|
||||
def after_run(self, runner) -> None:
|
||||
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
# copy or upload logs to self.out_dir
|
||||
if self.out_dir is None:
|
||||
return
|
||||
for filename in scandir(runner.work_dir, self.out_suffix, True):
|
||||
local_filepath = osp.join(runner.work_dir, filename)
|
||||
out_filepath = self.file_client.join_path(self.out_dir, filename)
|
||||
with open(local_filepath, 'r') as f:
|
||||
self.file_client.put_text(f.read(), out_filepath)
|
||||
|
||||
runner.logger.info(
|
||||
(f'The file {local_filepath} has been uploaded to '
|
||||
f'{out_filepath}.'))
|
||||
|
||||
if not self.keep_local:
|
||||
os.remove(local_filepath)
|
||||
runner.logger.info((f'{local_filepath} was removed due to the '
|
||||
'`self.keep_local=False`'))
|
||||
|
||||
def _log_train(self, runner) -> None:
|
||||
"""Collect and record training logs which start named with "train/*".
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
tag = self._collect_info(runner, 'train')
|
||||
# The training log default defines `lr`, `momentum`, `time` and
|
||||
# `data_time`. `log_tag` will pop these keys and loop other keys to
|
||||
# `log_str`.
|
||||
log_tag = copy.deepcopy(tag)
|
||||
cur_iter = self._get_iter(runner, inner_iter=True)
|
||||
cur_epoch = self._get_epoch(runner, 'train')
|
||||
|
||||
# Record learning rate and momentum.
|
||||
lr_str_list = []
|
||||
momentum_str_list = []
|
||||
for key, value in tag.items():
|
||||
if key.startswith('lr'):
|
||||
log_tag.pop(key)
|
||||
lr_str_list.append(f'{key}: {value:.3e}')
|
||||
lr_str = ' '.join(lr_str_list)
|
||||
for key, value in tag.items():
|
||||
if key.startswith('momentum'):
|
||||
log_tag.pop(key)
|
||||
momentum_str_list.append(f'{key}: {value:.3e}')
|
||||
momentum_str = ' '.join(momentum_str_list)
|
||||
lr_momentum_str = f'{lr_str} {momentum_str}'
|
||||
# by epoch: Epoch [4][100/1000]
|
||||
# by iter: Iter [100/100000]
|
||||
if self.by_epoch:
|
||||
log_str = f'Epoch [{cur_epoch}]' \
|
||||
f'[{cur_iter}/{len(runner.data_loader)}]\t'
|
||||
else:
|
||||
log_str = f'Iter [{cur_iter}/{runner.max_iters}]\t'
|
||||
log_str += f'{lr_momentum_str}, '
|
||||
# Calculate eta time.
|
||||
self.time_sec_tot += (tag['time'] * self.interval)
|
||||
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
|
||||
eta_sec = time_sec_avg * (runner.max_iters - runner.iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
log_str += f'eta: {eta_str}, '
|
||||
log_str += f'time: {tag["time"]:.3f}, ' \
|
||||
f'data_time: {tag["data_time"]:.3f}, '
|
||||
# Pop recorded keys
|
||||
log_tag.pop('time')
|
||||
log_tag.pop('data_time')
|
||||
# statistic memory
|
||||
if torch.cuda.is_available():
|
||||
log_str += f'memory: {self._get_max_memory(runner)}, '
|
||||
# Loop left keys to fill `log_str`.
|
||||
log_items = []
|
||||
for name, val in log_tag.items():
|
||||
if isinstance(val, float):
|
||||
val = f'{val:.4f}'
|
||||
log_items.append(f'{name}: {val}')
|
||||
log_str += ', '.join(log_items)
|
||||
runner.logger.info(log_str)
|
||||
# Write logs to local, tensorboad, and wandb.
|
||||
runner.writer.add_scalars(
|
||||
tag, step=runner.iter + 1, file_path=self.json_log_path)
|
||||
|
||||
def _log_val(self, runner) -> None:
|
||||
"""Collect and record training logs which start named with "val/*".
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
tag = self._collect_info(runner, 'val')
|
||||
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
|
||||
eval_iter = len(runner.data_loader)
|
||||
cur_iter = self._get_iter(runner)
|
||||
cur_epoch = self._get_epoch(runner, 'val')
|
||||
# val/test time
|
||||
# here 1000 is the length of the val dataloader
|
||||
# by epoch: Epoch[val] [4][1000]
|
||||
# by iter: Iter[val] [1000]
|
||||
if self.by_epoch:
|
||||
# runner.epoch += 1 has been done before val workflow
|
||||
log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t'
|
||||
else:
|
||||
log_str = f'Iter(val) [{eval_iter}]\t'
|
||||
|
||||
log_items = []
|
||||
for name, val in tag.items():
|
||||
if isinstance(val, float):
|
||||
val = f'{val:.4f}'
|
||||
log_items.append(f'{name}: {val}')
|
||||
log_str += ', '.join(log_items)
|
||||
runner.logger.info(log_str)
|
||||
# Write tag.
|
||||
runner.writer.add_scalars(
|
||||
tag, step=cur_iter, file_path=self.json_log_path)
|
||||
|
||||
def _get_window_size(self, runner, window_size: Union[int, str]) \
|
||||
-> int:
|
||||
"""Parse window_size specified in ``self.custom_keys`` to int value.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
window_size (int or str): Smoothing scale of logs.
|
||||
|
||||
Returns:
|
||||
int: Smoothing window for statistical methods.
|
||||
"""
|
||||
if isinstance(window_size, int):
|
||||
assert window_size == self.interval, \
|
||||
'The value of windows size must equal to LoggerHook.interval'
|
||||
return window_size
|
||||
elif window_size == 'epoch':
|
||||
return runner.inner_iter + 1
|
||||
elif window_size == 'global':
|
||||
return runner.iter + 1
|
||||
else:
|
||||
raise ValueError('window_size should be int, epoch or global, but '
|
||||
f'got invalid {window_size}')
|
||||
|
||||
def _collect_info(self, runner, mode: str) -> dict:
|
||||
"""Collect log information to a dict according to mode.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
mode (str): 'train' or 'val', which means the prefix attached by
|
||||
runner.
|
||||
|
||||
Returns:
|
||||
dict: Statistical values of logs.
|
||||
"""
|
||||
tag = OrderedDict()
|
||||
log_buffers = runner.message_hub.log_buffers
|
||||
mode_log_buffers = OrderedDict()
|
||||
# Filter log_buffers which starts with `mode`.
|
||||
for prefix_key, log_buffer in log_buffers.items():
|
||||
if prefix_key.startswith(mode):
|
||||
key = prefix_key.split('/')[-1]
|
||||
mode_log_buffers[key] = log_buffer
|
||||
# Ensure all metric and lr values are latest.
|
||||
for key in mode_log_buffers:
|
||||
# Update the latest learning rate and smoothed time logs.
|
||||
if key in self.fixed_smooth_keys or key.startswith('loss'):
|
||||
tag[key] = mode_log_buffers[key].mean(self.interval)
|
||||
else:
|
||||
tag[key] = mode_log_buffers[key].current()
|
||||
# Update custom keys.
|
||||
if mode == 'train':
|
||||
for log_key, log_cfg in self.custom_keys.items():
|
||||
self._parse_custom_keys(runner, log_key,
|
||||
copy.deepcopy(log_cfg),
|
||||
mode_log_buffers, tag)
|
||||
return tag
|
||||
|
||||
def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict,
|
||||
log_buffers: OrderedDict, tag: OrderedDict) -> None:
|
||||
"""Statistics logs in log_buffers according to custom_keys.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
log_key (str): log key specified in ``self.custom_keys``
|
||||
log_cfg (dict): A config dict for describing the logging
|
||||
statistics method.
|
||||
log_buffers (OrderedDict): All logs for the corresponding phase.
|
||||
tag (OrderedDict): A dict which defines all statistic values of
|
||||
logs.
|
||||
"""
|
||||
if isinstance(log_cfg, list):
|
||||
log_names = set()
|
||||
for cfg in log_cfg:
|
||||
log_name = cfg.get('log_name', None)
|
||||
if log_name in log_names:
|
||||
raise KeyError(f'{cfg["log_name"]} cannot be Redefined in '
|
||||
'log_key')
|
||||
if log_name is not None:
|
||||
log_names.add(log_name)
|
||||
self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag)
|
||||
assert len(log_names) == len(log_cfg) - 1, \
|
||||
f'{log_key} cannot be overwritten multiple times, please ' \
|
||||
f'check only one key does not contain `log_name` in {log_cfg}.'
|
||||
elif isinstance(log_cfg, dict):
|
||||
if 'window_size' in log_cfg:
|
||||
log_cfg['window_size'] = \
|
||||
self._get_window_size(runner, log_cfg['window_size'])
|
||||
if 'log_name' in log_cfg:
|
||||
name = log_cfg.pop('log_name')
|
||||
else:
|
||||
name = log_key
|
||||
tag[name] = log_buffers[log_key].statistics(**log_cfg)
|
||||
else:
|
||||
raise ValueError('The structure of `LoggerHook.custom key` is '
|
||||
'wrong, please make sure the type of each key is '
|
||||
'dict or list.')
|
||||
|
||||
def _get_max_memory(self, runner) -> int:
|
||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
|
||||
for a given device.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
|
||||
Returns:
|
||||
The maximum GPU memory occupied by tensors in megabytes for a given
|
||||
device.
|
||||
"""
|
||||
# TODO use `mmengine.dist.max_memory_allocated` to count mem_mb
|
||||
device = getattr(runner.model, 'output_device', None)
|
||||
mem = torch.cuda.max_memory_allocated(device=device)
|
||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return int(mem_mb.item())
|
||||
|
||||
def _check_custom_keys(self) -> None:
|
||||
"""Check the legality of ``self.custom_keys``.
|
||||
|
||||
If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The
|
||||
key of ``self.fixed_smooth_keys`` cannot be overwritten.
|
||||
"""
|
||||
|
||||
def _check_window_size(item):
|
||||
if not self.by_epoch:
|
||||
assert item['window_size'] != 'epoch', \
|
||||
'window_size cannot be epoch if LoggerHook.by_epoch is ' \
|
||||
'False.'
|
||||
|
||||
def _check_fixed_keys(key, item):
|
||||
if key in self.fixed_smooth_keys:
|
||||
assert 'log_name' in item, f'{key} cannot be overwritten by ' \
|
||||
'custom keys!'
|
||||
|
||||
for key, value in self.custom_keys.items():
|
||||
if isinstance(value, Sequence):
|
||||
[(_check_window_size(item), _check_fixed_keys(key, item))
|
||||
for item in value]
|
||||
|
||||
else:
|
||||
_check_window_size(value)
|
||||
_check_fixed_keys(key, value)
|
||||
|
||||
def _get_epoch(self, runner, mode: str) -> int:
|
||||
"""Get epoch according to mode.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
mode (str): Train or val.
|
||||
|
||||
Returns:
|
||||
int: The current epoch.
|
||||
"""
|
||||
if mode == 'train':
|
||||
epoch = runner.epoch + 1
|
||||
elif mode == 'val':
|
||||
# normal val mode
|
||||
# runner.epoch += 1 has been done before val workflow
|
||||
epoch = runner.epoch
|
||||
else:
|
||||
raise ValueError(f"runner mode should be 'train' or 'val', "
|
||||
f'but got {runner.mode}')
|
||||
return epoch
|
||||
|
||||
def _get_iter(self, runner, inner_iter=False) -> int:
|
||||
"""Get the current training iteration step.
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
inner_iter (bool): Whether to return the inner iter of an epoch.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
int: The current global iter or inner iter.
|
||||
"""
|
||||
if self.by_epoch and inner_iter:
|
||||
current_iter = runner.inner_iter + 1
|
||||
else:
|
||||
current_iter = runner.iter + 1
|
||||
return current_iter
|
|
@ -0,0 +1,355 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import logging
|
||||
import os.path as osp
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.fileio.file_client import HardDiskBackend
|
||||
from mmengine.hooks import LoggerHook
|
||||
|
||||
|
||||
class TestLoggerHook:
|
||||
|
||||
def test_init(self):
|
||||
logger_hook = LoggerHook(out_dir='tmp.txt')
|
||||
assert logger_hook.by_epoch
|
||||
assert logger_hook.interval == 10
|
||||
assert not logger_hook.custom_keys
|
||||
assert logger_hook.ignore_last
|
||||
assert logger_hook.time_sec_tot == 0
|
||||
assert logger_hook.interval_exp_name == 1000
|
||||
assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
|
||||
assert logger_hook.keep_local
|
||||
assert logger_hook.file_client_args is None
|
||||
assert isinstance(logger_hook.file_client.client, HardDiskBackend)
|
||||
# out_dir should be None or string or tuple of string.
|
||||
with pytest.raises(TypeError):
|
||||
LoggerHook(out_dir=1)
|
||||
# time cannot be overwritten.
|
||||
with pytest.raises(AssertionError):
|
||||
LoggerHook(custom_keys=dict(time=dict(method='max')))
|
||||
LoggerHook(
|
||||
custom_keys=dict(time=[
|
||||
dict(method='max', log_name='time_max'),
|
||||
dict(method='min', log_name='time_min')
|
||||
]))
|
||||
# Epoch window_size cannot be used when `LoggerHook.by_epoch=False`
|
||||
with pytest.raises(AssertionError):
|
||||
LoggerHook(
|
||||
by_epoch=False,
|
||||
custom_keys=dict(
|
||||
time=dict(
|
||||
method='max', log_name='time_max',
|
||||
window_size='epoch')))
|
||||
with pytest.raises(ValueError):
|
||||
LoggerHook(file_client_args=dict(enable_mc=True))
|
||||
|
||||
def test_before_run(self):
|
||||
runner = MagicMock()
|
||||
runner.iter = 10
|
||||
runner.timestamp = 'timestamp'
|
||||
runner.work_dir = 'work_dir'
|
||||
runner.logger = MagicMock()
|
||||
logger_hook = LoggerHook(out_dir='out_dir')
|
||||
logger_hook.before_run(runner)
|
||||
assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
|
||||
assert logger_hook.json_log_path == osp.join('work_dir',
|
||||
'timestamp.log.json')
|
||||
assert logger_hook.start_iter == runner.iter
|
||||
runner.writer.add_params.assert_called()
|
||||
|
||||
def test_after_run(self, tmp_path):
|
||||
out_dir = tmp_path / 'out_dir'
|
||||
out_dir.mkdir()
|
||||
work_dir = tmp_path / 'work_dir'
|
||||
work_dir.mkdir()
|
||||
work_dir_json = work_dir / 'tmp.log.json'
|
||||
json_f = open(work_dir_json, 'w')
|
||||
json_f.close()
|
||||
runner = MagicMock()
|
||||
runner.work_dir = work_dir
|
||||
|
||||
logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
|
||||
logger_hook.out_dir = str(out_dir)
|
||||
logger_hook.after_run(runner)
|
||||
# Verify that the file has been moved to `out_dir`.
|
||||
assert not osp.exists(str(work_dir_json))
|
||||
assert osp.exists(str(out_dir / 'tmp.log.json'))
|
||||
|
||||
def test_after_train_iter(self):
|
||||
# Test LoggerHook by iter.
|
||||
runner = MagicMock()
|
||||
runner.iter = 10
|
||||
logger_hook = LoggerHook(by_epoch=False)
|
||||
logger_hook._log_train = MagicMock()
|
||||
logger_hook.after_train_iter(runner)
|
||||
# `cur_iter=10+1`, which cannot be exact division by
|
||||
# `logger_hook.interval`
|
||||
logger_hook._log_train.assert_not_called()
|
||||
runner.iter = 9
|
||||
logger_hook.after_train_iter(runner)
|
||||
logger_hook._log_train.assert_called()
|
||||
|
||||
# Test LoggerHook by epoch.
|
||||
logger_hook = LoggerHook(by_epoch=True)
|
||||
logger_hook._log_train = MagicMock()
|
||||
# Only `runner.inner_iter` will work.
|
||||
runner.iter = 9
|
||||
runner.inner_iter = 10
|
||||
logger_hook.after_train_iter(runner)
|
||||
logger_hook._log_train.assert_not_called()
|
||||
runner.inner_iter = 9
|
||||
logger_hook.after_train_iter(runner)
|
||||
logger_hook._log_train.assert_called()
|
||||
|
||||
# Test end of the epoch.
|
||||
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
|
||||
logger_hook._log_train = MagicMock()
|
||||
runner.data_loader = [0] * 5
|
||||
runner.inner_iter = 4
|
||||
logger_hook.after_train_iter(runner)
|
||||
logger_hook._log_train.assert_called()
|
||||
|
||||
# Test print exp_name
|
||||
runner.meta = dict(exp_name='retinanet')
|
||||
logger_hook = LoggerHook()
|
||||
runner.logger = MagicMock()
|
||||
logger_hook._log_train = MagicMock()
|
||||
logger_hook.after_train_iter(runner)
|
||||
runner.logger.info.assert_called_with(
|
||||
f'Exp name: {runner.meta["exp_name"]}')
|
||||
|
||||
def test_after_val_epoch(self):
|
||||
logger_hook = LoggerHook()
|
||||
runner = MagicMock()
|
||||
logger_hook._log_val = MagicMock()
|
||||
logger_hook.after_val_epoch(runner)
|
||||
logger_hook._log_val.assert_called()
|
||||
|
||||
@pytest.mark.parametrize('by_epoch', [True, False])
|
||||
def test_log_train(self, by_epoch, capsys):
|
||||
runner = self._setup_runner()
|
||||
runner.meta = dict(exp_name='retinanet')
|
||||
# Prepare LoggerHook
|
||||
logger_hook = LoggerHook(by_epoch=by_epoch)
|
||||
logger_hook.writer = MagicMock()
|
||||
logger_hook.time_sec_tot = 1000
|
||||
logger_hook.start_iter = 0
|
||||
logger_hook._get_max_memory = MagicMock(return_value='100')
|
||||
logger_hook.json_log_path = 'tmp.json'
|
||||
|
||||
# Prepare training information.
|
||||
train_infos = dict(
|
||||
lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0)
|
||||
logger_hook._collect_info = MagicMock(return_value=train_infos)
|
||||
logger_hook._log_train(runner)
|
||||
# Verify that the correct variables have been written.
|
||||
runner.writer.add_scalars.assert_called_with(
|
||||
train_infos, step=11, file_path='tmp.json')
|
||||
# Verify that the correct context have been logged.
|
||||
out, _ = capsys.readouterr()
|
||||
time_avg = logger_hook.time_sec_tot / (
|
||||
runner.iter + 1 - logger_hook.start_iter)
|
||||
eta_second = time_avg * (runner.max_iters - runner.iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_second)))
|
||||
if by_epoch:
|
||||
if torch.cuda.is_available():
|
||||
log_str = 'Epoch [2][2/5]\t' \
|
||||
f"lr: {train_infos['lr']:.3e} " \
|
||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
||||
f'eta: {eta_str}, ' \
|
||||
f"time: {train_infos['time']:.3f}, " \
|
||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
||||
f'memory: 100, ' \
|
||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
||||
else:
|
||||
log_str = 'Epoch [2][2/5]\t' \
|
||||
f"lr: {train_infos['lr']:.3e} " \
|
||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
||||
f'eta: {eta_str}, ' \
|
||||
f"time: {train_infos['time']:.3f}, " \
|
||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
||||
assert out == log_str
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
log_str = 'Iter [11/50]\t' \
|
||||
f"lr: {train_infos['lr']:.3e} " \
|
||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
||||
f'eta: {eta_str}, ' \
|
||||
f"time: {train_infos['time']:.3f}, " \
|
||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
||||
f'memory: 100, ' \
|
||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
||||
else:
|
||||
log_str = 'Iter [11/50]\t' \
|
||||
f"lr: {train_infos['lr']:.3e} " \
|
||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
||||
f'eta: {eta_str}, ' \
|
||||
f"time: {train_infos['time']:.3f}, " \
|
||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
||||
assert out == log_str
|
||||
|
||||
@pytest.mark.parametrize('by_epoch', [True, False])
|
||||
def test_log_val(self, by_epoch, capsys):
|
||||
runner = self._setup_runner()
|
||||
# Prepare LoggerHook.
|
||||
logger_hook = LoggerHook(by_epoch=by_epoch)
|
||||
logger_hook.json_log_path = 'tmp.json'
|
||||
metric = dict(accuracy=0.9, data_time=1.0)
|
||||
logger_hook._collect_info = MagicMock(return_value=metric)
|
||||
logger_hook._log_val(runner)
|
||||
# Verify that the correct context have been logged.
|
||||
out, _ = capsys.readouterr()
|
||||
runner.writer.add_scalars.assert_called_with(
|
||||
metric, step=11, file_path='tmp.json')
|
||||
if by_epoch:
|
||||
assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \
|
||||
'data_time: 1.0000\n'
|
||||
|
||||
else:
|
||||
assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \
|
||||
'data_time: 1.0000\n'
|
||||
|
||||
def test_get_window_size(self):
|
||||
runner = self._setup_runner()
|
||||
logger_hook = LoggerHook()
|
||||
# Test get window size by name.
|
||||
assert logger_hook._get_window_size(runner, 'epoch') == 2
|
||||
assert logger_hook._get_window_size(runner, 'global') == 11
|
||||
assert logger_hook._get_window_size(runner, 10) == 10
|
||||
# Window size must equal to `logger_hook.interval`.
|
||||
with pytest.raises(AssertionError):
|
||||
logger_hook._get_window_size(runner, 20)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
logger_hook._get_window_size(runner, 'unknwon')
|
||||
|
||||
def test_parse_custom_keys(self):
|
||||
tag = OrderedDict()
|
||||
runner = self._setup_runner()
|
||||
log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock())
|
||||
cfg_dict = dict(
|
||||
lr=dict(method='min'),
|
||||
loss=[
|
||||
dict(method='min', window_size='global'),
|
||||
dict(method='max', log_name='loss_max')
|
||||
])
|
||||
logger_hook = LoggerHook()
|
||||
for log_key, log_cfg in cfg_dict.items():
|
||||
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
|
||||
log_buffers, tag)
|
||||
assert list(tag) == ['lr', 'loss', 'loss_max']
|
||||
assert log_buffers['lr'].min.assert_called
|
||||
assert log_buffers['loss'].min.assert_called
|
||||
assert log_buffers['loss'].max.assert_called
|
||||
assert log_buffers['loss'].mean.assert_called
|
||||
# `log_name` Cannot be repeated.
|
||||
with pytest.raises(KeyError):
|
||||
cfg_dict = dict(loss=[
|
||||
dict(method='min', window_size='global'),
|
||||
dict(method='max', log_name='loss_max'),
|
||||
dict(method='mean', log_name='loss_max')
|
||||
])
|
||||
logger_hook.custom_keys = cfg_dict
|
||||
for log_key, log_cfg in cfg_dict.items():
|
||||
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
|
||||
log_buffers, tag)
|
||||
# `log_key` cannot be overwritten multiple times.
|
||||
with pytest.raises(AssertionError):
|
||||
cfg_dict = dict(loss=[
|
||||
dict(method='min', window_size='global'),
|
||||
dict(method='max'),
|
||||
])
|
||||
logger_hook.custom_keys = cfg_dict
|
||||
for log_key, log_cfg in cfg_dict.items():
|
||||
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
|
||||
log_buffers, tag)
|
||||
|
||||
def test_collect_info(self):
|
||||
runner = self._setup_runner()
|
||||
logger_hook = LoggerHook(
|
||||
custom_keys=dict(time=dict(method='max', log_name='time_max')))
|
||||
logger_hook._parse_custom_keys = MagicMock()
|
||||
# Collect with prefix.
|
||||
log_buffers = {
|
||||
'train/time': MagicMock(),
|
||||
'lr': MagicMock(),
|
||||
'train/loss_cls': MagicMock(),
|
||||
'val/metric': MagicMock()
|
||||
}
|
||||
runner.message_hub.log_buffers = log_buffers
|
||||
tag = logger_hook._collect_info(runner, mode='train')
|
||||
# Test parse custom_keys
|
||||
logger_hook._parse_custom_keys.assert_called()
|
||||
# Test training key in tag.
|
||||
assert list(tag.keys()) == ['time', 'loss_cls']
|
||||
# Test statistics lr with `current`, loss and time with 'mean'
|
||||
log_buffers['train/time'].mean.assert_called()
|
||||
log_buffers['train/loss_cls'].mean.assert_called()
|
||||
log_buffers['train/loss_cls'].current.assert_not_called()
|
||||
|
||||
tag = logger_hook._collect_info(runner, mode='val')
|
||||
assert list(tag.keys()) == ['metric']
|
||||
log_buffers['val/metric'].current.assert_called()
|
||||
|
||||
@patch('torch.distributed.reduce', MagicMock())
|
||||
def test_get_max_memory(self):
|
||||
logger_hook = LoggerHook()
|
||||
runner = MagicMock()
|
||||
runner.world_size = 1
|
||||
runner.model = torch.nn.Linear(1, 1)
|
||||
logger_hook._get_max_memory(runner)
|
||||
torch.distributed.reduce.assert_not_called()
|
||||
runner.world_size = 2
|
||||
logger_hook._get_max_memory(runner)
|
||||
torch.distributed.reduce.assert_called()
|
||||
|
||||
def test_get_iter(self):
|
||||
runner = self._setup_runner()
|
||||
logger_hook = LoggerHook()
|
||||
# Get global iter when `inner_iter=False`
|
||||
iter = logger_hook._get_iter(runner)
|
||||
assert iter == 11
|
||||
# Get inner iter
|
||||
iter = logger_hook._get_iter(runner, inner_iter=True)
|
||||
assert iter == 2
|
||||
# Still get global iter when `logger_hook.by_epoch==False`
|
||||
logger_hook.by_epoch = False
|
||||
iter = logger_hook._get_iter(runner, inner_iter=True)
|
||||
assert iter == 11
|
||||
|
||||
def test_get_epoch(self):
|
||||
runner = self._setup_runner()
|
||||
logger_hook = LoggerHook()
|
||||
epoch = logger_hook._get_epoch(runner, 'train')
|
||||
assert epoch == 2
|
||||
epoch = logger_hook._get_epoch(runner, 'val')
|
||||
assert epoch == 1
|
||||
with pytest.raises(ValueError):
|
||||
logger_hook._get_epoch(runner, 'test')
|
||||
|
||||
def _setup_runner(self):
|
||||
runner = MagicMock()
|
||||
runner.epoch = 1
|
||||
runner.data_loader = [0] * 5
|
||||
runner.inner_iter = 1
|
||||
runner.iter = 10
|
||||
runner.max_iters = 50
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
for handler in logger.handlers:
|
||||
if not isinstance(handler, logging.StreamHandler):
|
||||
continue
|
||||
else:
|
||||
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
|
||||
runner.logger = logger
|
||||
runner.message_hub = MagicMock()
|
||||
runner.composed_wirter = MagicMock()
|
||||
return runner
|
Loading…
Reference in New Issue