[Refactor] Refine LoggerHook (#155)

* rename global accessible and intergration get_sintance and create_instance

* move ManagerMixin to utils

* fix as docstring and seporate get_instance to get_instance and get_current_instance

* fix lint

* fix docstring, rename and move test_global_meta

* rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume

* refine MMLogger timestamp, update unit test

* MMLogger add logger_name arguments

* Fix docstring

* Add LogProcessor and some unit test

* update unit test

* complete LogProcessor unit test

* refine LoggerHook

* solve circle import

* change default logger_name to mmengine

* refactor eta

* Fix docstring comment and unitt test

* Fix with runner

* fix docstring

fix docstring

* fix docstring

* Add by_epoch attribute to LoggerHook and fix docstring

* Please mypy and fix comment

* remove \ in MMLogger

* Fix lint

* roll back pre-commit-hook

* Fix hook unit test

* Fix comments

* remove \t in log and add docstring

* Fix as comment

* should not accept other arguments if corresponding instance has been created

* fix logging ddp file saving

* fix logging ddp file saving

* move log processor to logging

* move log processor to logging

* remove current datalaoder

* fix docstring

* fix unit test

* add learing rate in messagehub

* Support output training/validation/testing message after iterations/epochs

* fix docstring

* Fix IterBasedRunner log string

* Fix IterBasedRunner log string

* Support parse validation loss in log processor
pull/197/head
Mashiro 2022-04-24 19:23:28 +08:00 committed by GitHub
parent 4274679376
commit e2a2b0438e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 961 additions and 702 deletions

View File

@ -358,11 +358,11 @@ class Hook:
"""
return (runner.epoch + 1) % n == 0 if n > 0 else False
def every_n_inner_iters(self, inner_iter: int, n: int) -> bool:
def every_n_inner_iters(self, batch_idx: int, n: int) -> bool:
"""Test whether current inner iteration can be evenly divided by n.
Args:
inner_iter (int): Current inner_iter of the training, validation
batch_idx (int): Current batch index of the training, validation
or testing loop.
n (int): Whether current inner iteration can be evenly
divided by n.
@ -371,7 +371,7 @@ class Hook:
bool: Whether current inner iteration can be evenly
divided by n.
"""
return (inner_iter + 1) % n == 0 if n > 0 else False
return (batch_idx + 1) % n == 0 if n > 0 else False
def every_n_iters(self, runner, n: int) -> bool:
"""Test whether current iteration can be evenly divided by n.
@ -387,19 +387,18 @@ class Hook:
"""
return (runner.iter + 1) % n == 0 if n > 0 else False
def end_of_epoch(self, runner, batch_idx: int) -> bool:
def end_of_epoch(self, dataloader, batch_idx: int) -> bool:
"""Check whether the current iteration reaches the last iteration of
current dataloader.
the dataloader.
Args:
runner (Runner): The runner of the training, validation or testing
process.
dataloader (Dataloader): The dataloader of the training,
validation or testing process.
batch_idx (int): The index of the current batch in the loop.
Returns:
bool: Whether reaches the end of current epoch or not.
"""
return batch_idx + 1 == len(runner.cur_dataloader)
return batch_idx + 1 == len(dataloader)
def is_last_train_epoch(self, runner) -> bool:
"""Test whether current epoch is the last train epoch.
@ -418,10 +417,10 @@ class Hook:
Args:
runner (Runner): The runner of the training, validation or testing
process.
mode (str): Current mode of runner. Defaults to 'train'.
Returns:
bool: Whether current iteration is the last iteration.
mode (str): Current mode of runner. Defaults to 'train'.
"""
if mode == 'train':
return runner.iter + 1 == runner.train_loop.max_iters

View File

@ -18,11 +18,25 @@ class IterTimerHook(Hook):
priority = 'NORMAL'
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Record time flag before start a epoch.
def __init__(self):
self.time_sec_tot = 0
self.start_iter = 0
def before_run(self, runner) -> None:
"""Synchronize the number of iterations with the runner.
Args:
runner (Runner): The runner of the training process.
runner: The runner of the training, validation or testing
process.
"""
self.start_iter = runner.iter
def _before_epoch(self, runner, mode: str = 'train') -> None:
"""Record timestamp before start an epoch.
Args:
runner (Runner): The runner of the training validation and
testing process.
mode (str): Current mode of runner. Defaults to 'train'.
"""
self.t = time.time()
@ -32,16 +46,18 @@ class IterTimerHook(Hook):
batch_idx: int,
data_batch: DATA_BATCH = None,
mode: str = 'train') -> None:
"""Logging time for loading data and update the time flag.
"""Calculating time for loading data and updating "data_time"
``HistoryBuffer`` of ``runner.message_hub``.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the training, validation and
testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# TODO: update for new logging system
# Update data loading time in `runner.message_hub`.
runner.message_hub.update_scalar(f'{mode}/data_time',
time.time() - self.t)
@ -52,10 +68,12 @@ class IterTimerHook(Hook):
outputs: Optional[Union[dict,
Sequence[BaseDataElement]]] = None,
mode: str = 'train') -> None:
"""Logging time for a iteration and update the time flag.
"""Calculating time for an iteration and updating "time"
``HistoryBuffer`` of ``runner.message_hub``.
Args:
runner (Runner): The runner of the training process.
runner (Runner): The runner of the training validation and
testing process.
batch_idx (int): The index of the current batch in the loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
from dataloader. Defaults to None.
@ -63,7 +81,31 @@ class IterTimerHook(Hook):
to None.
mode (str): Current mode of runner. Defaults to 'train'.
"""
# TODO: update for new logging system
runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
# Update iteration time in `runner.message_hub`.
message_hub = runner.message_hub
message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
self.t = time.time()
window_size = runner.log_processor.window_size
# Calculate eta every `window_size` iterations. Since test and val
# loop will not update runner.iter, use `every_n_innter_iters`to check
# the interval.
if self.every_n_inner_iters(batch_idx, window_size):
iter_time = message_hub.get_scalar(f'{mode}/time').mean(
window_size)
if mode == 'train':
self.time_sec_tot += iter_time * window_size
# Calculate average iterative time.
time_sec_avg = self.time_sec_tot / (
runner.iter - self.start_iter + 1)
# Calculate eta.
eta_sec = time_sec_avg * (
runner.train_loop.max_iters - runner.iter - 1)
runner.message_hub.update_info('eta', eta_sec)
else:
if mode == 'val':
cur_dataloader = runner.val_loop.dataloader
else:
cur_dataloader = runner.test_loop.dataloader
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
runner.message_hub.update_info('eta', eta_sec)

View File

@ -1,16 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import datetime
import os
import os.path as osp
from collections import OrderedDict
from pathlib import Path
from typing import Any, Optional, Sequence, Tuple, Union
import torch
from mmengine.data import BaseDataElement
from mmengine.dist import master_only
from mmengine.fileio import FileClient
from mmengine.hooks import Hook
from mmengine.registry import HOOKS
@ -21,33 +15,20 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
@HOOKS.register_module()
class LoggerHook(Hook):
"""In this logger hook, the information will be printed on the terminal and
saved in JSON file, tensorboard, wandb .etc.
"""Collect logs from different components of ``Runner`` and write them to
terminal, JSON file, tensorboard and wandb .etc.
``LoggerHook`` is used to record logs formatted by ``LogProcessor`` during
training/validation/testing phase. It is used to control following
behaviers:
- The frequency of logs update in terminal, local, tensorboad wandb.etc.
- The frequency of show experiment information in terminal.
- The work directory to save logs.
Args:
by_epoch (bool): Whether ``EpochBasedLoop`` is used.
Defaults to True.
interval (int): Logging interval (every k iterations).
Defaults to 10.
custom_keys (dict, optional): Defines the keys in the log and which
kinds of statistic methods should be used to log them.
- ``custom_keys`` contains multiple string-dict pairs. In each
string-dict pair, the string defines a key name in the log and the
dict is a config defines the statistic methods and corresponding
arguments used to log the value. For example,
``dict(loss=dict(method_name='mean', log_name='global_loss',
window_size='global'))`` which means the log key ``loss`` will be
counted as global mean and additionally logged as ``global_loss``.
If ``log_name`` is not defined in config dict, the original logged
key will be overwritten.
- The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten
because ``time`` and ``iter_time`` will be used to calculate
estimated time of arrival. If you want to recount the time, you
should set ``log_name`` in corresponding values.
- For those statistic methods with the ``window_size`` argument,
if ``by_epoch`` is set to False, ``windows_size`` should not be
`epoch` to statistics log value by epoch.
ignore_last (bool): Ignore the log of last iterations in each epoch if
the number of remaining iterations is less than :attr:`interval`.
Defaults to True.
@ -72,64 +53,24 @@ class LoggerHook(Hook):
Defaults to None.
Examples:
>>> # `log_name` is defined, `loss_mean_window` will be an additional
>>> # record.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(
>>> loss=dict(
>>> log_name='loss_mean_window',
>>> method_name='mean',
>>> window_size=10)))
>>> # `log_name` is not defined. `loss` will be overwritten by
>>> # `global_mean` statistics.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(
>>> loss=dict(
>>> method_name='mean',
>>> window_size='global')))
>>> # `time` cannot be overwritten, `global_time` will be an additional
>>> # record.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(
>>> time=dict(
>>> log_name='global_time',
>>> method='mean',
>>> window_size='global')))
>>> # Record loss with different statistics methods.
>>> logger_hook_cfg = dict(by_epoch=True,
>>> custom_keys=dict(loss=[
>>> dict(log_name='loss_mean_window',
>>> method_name='mean',
>>> window_size=10),
>>> dict(method_name='mean',
>>> window_size='global')]))
>>> # A simplest LoggerHook config.
>>> logger_hook_cfg = dict(interval=20)
"""
# eta will be calculated by time. `time` and `data_time` should not be
# overwritten.
fixed_smooth_keys = ('time', 'data_time')
priority = 'BELOW_NORMAL'
def __init__(
self,
by_epoch: bool = True,
interval: int = 10,
custom_keys: Optional[dict] = None,
ignore_last: bool = True,
interval_exp_name: int = 1000,
out_dir: Optional[Union[str, Path]] = None,
out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'),
keep_local=True,
file_client_args=None,
keep_local: bool = True,
file_client_args: Optional[dict] = None,
):
self._inner_iter = 0
self.by_epoch = by_epoch
self.interval = interval
self.custom_keys = custom_keys if custom_keys is not None else dict()
self.ignore_last = ignore_last
self.time_sec_tot = 0
self.interval_exp_name = interval_exp_name
self._check_custom_keys()
if out_dir is None and file_client_args is not None:
raise ValueError(
@ -169,7 +110,7 @@ class LoggerHook(Hook):
f'{runner.timestamp}.log.json')
self.yaml_log_path = osp.join(runner.work_dir,
f'{runner.timestamp}.log.json')
self.start_iter = runner.iter
# TODO Compatible with Visualizer.
if runner.meta is not None:
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
@ -178,41 +119,100 @@ class LoggerHook(Hook):
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[dict] = None) -> None:
"""Record training logs.
"""Record training logs after training iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[BaseDataElement], optional): Data from
dataloader. Defaults to None.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
outputs (dict, optional): Outputs from model.
Defaults to None.
"""
self._inner_iter = batch_idx
if runner.meta is not None and 'exp_name' in runner.meta:
if (self.every_n_iters(runner, self.interval_exp_name)) or (
self.by_epoch and self.end_of_epoch(runner, batch_idx)):
exp_info = f'Exp name: {runner.meta["exp_name"]}'
# Print experiment name every n iterations.
if self.every_n_iters(runner,
self.interval_exp_name) or (self.end_of_epoch(
runner.train_dataloader, batch_idx)):
exp_info = f'Exp name: {runner.experiment_name}'
runner.logger.info(exp_info)
if self.by_epoch and self.every_n_inner_iters(batch_idx,
self.interval):
self._log_train(runner)
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
self._log_train(runner)
elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last:
if self.every_n_inner_iters(batch_idx, self.interval):
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'train')
elif (self.end_of_epoch(runner.train_dataloader, batch_idx)
and not self.ignore_last):
# `runner.max_iters` may not be divisible by `self.interval`. if
# `self.ignore_last==True`, the log of remaining iterations will
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
# iterations will be recorded).
self._log_train(runner)
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'train')
else:
return
runner.logger.info(log_str)
# TODO compatible with visualizer.
runner.writer.add_scalars(tag, step=runner.iter + 1)
def after_val_iter(
self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
"""Record validation logs after validation iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
outputs (sequence, optional): Outputs from model. Defaults to None.
"""
if self.every_n_inner_iters(batch_idx, self.interval):
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'val')
runner.logger.info(log_str)
def after_test_iter(
self,
runner,
batch_idx: int,
data_batch: DATA_BATCH = None,
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
"""Record testing logs after iteration.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The index of the current batch in the train loop.
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
Data from dataloader. Defaults to None.
outputs (sequence, optional): Outputs from model. Defaults to None.
"""
if self.every_n_inner_iters(batch_idx, self.interval):
tag, log_str = runner.log_processor.get_log_after_iter(
runner, batch_idx, 'test')
runner.logger.info(log_str)
def after_val_epoch(self, runner) -> None:
"""Record validation logs.
"""Record validation logs after validation epoch.
Args:
runner (Runner): The runner of the training process.
"""
self._log_val(runner)
tag, log_str = runner.log_processor.get_log_after_epoch(
runner, len(runner.val_dataloader), 'val')
runner.logger.info(log_str)
# TODO compatible with visualizer.
runner.writer.add_scalars(tag, step=runner.iter + 1)
def after_test_epoch(self, runner) -> None:
"""Record testing logs after test epoch.
Args:
runner (Runner): The runner of the training process.
"""
tag, log_str = runner.log_processor.get_log_after_epoch(
runner, len(runner.val_dataloader), 'test')
runner.logger.info(log_str)
def after_run(self, runner) -> None:
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
@ -237,280 +237,3 @@ class LoggerHook(Hook):
os.remove(local_filepath)
runner.logger.info((f'{local_filepath} was removed due to the '
'`self.keep_local=False`'))
@master_only
def _log_train(self, runner) -> None:
"""Collect and record training logs which start named with "train/*".
Args:
runner (Runner): The runner of the training process.
"""
tag = self._collect_info(runner, 'train')
# The training log default defines `lr`, `momentum`, `time` and
# `data_time`. `log_tag` will pop these keys and loop other keys to
# `log_str`.
log_tag = copy.deepcopy(tag)
cur_iter = self._get_iter(runner, inner_iter=True)
cur_epoch = self._get_epoch(runner, 'train')
# Record learning rate and momentum.
lr_str_list = []
momentum_str_list = []
for key, value in tag.items():
if key.startswith('lr'):
log_tag.pop(key)
lr_str_list.append(f'{key}: {value:.3e}')
lr_str = ' '.join(lr_str_list)
for key, value in tag.items():
if key.startswith('momentum'):
log_tag.pop(key)
momentum_str_list.append(f'{key}: {value:.3e}')
momentum_str = ' '.join(momentum_str_list)
lr_momentum_str = f'{lr_str} {momentum_str}'
# by epoch: Epoch [4][100/1000]
# by iter: Iter [100/100000]
if self.by_epoch:
log_str = f'Epoch [{cur_epoch}]' \
f'[{cur_iter}/{len(runner.cur_dataloader)}]\t'
else:
log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t'
log_str += f'{lr_momentum_str}, '
# Calculate eta time.
self.time_sec_tot += (tag['time'] * self.interval)
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
eta_sec = time_sec_avg * (
runner.train_loop.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
log_str += f'eta: {eta_str}, '
log_str += f'time: {tag["time"]:.3f}, ' \
f'data_time: {tag["data_time"]:.3f}, '
# Pop recorded keys
log_tag.pop('time')
log_tag.pop('data_time')
# statistic memory
if torch.cuda.is_available():
log_str += f'memory: {self._get_max_memory(runner)}, '
# Loop left keys to fill `log_str`.
log_items = []
for name, val in log_tag.items():
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)
runner.logger.info(log_str)
# Write logs to local, tensorboad, and wandb.
runner.writer.add_scalars(
tag, step=runner.iter + 1, file_path=self.json_log_path)
@master_only
def _log_val(self, runner) -> None:
"""Collect and record training logs which start named with "val/*".
Args:
runner (Runner): The runner of the training process.
"""
tag = self._collect_info(runner, 'val')
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
eval_iter = len(runner.cur_dataloader)
cur_iter = self._get_iter(runner)
cur_epoch = self._get_epoch(runner, 'val')
# val/test time
# here 1000 is the length of the val dataloader
# by epoch: Epoch[val] [4][1000]
# by iter: Iter[val] [1000]
if self.by_epoch:
# runner.epoch += 1 has been done before val workflow
log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t'
else:
log_str = f'Iter(val) [{eval_iter}]\t'
log_items = []
for name, val in tag.items():
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ', '.join(log_items)
runner.logger.info(log_str)
# Write tag.
runner.writer.add_scalars(
tag, step=cur_iter, file_path=self.json_log_path)
def _get_window_size(self, runner, window_size: Union[int, str]) \
-> int:
"""Parse window_size specified in ``self.custom_keys`` to int value.
Args:
runner (Runner): The runner of the training process.
window_size (int or str): Smoothing scale of logs.
Returns:
int: Smoothing window for statistical methods.
"""
if isinstance(window_size, int):
assert window_size == self.interval, \
'The value of windows size must equal to LoggerHook.interval'
return window_size
elif window_size == 'epoch':
return self._inner_iter + 1
elif window_size == 'global':
return runner.iter + 1
else:
raise ValueError('window_size should be int, epoch or global, but '
f'got invalid {window_size}')
def _collect_info(self, runner, mode: str) -> dict:
"""Collect log information to a dict according to mode.
Args:
runner (Runner): The runner of the training process.
mode (str): 'train' or 'val', which means the prefix attached by
runner.
Returns:
dict: Statistical values of logs.
"""
tag = OrderedDict()
log_buffers = runner.message_hub.log_scalars
mode_log_buffers = OrderedDict()
# Filter log_buffers which starts with `mode`.
for prefix_key, log_buffer in log_buffers.items():
if prefix_key.startswith(mode):
key = prefix_key.split('/')[-1]
mode_log_buffers[key] = log_buffer
# Ensure all metric and lr values are latest.
for key in mode_log_buffers:
# Update the latest learning rate and smoothed time logs.
if key in self.fixed_smooth_keys or key.startswith('loss'):
tag[key] = mode_log_buffers[key].mean(self.interval)
else:
tag[key] = mode_log_buffers[key].current()
# Update custom keys.
if mode == 'train':
for log_key, log_cfg in self.custom_keys.items():
self._parse_custom_keys(runner, log_key,
copy.deepcopy(log_cfg),
mode_log_buffers, tag)
return tag
def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict,
log_buffers: OrderedDict, tag: OrderedDict) -> None:
"""Statistics logs in log_buffers according to custom_keys.
Args:
runner (Runner): The runner of the training process.
log_key (str): log key specified in ``self.custom_keys``
log_cfg (dict): A config dict for describing the logging
statistics method.
log_buffers (OrderedDict): All logs for the corresponding phase.
tag (OrderedDict): A dict which defines all statistic values of
logs.
"""
if isinstance(log_cfg, list):
log_names = set()
for cfg in log_cfg:
log_name = cfg.get('log_name', None)
if log_name in log_names:
raise KeyError(f'{cfg["log_name"]} cannot be redefined in '
'log_key')
if log_name is not None:
log_names.add(log_name)
self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag)
assert len(log_names) == len(log_cfg) - 1, \
f'{log_key} cannot be overwritten multiple times, please ' \
f'check only one key does not contain `log_name` in {log_cfg}.'
elif isinstance(log_cfg, dict):
if 'window_size' in log_cfg:
log_cfg['window_size'] = \
self._get_window_size(runner, log_cfg['window_size'])
if 'log_name' in log_cfg:
name = log_cfg.pop('log_name')
else:
name = log_key
tag[name] = log_buffers[log_key].statistics(**log_cfg).item()
else:
raise ValueError('The structure of `LoggerHook.custom key` is '
'wrong, please make sure the type of each key is '
'dict or list.')
def _get_max_memory(self, runner) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
for a given device.
Args:
runner (Runner): The runner of the training process.
Returns:
The maximum GPU memory occupied by tensors in megabytes for a given
device.
"""
device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())
def _check_custom_keys(self) -> None:
"""Check the legality of ``self.custom_keys``.
If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The
key of ``self.fixed_smooth_keys`` cannot be overwritten.
"""
def _check_window_size(item):
if not self.by_epoch:
assert item['window_size'] != 'epoch', \
'window_size cannot be epoch if LoggerHook.by_epoch is ' \
'False.'
def _check_fixed_keys(key, item):
if key in self.fixed_smooth_keys:
assert 'log_name' in item, f'{key} cannot be overwritten by ' \
'custom keys!'
for key, value in self.custom_keys.items():
if isinstance(value, Sequence):
[(_check_window_size(item), _check_fixed_keys(key, item))
for item in value]
else:
_check_window_size(value)
_check_fixed_keys(key, value)
def _get_epoch(self, runner, mode: str) -> int:
"""Get epoch according to mode.
Args:
runner (Runner): The runner of the training process.
mode (str): Train or val.
Returns:
int: The current epoch.
"""
if mode == 'train':
epoch = runner.epoch + 1
elif mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before val workflow
epoch = runner.epoch
else:
raise ValueError(f"runner mode should be 'train' or 'val', "
f'but got {runner.mode}')
return epoch
def _get_iter(self, runner, inner_iter=False) -> int:
"""Get the current training iteration step.
Args:
runner (Runner): The runner of the training process.
inner_iter (bool): Whether to return the inner iter of an epoch.
Defaults to False.
Returns:
int: The current global iter or inner iter.
"""
if self.by_epoch and inner_iter:
current_iter = self._inner_iter + 1
else:
current_iter = runner.iter + 1
return current_iter

View File

@ -86,6 +86,9 @@ class OptimizerHook(Hook):
we keep ``outputs`` here. Defaults to None.
"""
runner.optimizer.zero_grad()
runner.message_hub.update_scalar(
'train/lr', runner.optimizer.param_groups[0]['lr'])
if self.detect_anomalous_params:
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
runner.outputs['loss'].backward()

View File

@ -1,6 +1,9 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .history_buffer import HistoryBuffer
from .log_processor import LogProcessor
from .logger import MMLogger, print_log
from .message_hub import MessageHub
__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']
__all__ = [
'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor'
]

View File

@ -0,0 +1,409 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import datetime
from collections import OrderedDict
from typing import List, Optional, Tuple
import torch
class LogProcessor:
"""A log processor used to format log information collected from
``runner.message_hub.log_scalars``.
``LogProcessor`` instance is built by runner and will format
``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can
directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument
``custom_cfg`` of constructor can control the statistics method of logs.
Args:
window_size (int): default smooth interval Defaults to 10.
by_epoch (bool): Whether to format logs with epoch stype. Defaults to
True.
custom_cfg (list[dict], optional): Contains multiple log config dict,
in which key means the data source name of log and value means the
statistic method and corresponding arguments used to count the
data source. Defaults to None
- If custom_cfg is None, all logs will be formatted via default
methods, such as smoothing loss by default window_size. If
custom_cfg is defined as a list of config dict, for example:
[dict(data_src=loss, method='mean', log_name='global_loss',
window_size='global')]. It means the log item ``loss`` will be
counted as global mean and additionally logged as ``global_loss``
(defined by ``log_name``). If ``log_name`` is not defined in
config dict, the original logged key will be overwritten.
- The original log item cannot be overwritten twice. Here is
an error example:
[dict(data_src=loss, method='mean', window_size='global'),
dict(data_src=loss, method='mean', window_size='epoch')].
Both log config dict in custom_cfg do not have ``log_name`` key,
which means the loss item will be overwritten twice.
- For those statistic methods with the ``window_size`` argument,
if ``by_epoch`` is set to False, ``windows_size`` should not be
`epoch` to statistics log value by epoch.
Examples:
>>> # `log_name` is defined, `loss_large_window` will be an additional
>>> # record.
>>> log_processor = dict(
>>> window_size=10,
>>> by_epoch=True,
>>> custom_cfg=[dict(data_src='loss',
>>> log_name='loss_large_window',
>>> method_name='mean',
>>> window_size=100)])
>>> # `log_name` is not defined. `loss` will be overwritten.
>>> log_processor = dict(
>>> window_size=10,
>>> by_epoch=True,
>>> custom_cfg=[dict(data_src='loss',
>>> method_name='mean',
>>> window_size=100)])
>>> # Record loss with different statistics methods.
>>> log_processor = dict(
>>> window_size=10,
>>> by_epoch=True,
>>> custom_cfg=[dict(data_src='loss',
>>> log_name='loss_large_window',
>>> method_name='mean',
>>> window_size=100),
>>> dict(data_src='loss',
>>> method_name='mean',
>>> window_size=100)])
>>> # Overwrite loss item twice will raise an error.
>>> log_processor = dict(
>>> window_size=10,
>>> by_epoch=True,
>>> custom_cfg=[dict(data_src='loss',
>>> method_name='mean',
>>> window_size=100),
>>> dict(data_src='loss',
>>> method_name='max',
>>> window_size=100)])
AssertionError
"""
def __init__(self,
window_size=10,
by_epoch=True,
custom_cfg: Optional[List[dict]] = None):
self.window_size = window_size
self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else []
self._check_custom_cfg()
def get_log_after_iter(self, runner, batch_idx: int,
mode: str) -> Tuple[dict, str]:
"""Format log string after training, validation or testing epoch.
Args:
runner (Runner): The runner of training phase.
batch_idx (int): The index of the current batch in the current
loop.
mode (str): Current mode of runner, train, test or val.
Return:
Tuple(dict, str): Formatted log dict/string which will be
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
"""
assert mode in ['train', 'test', 'val']
current_loop = self._get_cur_loop(runner, mode)
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
# `log_tag` will pop 'lr' and loop other keys to `log_str`.
log_tag = copy.deepcopy(tag)
# Record learning rate.
lr_str_list = []
for key, value in tag.items():
if key.startswith('lr'):
log_tag.pop(key)
lr_str_list.append(f'{key}: {value:.3e}')
lr_str = ' '.join(lr_str_list)
# Format log header.
# by_epoch == True
# train/val: Epoch [5][5/10] ...
# test: Epoch [5/10]
# by_epoch == False
# train: Epoch [5/10000] ... (divided by `max_iter`)
# val/test: Epoch [5/2000] ... (divided by length of dataloader)
if self.by_epoch:
if mode in ['train', 'val']:
cur_epoch = self._get_epoch(runner, mode)
log_str = (f'Epoch({mode}) [{cur_epoch}]'
f'[{cur_iter}/{len(current_loop.dataloader)}] ')
else:
log_str = (f'Epoch({mode}) '
f'[{cur_iter}/{len(current_loop.dataloader)}] ')
else:
if mode == 'train':
log_str = (f'Iter({mode}) '
f'[{cur_iter}/{runner.train_loop.max_iters}] ')
else:
log_str = (f'Iter({mode}) [{batch_idx+1}'
f'/{len(current_loop.dataloader)}] ')
# Concatenate lr, momentum string with log header.
log_str += f'{lr_str} '
# If IterTimerHook used in runner, eta, time, and data_time should be
# recorded.
if (all(item in tag for item in ['time', 'data_time'])
and 'eta' in runner.message_hub.runtime_info):
eta = runner.message_hub.get_info('eta')
eta_str = str(datetime.timedelta(seconds=int(eta)))
log_str += f'eta: {eta_str} '
log_str += (f'time: {tag["time"]:.3f} '
f'data_time: {tag["data_time"]:.3f} ')
# Pop recorded keys
log_tag.pop('time')
log_tag.pop('data_time')
# If cuda is available, the max memory occupied should be calculated.
if torch.cuda.is_available():
log_str += f'memory: {self._get_max_memory(runner)} '
# Loop left keys to fill `log_str`.
if mode in ('train', 'val'):
log_items = []
for name, val in log_tag.items():
if mode == 'val' and not name.startswith('loss'):
continue
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)
return tag, log_str
def get_log_after_epoch(self, runner, batch_idx: int,
mode: str) -> Tuple[dict, str]:
"""Format log string after validation or testing epoch.
Args:
runner (Runner): The runner of training phase.
batch_idx (int): The index of the current batch in the current
loop.
mode (str): Current mode of runner.
Return:
Tuple(dict, str): Formatted log dict/string which will be
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
"""
assert mode in [
'test', 'val'
], ('`_get_metric_log_str` only accept val or test mode, but got '
f'{mode}')
cur_loop = self._get_cur_loop(runner, mode)
dataloader_len = len(cur_loop.dataloader)
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
# tag is used to write log information to different backends.
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
# validation log string needs cur epoch/iteration and max
# epochs/iterations. test log string only needs length of test
# dataloader.
cur_iter = self._get_iter(runner, batch_idx)
if self.by_epoch:
if mode == 'val':
cur_epoch = self._get_epoch(runner, mode)
log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/'
f'{dataloader_len}] ')
else:
log_str = (
f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ')
else:
if mode == 'train':
log_str = (f'Iter({mode}) [{cur_iter}/'
f'{runner.train_loop.max_iters}] ')
else:
log_str = (
f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
log_items = []
for name, val in tag.items():
if name in ('time', 'data_time'):
continue
if isinstance(val, float):
val = f'{val:.4f}'
log_items.append(f'{name}: {val}')
log_str += ' '.join(log_items)
return tag, log_str
def _collect_scalars(self, custom_cfg: List[dict], runner,
mode: str) -> dict:
"""Collect log information to compose a dict according to mode.
Args:
custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int
``window_size``.
runner (Runner): The runner of the training process.
mode (str): 'train' or 'val', which means the prefix attached by
runner.
Returns:
dict: Statistical values of logs.
"""
tag = OrderedDict()
# history_scalars of train/val/test phase.
history_scalars = runner.message_hub.log_scalars
# corresponding mode history_scalars
mode_history_scalars = OrderedDict()
# extract log scalars and remove prefix to `mode_history_scalars`
# according to mode.
for prefix_key, log_buffer in history_scalars.items():
if prefix_key.startswith(mode):
key = prefix_key.split('/')[-1]
mode_history_scalars[key] = log_buffer
for key in mode_history_scalars:
# Update the latest learning rate and smoothed time logs.
if key.startswith('loss'):
tag[key] = mode_history_scalars[key].mean(self.window_size)
else:
# Default statistic method is current.
tag[key] = mode_history_scalars[key].current()
# Update custom keys.
for log_cfg in custom_cfg:
data_src = log_cfg.pop('data_src')
if 'log_name' in log_cfg:
log_name = log_cfg.pop('log_name')
else:
log_name = data_src
# log item in custom_cfg could only exist in train or val
# mode.
if data_src in mode_history_scalars:
tag[log_name] = mode_history_scalars[data_src].statistics(
**log_cfg)
return tag
def _check_custom_cfg(self) -> None:
"""Check the legality of ``self.custom_cfg``."""
def _check_window_size():
for log_cfg in self.custom_cfg:
if not self.by_epoch:
assert log_cfg['window_size'] != 'epoch', \
'window_size cannot be epoch if LoggerHook.by_epoch' \
' is False.'
def _check_repeated_log_name():
check_dict = dict()
# The `log_name` of the same data_src should not be repeated.
# If `log_name` is not specified, `data_src` will be overwritten.
# But only allowed to be overwritten once.
for log_cfg in self.custom_cfg:
assert 'data_src' in log_cfg
data_src = log_cfg['data_src']
log_name = log_cfg.get('log_name', data_src)
check_dict.setdefault(data_src,
dict(log_names=set(), log_counts=0))
check_dict[data_src]['log_names'].add(log_name)
check_dict[data_src]['log_counts'] += 1
assert (len(
check_dict[data_src]
['log_names']) == check_dict[data_src]['log_counts']), (
f'If you want to statistic {data_src} with multiple '
'statistics method, please check `log_name` is unique'
f'and {data_src} will not be overwritten twice. See '
f'more information in the docstring of `LogProcessor`')
_check_repeated_log_name()
_check_window_size()
def _parse_windows_size(self, runner, batch_idx: int) -> list:
"""Parse window_size defined in custom_cfg to int value.
Args:
runner (Runner): The runner of the training process.
batch_idx (int): The iteration index of current dataloader.
"""
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
for log_cfg in custom_cfg_copy:
window_size = log_cfg.get('window_size', None)
if window_size is None or isinstance(window_size, int):
continue
elif window_size == 'epoch':
log_cfg['window_size'] = batch_idx + 1
elif window_size == 'global':
log_cfg['window_size'] = runner.iter + 1
else:
raise TypeError(
'window_size should be int, epoch or global, but got '
f'invalid {window_size}')
return custom_cfg_copy
def _get_max_memory(self, runner) -> int:
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
for a given device.
Args:
runner (Runner): The runner of the training process.
Returns:
The maximum GPU memory occupied by tensors in megabytes for a given
device.
"""
device = getattr(runner.model, 'output_device', None)
mem = torch.cuda.max_memory_allocated(device=device)
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
dtype=torch.int,
device=device)
torch.cuda.reset_peak_memory_stats()
return int(mem_mb.item())
def _get_iter(self, runner, batch_idx: int = None) -> int:
"""Get current training iteration step.
Args:
runner (Runner): The runner of the training process.
batch_idx (int, optional): The interaction index of current
dataloader. Defaults to None.
Returns:
int: The current global iter or inner iter.
"""
if self.by_epoch and batch_idx:
current_iter = batch_idx + 1
else:
current_iter = runner.iter + 1
return current_iter
def _get_epoch(self, runner, mode: str) -> int:
"""Get current epoch according to mode.
Args:
runner (Runner): The runner of the training/validation process.
mode (str): Current mode of runner, "train" or "val".
Returns:
int: The current epoch.
"""
if mode == 'train':
epoch = runner.epoch + 1
elif mode == 'val':
# normal val mode
# runner.epoch += 1 has been done before validation
epoch = runner.epoch
else:
raise ValueError(
f"runner mode should be 'train' or 'val', but got {mode}")
return epoch
def _get_cur_loop(self, runner, mode: str):
"""Get current loop according to mode.
Args:
runner (Runner): The runner of the training/validation/testing
process.
mode (str): Current mode of runner, "train", "val" or test.
Returns:
BaseLoop: Current loop of runner.
"""
# returns type hint will occur circular import
if mode == 'train':
return runner.train_loop
elif mode == 'val':
return runner.val_loop
else:
return runner.test_loop

View File

@ -32,15 +32,15 @@ class MMFormatter(logging.Formatter):
info_prefix = self._get_prefix('INFO', color)
debug_prefix = self._get_prefix('DEBUG', color)
# Config output format.
self.err_format = f'%(asctime)s - %(name)s - {error_prefix} - ' \
f'%(pathname)s - %(funcName)s - %(lineno)d - ' \
'%(message)s'
self.warn_format = f'%(asctime)s - %(name)s - {warn_prefix} - %(' \
'message)s'
self.info_format = f'%(asctime)s - %(name)s - {info_prefix} - %(' \
'message)s'
self.debug_format = f'%(asctime)s - %(name)s - {debug_prefix} - %(' \
'message)s'
self.err_format = (f'%(asctime)s - %(name)s - {error_prefix} - '
'%(pathname)s - %(funcName)s - %(lineno)d - '
'%(message)s')
self.warn_format = (f'%(asctime)s - %(name)s - {warn_prefix} - %('
'message)s')
self.info_format = (f'%(asctime)s - %(name)s - {info_prefix} - %('
'message)s')
self.debug_format = (f'%(asctime)s - %(name)s - {debug_prefix} - %('
'message)s')
def _get_prefix(self, level: str, color: bool) -> str:
"""Get the prefix of the target log level.

View File

@ -25,7 +25,7 @@ from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only,
sync_random_seed)
from mmengine.evaluator import Evaluator
from mmengine.hooks import Hook
from mmengine.logging import MessageHub, MMLogger
from mmengine.logging import LogProcessor, MessageHub, MMLogger
from mmengine.model import is_model_wrapper
from mmengine.optim import _ParamScheduler, build_optimizer
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
@ -127,6 +127,8 @@ class Runner:
non-distributed environment will be launched.
env_cfg (dict): A dict used for setting environment. Defaults to
dict(dist_cfg=dict(backend='nccl')).
log_processor (dict, optional): A processor to format logs. Defaults to
None.
log_level (int or str): The log level of MMLogger handlers.
Defaults to 'INFO'.
writer (ComposedWriter or dict, optional): A ComposedWriter object or a
@ -184,6 +186,7 @@ class Runner:
param_scheduler=dict(type='ParamSchedulerHook')),
launcher='none',
env_cfg=dict(dist_cfg=dict(backend='nccl')),
log_processor=dict(window_size=20),
writer=dict(
name='composed_writer',
writers=[dict(type='LocalWriter', save_dir='temp_dir')])
@ -218,6 +221,7 @@ class Runner:
launcher: str = 'none',
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
log_level: str = 'INFO',
log_processor: Optional[Dict] = None,
writer: Optional[Union[ComposedWriter, Dict]] = None,
default_scope: Optional[str] = None,
randomness: Dict = dict(seed=None),
@ -310,6 +314,10 @@ class Runner:
else:
self._experiment_name = self.timestamp
log_processor = dict() if log_processor is None else log_processor
self.log_processor = LogProcessor(**log_processor)
# Since `get_instance` could return any subclass of ManagerMixin. The
# corresponding attribute needs a type hint.
self.logger = self.build_logger(log_level=log_level)
# Build `message_hub` for communication among components.
# `message_hub` can store log scalars (loss, learning rate) and
@ -385,6 +393,7 @@ class Runner:
resume=cfg.get('resume', False),
launcher=cfg.get('launcher', 'none'),
env_cfg=cfg.get('env_cfg'), # type: ignore
log_processor=cfg.get('log_processor'),
log_level=cfg.get('log_level', 'INFO'),
writer=cfg.get('writer'),
default_scope=cfg.get('default_scope'),

View File

@ -157,18 +157,17 @@ class TestHook:
def test_end_of_epoch(self):
hook = Hook()
runner = Mock()
# last inner iter
batch_idx = 1
runner.cur_dataloader.__len__ = Mock(return_value=2)
runner.cur_dataloader.__len__ = Mock(return_value=2)
return_val = hook.end_of_epoch(runner, batch_idx)
dataloader = Mock()
dataloader.__len__ = Mock(return_value=2)
return_val = hook.end_of_epoch(dataloader, batch_idx)
assert return_val
# not the last inner iter
batch_idx = 0
return_val = hook.end_of_epoch(runner, batch_idx)
return_val = hook.end_of_epoch(dataloader, batch_idx)
assert not return_val
def test_is_last_train_epoch(self):

View File

@ -1,29 +1,70 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from unittest import TestCase
from unittest.mock import MagicMock, Mock, patch
from mmengine.hooks import IterTimerHook
from mmengine.logging import MessageHub
class TestIterTimerHook:
def time_patch():
if not hasattr(time_patch, 'time'):
time_patch.time = 0
else:
time_patch.time += 1
return time_patch.time
class TestIterTimerHook(TestCase):
def setUp(self) -> None:
self.hook = IterTimerHook()
def test_init(self):
assert self.hook.time_sec_tot == 0
assert self.hook.start_iter == 0
def test_before_run(self):
runner = MagicMock()
runner.iter = 1
self.hook.before_run(runner)
assert self.hook.start_iter == 1
def test_before_epoch(self):
hook = IterTimerHook()
runner = Mock()
hook._before_epoch(runner)
assert isinstance(hook.t, float)
self.hook._before_epoch(runner)
assert isinstance(self.hook.t, float)
@patch('time.time', MagicMock(return_value=1))
def test_before_iter(self):
hook = IterTimerHook()
runner = Mock()
runner = MagicMock()
runner.log_buffer = dict()
hook._before_epoch(runner)
hook._before_iter(runner, 0)
runner.message_hub.update_scalar.assert_called()
self.hook._before_epoch(runner)
for mode in ('train', 'val', 'test'):
self.hook._before_iter(runner, batch_idx=1, mode=mode)
runner.message_hub.update_scalar.assert_called_with(
f'{mode}/data_time', 0)
@patch('time.time', time_patch)
def test_after_iter(self):
hook = IterTimerHook()
runner = Mock()
runner = MagicMock()
runner.log_buffer = dict()
hook._before_epoch(runner)
hook._after_iter(runner, 0)
runner.log_processor.window_size = 10
runner.train_loop.max_iters = 100
runner.iter = 0
runner.test_loop.dataloader = [0] * 20
runner.val_loop.dataloader = [0] * 20
self.hook._before_epoch(runner)
self.hook.before_run(runner)
self.hook._after_iter(runner, batch_idx=1)
runner.message_hub.update_scalar.assert_called()
runner.message_hub.get_log.assert_not_called()
runner.message_hub.update_info.assert_not_called()
runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
runner.iter = 9
# eta = (100 - 10) / 1
self.hook._after_iter(runner, batch_idx=89)
assert runner.message_hub.get_info('eta') == 90
self.hook._after_iter(runner, batch_idx=9, mode='val')
assert runner.message_hub.get_info('eta') == 10
self.hook._after_iter(runner, batch_idx=19, mode='test')
assert runner.message_hub.get_info('eta') == 0

View File

@ -1,13 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
import datetime
import logging
import os.path as osp
import sys
from collections import OrderedDict
from unittest.mock import MagicMock, patch
from unittest.mock import MagicMock
import pytest
import torch
from mmengine.fileio.file_client import HardDiskBackend
from mmengine.hooks import LoggerHook
@ -17,11 +12,8 @@ class TestLoggerHook:
def test_init(self):
logger_hook = LoggerHook(out_dir='tmp.txt')
assert logger_hook.by_epoch
assert logger_hook.interval == 10
assert not logger_hook.custom_keys
assert logger_hook.ignore_last
assert logger_hook.time_sec_tot == 0
assert logger_hook.interval_exp_name == 1000
assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
assert logger_hook.keep_local
@ -30,22 +22,7 @@ class TestLoggerHook:
# out_dir should be None or string or tuple of string.
with pytest.raises(TypeError):
LoggerHook(out_dir=1)
# time cannot be overwritten.
with pytest.raises(AssertionError):
LoggerHook(custom_keys=dict(time=dict(method='max')))
LoggerHook(
custom_keys=dict(time=[
dict(method='max', log_name='time_max'),
dict(method='min', log_name='time_min')
]))
# Epoch window_size cannot be used when `LoggerHook.by_epoch=False`
with pytest.raises(AssertionError):
LoggerHook(
by_epoch=False,
custom_keys=dict(
time=dict(
method='max', log_name='time_max',
window_size='epoch')))
with pytest.raises(ValueError):
LoggerHook(file_client_args=dict(enable_mc=True))
@ -60,20 +37,23 @@ class TestLoggerHook:
assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
assert logger_hook.json_log_path == osp.join('work_dir',
'timestamp.log.json')
assert logger_hook.start_iter == runner.iter
runner.writer.add_params.assert_called()
def test_after_run(self, tmp_path):
# Test
out_dir = tmp_path / 'out_dir'
out_dir.mkdir()
work_dir = tmp_path / 'work_dir'
work_dir.mkdir()
work_dir_json = work_dir / 'tmp.log.json'
json_f = open(work_dir_json, 'w')
json_f.close()
runner = MagicMock()
runner.work_dir = work_dir
# Test without out_dir.
logger_hook = LoggerHook()
logger_hook.after_run(runner)
# Test with out_dir and make sure json file has been moved to out_dir.
json_f = open(work_dir_json, 'w')
json_f.close()
logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
logger_hook.out_dir = str(out_dir)
logger_hook.after_run(runner)
@ -84,274 +64,83 @@ class TestLoggerHook:
def test_after_train_iter(self):
# Test LoggerHook by iter.
runner = MagicMock()
runner.iter = 10
batch_idx = 5
logger_hook = LoggerHook(by_epoch=False)
logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook = LoggerHook()
logger_hook.after_train_iter(runner, batch_idx=5)
# `cur_iter=10+1`, which cannot be exact division by
# `logger_hook.interval`
logger_hook._log_train.assert_not_called()
runner.iter = 9
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
runner.log_processor.get_log_after_iter.assert_not_called()
logger_hook.after_train_iter(runner, batch_idx=9)
runner.log_processor.get_log_after_iter.assert_called()
# Test LoggerHook by epoch.
logger_hook = LoggerHook(by_epoch=True)
logger_hook._log_train = MagicMock()
# Only `runner.inner_iter` will work.
runner.iter = 9
batch_idx = 10
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_not_called()
batch_idx = 9
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
logger_hook = LoggerHook()
runner = MagicMock()
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
# Only `batch_idx` will work.
logger_hook.after_train_iter(runner, batch_idx=10)
runner.log_processor.get_log_after_iter.assert_not_called()
logger_hook.after_train_iter(runner, batch_idx=9)
runner.log_processor.get_log_after_iter.assert_called()
# Test end of the epoch.
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
logger_hook._log_train = MagicMock()
runner.cur_dataloader = [0] * 5
batch_idx = 4
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
logger_hook._log_train.assert_called()
runner = MagicMock()
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook = LoggerHook(ignore_last=False)
runner.train_dataloader = [0] * 5
logger_hook.after_train_iter(runner, batch_idx=4)
runner.log_processor.get_log_after_iter.assert_called()
# Test print exp_name
runner = MagicMock()
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
runner.meta = dict(exp_name='retinanet')
logger_hook = LoggerHook()
runner.logger = MagicMock()
logger_hook._log_train = MagicMock()
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
runner.logger.info.assert_called_with(
f'Exp name: {runner.meta["exp_name"]}')
logger_hook = LoggerHook()
logger_hook.after_train_iter(runner, batch_idx=999)
runner.logger.info.assert_called()
def test_after_val_epoch(self):
logger_hook = LoggerHook()
runner = MagicMock()
logger_hook._log_val = MagicMock()
runner.log_processor.get_log_after_epoch = MagicMock(
return_value=(dict(), 'string'))
logger_hook.after_val_epoch(runner)
logger_hook._log_val.assert_called()
runner.log_processor.get_log_after_epoch.assert_called()
runner.logger.info.assert_called()
runner.writer.add_scalars.assert_called()
@pytest.mark.parametrize('by_epoch', [True, False])
def test_log_train(self, by_epoch, capsys):
runner = self._setup_runner()
runner.meta = dict(exp_name='retinanet')
# Prepare LoggerHook
logger_hook = LoggerHook(by_epoch=by_epoch)
logger_hook._inner_iter = 1
logger_hook.writer = MagicMock()
logger_hook.time_sec_tot = 1000
logger_hook.start_iter = 0
logger_hook._get_max_memory = MagicMock(return_value='100')
logger_hook.json_log_path = 'tmp.json'
# Prepare training information.
train_infos = dict(
lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0)
logger_hook._collect_info = MagicMock(return_value=train_infos)
logger_hook._log_train(runner)
# Verify that the correct variables have been written.
runner.writer.add_scalars.assert_called_with(
train_infos, step=11, file_path='tmp.json')
# Verify that the correct context have been logged.
out, _ = capsys.readouterr()
time_avg = logger_hook.time_sec_tot / (
runner.iter + 1 - logger_hook.start_iter)
eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1)
eta_str = str(datetime.timedelta(seconds=int(eta_second)))
if by_epoch:
if torch.cuda.is_available():
log_str = 'Epoch [2][2/5]\t' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f'memory: 100, ' \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
else:
log_str = 'Epoch [2][2/5]\t' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
assert out == log_str
else:
if torch.cuda.is_available():
log_str = 'Iter [11/50]\t' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f'memory: 100, ' \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
else:
log_str = 'Iter [11/50]\t' \
f"lr: {train_infos['lr']:.3e} " \
f"momentum: {train_infos['momentum']:.3e}, " \
f'eta: {eta_str}, ' \
f"time: {train_infos['time']:.3f}, " \
f"data_time: {train_infos['data_time']:.3f}, " \
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
assert out == log_str
@pytest.mark.parametrize('by_epoch', [True, False])
def test_log_val(self, by_epoch, capsys):
runner = self._setup_runner()
# Prepare LoggerHook.
logger_hook = LoggerHook(by_epoch=by_epoch)
logger_hook.json_log_path = 'tmp.json'
metric = dict(accuracy=0.9, data_time=1.0)
logger_hook._collect_info = MagicMock(return_value=metric)
logger_hook._log_val(runner)
# Verify that the correct context have been logged.
out, _ = capsys.readouterr()
runner.writer.add_scalars.assert_called_with(
metric, step=11, file_path='tmp.json')
if by_epoch:
assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \
'data_time: 1.0000\n'
else:
assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \
'data_time: 1.0000\n'
def test_get_window_size(self):
runner = self._setup_runner()
logger_hook = LoggerHook()
logger_hook._inner_iter = 1
# Test get window size by name.
assert logger_hook._get_window_size(runner, 'epoch') == 2
assert logger_hook._get_window_size(runner, 'global') == 11
assert logger_hook._get_window_size(runner, 10) == 10
# Window size must equal to `logger_hook.interval`.
with pytest.raises(AssertionError):
logger_hook._get_window_size(runner, 20)
with pytest.raises(ValueError):
logger_hook._get_window_size(runner, 'unknwon')
def test_parse_custom_keys(self):
tag = OrderedDict()
runner = self._setup_runner()
log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock())
cfg_dict = dict(
lr=dict(method='min'),
loss=[
dict(method='min', window_size='global'),
dict(method='max', log_name='loss_max')
])
logger_hook = LoggerHook()
for log_key, log_cfg in cfg_dict.items():
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
log_buffers, tag)
assert list(tag) == ['lr', 'loss', 'loss_max']
assert log_buffers['lr'].min.assert_called
assert log_buffers['loss'].min.assert_called
assert log_buffers['loss'].max.assert_called
assert log_buffers['loss'].mean.assert_called
# `log_name` Cannot be repeated.
with pytest.raises(KeyError):
cfg_dict = dict(loss=[
dict(method='min', window_size='global'),
dict(method='max', log_name='loss_max'),
dict(method='mean', log_name='loss_max')
])
logger_hook.custom_keys = cfg_dict
for log_key, log_cfg in cfg_dict.items():
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
log_buffers, tag)
# `log_key` cannot be overwritten multiple times.
with pytest.raises(AssertionError):
cfg_dict = dict(loss=[
dict(method='min', window_size='global'),
dict(method='max'),
])
logger_hook.custom_keys = cfg_dict
for log_key, log_cfg in cfg_dict.items():
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
log_buffers, tag)
def test_collect_info(self):
runner = self._setup_runner()
logger_hook = LoggerHook(
custom_keys=dict(time=dict(method='max', log_name='time_max')))
logger_hook._parse_custom_keys = MagicMock()
# Collect with prefix.
log_buffers = {
'train/time': MagicMock(),
'lr': MagicMock(),
'train/loss_cls': MagicMock(),
'val/metric': MagicMock()
}
runner.message_hub.log_scalars = log_buffers
tag = logger_hook._collect_info(runner, mode='train')
# Test parse custom_keys
logger_hook._parse_custom_keys.assert_called()
# Test training key in tag.
assert list(tag.keys()) == ['time', 'loss_cls']
# Test statistics lr with `current`, loss and time with 'mean'
log_buffers['train/time'].mean.assert_called()
log_buffers['train/loss_cls'].mean.assert_called()
log_buffers['train/loss_cls'].current.assert_not_called()
tag = logger_hook._collect_info(runner, mode='val')
assert list(tag.keys()) == ['metric']
log_buffers['val/metric'].current.assert_called()
@patch('torch.cuda.max_memory_allocated', MagicMock())
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
def test_get_max_memory(self):
def test_after_test_epoch(self):
logger_hook = LoggerHook()
runner = MagicMock()
runner.world_size = 1
runner.model = torch.nn.Linear(1, 1)
logger_hook._get_max_memory(runner)
torch.cuda.max_memory_allocated.assert_called()
torch.cuda.reset_peak_memory_stats.assert_called()
runner.log_processor.get_log_after_epoch = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook.after_test_epoch(runner)
runner.log_processor.get_log_after_epoch.assert_called()
runner.logger.info.assert_called()
def test_get_iter(self):
runner = self._setup_runner()
def test_after_val_iter(self):
logger_hook = LoggerHook()
logger_hook._inner_iter = 1
# Get global iter when `inner_iter=False`
iter = logger_hook._get_iter(runner)
assert iter == 11
# Get inner iter
iter = logger_hook._get_iter(runner, inner_iter=True)
assert iter == 2
# Still get global iter when `logger_hook.by_epoch==False`
logger_hook.by_epoch = False
iter = logger_hook._get_iter(runner, inner_iter=True)
assert iter == 11
def test_get_epoch(self):
runner = self._setup_runner()
logger_hook = LoggerHook()
epoch = logger_hook._get_epoch(runner, 'train')
assert epoch == 2
epoch = logger_hook._get_epoch(runner, 'val')
assert epoch == 1
with pytest.raises(ValueError):
logger_hook._get_epoch(runner, 'test')
def _setup_runner(self):
runner = MagicMock()
runner.epoch = 1
runner.cur_dataloader = [0] * 5
runner.iter = 10
runner.train_loop.max_iters = 50
logger = logging.getLogger()
logger.setLevel(logging.INFO)
for handler in logger.handlers:
if not isinstance(handler, logging.StreamHandler):
continue
else:
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
runner.logger = logger
runner.message_hub = MagicMock()
runner.composed_wirter = MagicMock()
return runner
runner.iter = 0
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook.after_val_iter(runner, 1)
runner.log_processor.get_log_after_iter.assert_not_called()
logger_hook.after_val_iter(runner, 9)
runner.log_processor.get_log_after_iter.assert_called()
def test_after_test_iter(self):
logger_hook = LoggerHook()
runner = MagicMock()
runner.iter = 0
runner.log_processor.get_log_after_iter = MagicMock(
return_value=(dict(), 'log_str'))
logger_hook.after_test_iter(runner, 1)
runner.log_processor.get_log_after_iter.assert_not_called()
logger_hook.after_test_iter(runner, 9)
runner.log_processor.get_log_after_iter.assert_called()

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from unittest.mock import MagicMock, Mock
import torch
from torch import nn
@ -45,7 +45,7 @@ class TestOptimizerHook:
model = Model()
x = torch.rand(1, 1, 3, 3)
dummy_runner = Mock()
dummy_runner = MagicMock()
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
dummy_runner.optimizer.step = Mock(return_value=None)
dummy_runner.model = model

View File

@ -0,0 +1,242 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest.mock import MagicMock, patch
import pytest
import torch
from mmengine.logging import LogProcessor, MessageHub, MMLogger
class TestLogProcessor:
def test_init(self):
log_processor = LogProcessor(
window_size=10, by_epoch=True, custom_cfg=None)
assert log_processor.by_epoch
assert log_processor.window_size == 10
assert log_processor.custom_cfg == []
def test_check_custom_cfg(self):
# ``by_epoch==False`` and `window_size='epoch'` in log config will
# raise AssertionError.
custom_cfg = [dict(data_src='loss', window_size='epoch')]
with pytest.raises(AssertionError):
LogProcessor(by_epoch=False, custom_cfg=custom_cfg)
# Duplicate log_name will raise AssertionError.
custom_cfg = [
dict(data_src='loss', log_name='loss_1'),
dict(data_src='loss', log_name='loss_1')
]
with pytest.raises(AssertionError):
LogProcessor(custom_cfg=custom_cfg)
# Overwrite loss item twice will raise AssertionError.
custom_cfg = [dict(data_src='loss'), dict(data_src='loss')]
with pytest.raises(AssertionError):
LogProcessor(custom_cfg=custom_cfg)
custom_cfg = [
dict(data_src='loss_cls', window_size=100, method_name='min'),
dict(data_src='loss', log_name='loss_min', method_name='max'),
dict(data_src='loss', log_name='loss_max', method_name='max')
]
LogProcessor(custom_cfg=custom_cfg)
def test_parse_windows_size(self):
log_processor = LogProcessor()
# Test parse 'epoch' window_size.
log_processor.custom_cfg = [
dict(data_src='loss_cls', window_size='epoch')
]
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
assert custom_cfg[0]['window_size'] == 2
# Test parse 'global' window_size.
log_processor.custom_cfg = [
dict(data_src='loss_cls', window_size='global')
]
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
assert custom_cfg[0]['window_size'] == 11
# Test parse int window_size
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=100)]
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
assert custom_cfg[0]['window_size'] == 100
# Invalid type window_size will raise TypeError.
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=[])]
with pytest.raises(TypeError):
log_processor._parse_windows_size(custom_cfg, self.runner)
@pytest.mark.parametrize('by_epoch,mode',
([True, 'train'], [False, 'train'], [True, 'val'],
[False, 'val'], [True, 'test'], [False, 'test']))
def test_get_log_after_iter(self, by_epoch, mode):
# Prepare LoggerHook
log_processor = LogProcessor(by_epoch=by_epoch)
log_processor._get_max_memory = MagicMock(return_value='100')
eta = 40
self.runner.message_hub.update_info('eta', eta)
# Prepare training information.
if mode == 'train':
train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0)
else:
train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0)
log_processor._collect_scalars = MagicMock(return_value=train_logs)
tag, out = log_processor.get_log_after_iter(self.runner, 1, mode)
# Verify that the correct context have been logged.
cur_loop = log_processor._get_cur_loop(self.runner, mode)
if by_epoch:
if mode in ['train', 'val']:
cur_epoch = log_processor._get_epoch(self.runner, mode)
log_str = (f'Epoch({mode}) [{cur_epoch}][2/'
f'{len(cur_loop.dataloader)}] ')
else:
log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ')
if mode == 'train':
log_str += f"lr: {train_logs['lr']:.3e} "
else:
log_str += ' '
log_str += (f'eta: 0:00:40 '
f"time: {train_logs['time']:.3f} "
f"data_time: {train_logs['data_time']:.3f} ")
if torch.cuda.is_available():
log_str += 'memory: 100 '
if mode == 'train':
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
assert out == log_str
else:
if mode == 'train':
max_iters = self.runner.train_loop.max_iters
log_str = f'Iter({mode}) [11/{max_iters}] '
else:
max_iters = len(cur_loop.dataloader)
log_str = f'Iter({mode}) [2/{max_iters}] '
if mode == 'train':
log_str += f"lr: {train_logs['lr']:.3e} "
else:
log_str += ' '
log_str += (f'eta: 0:00:40 '
f"time: {train_logs['time']:.3f} "
f"data_time: {train_logs['data_time']:.3f} ")
if torch.cuda.is_available():
log_str += 'memory: 100 '
if mode == 'train':
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
assert out == log_str
@pytest.mark.parametrize(
'by_epoch,mode',
([True, 'val'], [False, 'val'], [True, 'test'], [False, 'test']))
def test_log_val(self, by_epoch, mode):
# Prepare LoggerHook
log_processor = LogProcessor(by_epoch=by_epoch)
# Prepare validation information.
val_logs = dict(accuracy=0.9, data_time=1.0)
log_processor._collect_scalars = MagicMock(return_value=val_logs)
_, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
if by_epoch:
if mode == 'test':
assert out == 'Epoch(test) [5/5] accuracy: 0.9000'
else:
assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000'
else:
if mode == 'test':
assert out == 'Iter(test) [5/5] accuracy: 0.9000'
else:
assert out == 'Iter(val) [10/10] accuracy: 0.9000'
def test_collect_scalars(self):
custom_cfg = [
dict(data_src='time', method_name='mean', window_size=100),
dict(data_src='time', method_name='max', log_name='time_max')
]
logger_hook = LogProcessor(custom_cfg=custom_cfg)
# Collect with prefix.
log_scalars = {
'train/time': MagicMock(),
'lr': MagicMock(),
'train/loss_cls': MagicMock(),
'val/metric': MagicMock()
}
self.runner.message_hub._log_scalars = log_scalars
tag = logger_hook._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='train')
# Test training key in tag.
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
# Test statistics lr with `current`, loss and time with 'mean'
log_scalars['train/time'].statistics.assert_called_with(
method_name='max')
log_scalars['train/loss_cls'].mean.assert_called()
tag = logger_hook._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='val')
assert list(tag.keys()) == ['metric']
log_scalars['val/metric'].current.assert_called()
@patch('torch.cuda.max_memory_allocated', MagicMock())
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
def test_get_max_memory(self):
logger_hook = LogProcessor()
runner = MagicMock()
runner.world_size = 1
runner.model = torch.nn.Linear(1, 1)
logger_hook._get_max_memory(runner)
torch.cuda.max_memory_allocated.assert_called()
torch.cuda.reset_peak_memory_stats.assert_called()
def test_get_iter(self):
log_processor = LogProcessor()
# Get global iter when `inner_iter=False`
iter = log_processor._get_iter(self.runner)
assert iter == 11
# Get inner iter
iter = log_processor._get_iter(self.runner, 1)
assert iter == 2
# Still get global iter when `logger_hook.by_epoch==False`
log_processor.by_epoch = False
iter = log_processor._get_iter(self.runner, 1)
assert iter == 11
def test_get_epoch(self):
log_processor = LogProcessor()
epoch = log_processor._get_epoch(self.runner, 'train')
assert epoch == 2
epoch = log_processor._get_epoch(self.runner, 'val')
assert epoch == 1
with pytest.raises(ValueError):
log_processor._get_epoch(self.runner, 'test')
def test_get_cur_loop(self):
log_processor = LogProcessor()
loop = log_processor._get_cur_loop(self.runner, 'train')
assert len(loop.dataloader) == 20
loop = log_processor._get_cur_loop(self.runner, 'val')
assert len(loop.dataloader) == 10
loop = log_processor._get_cur_loop(self.runner, 'test')
assert len(loop.dataloader) == 5
def setup(self):
runner = MagicMock()
runner.epoch = 1
runner.iter = 10
runner.train_loop.max_iters = 50
runner.train_loop.dataloader = [0] * 20
runner.val_loop.dataloader = [0] * 10
runner.test_loop.dataloader = [0] * 5
logger = MMLogger.get_instance('log_processor_test')
runner.logger = logger
message_hub = MessageHub.get_instance('log_processor_test')
for i in range(10):
message_hub.update_scalar('train/loss', 10 - i)
for i in range(10):
message_hub.update_scalar('val/acc', i * 0.1)
runner.message_hub = message_hub
self.runner = runner

View File

@ -221,7 +221,7 @@ class TestRunner(TestCase):
self.iter_based_cfg.default_hooks = dict(
timer=dict(type='IterTimerHook'),
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
logger=dict(type='LoggerHook', by_epoch=False),
logger=dict(type='LoggerHook'),
optimizer=dict(type='OptimizerHook', grad_clip=None),
param_scheduler=dict(type='ParamSchedulerHook'))