[Refactor] Refine LoggerHook (#155)
* rename global accessible and intergration get_sintance and create_instance * move ManagerMixin to utils * fix as docstring and seporate get_instance to get_instance and get_current_instance * fix lint * fix docstring, rename and move test_global_meta * rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume * refine MMLogger timestamp, update unit test * MMLogger add logger_name arguments * Fix docstring * Add LogProcessor and some unit test * update unit test * complete LogProcessor unit test * refine LoggerHook * solve circle import * change default logger_name to mmengine * refactor eta * Fix docstring comment and unitt test * Fix with runner * fix docstring fix docstring * fix docstring * Add by_epoch attribute to LoggerHook and fix docstring * Please mypy and fix comment * remove \ in MMLogger * Fix lint * roll back pre-commit-hook * Fix hook unit test * Fix comments * remove \t in log and add docstring * Fix as comment * should not accept other arguments if corresponding instance has been created * fix logging ddp file saving * fix logging ddp file saving * move log processor to logging * move log processor to logging * remove current datalaoder * fix docstring * fix unit test * add learing rate in messagehub * Support output training/validation/testing message after iterations/epochs * fix docstring * Fix IterBasedRunner log string * Fix IterBasedRunner log string * Support parse validation loss in log processorpull/197/head
parent
4274679376
commit
e2a2b0438e
|
@ -358,11 +358,11 @@ class Hook:
|
|||
"""
|
||||
return (runner.epoch + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_inner_iters(self, inner_iter: int, n: int) -> bool:
|
||||
def every_n_inner_iters(self, batch_idx: int, n: int) -> bool:
|
||||
"""Test whether current inner iteration can be evenly divided by n.
|
||||
|
||||
Args:
|
||||
inner_iter (int): Current inner_iter of the training, validation
|
||||
batch_idx (int): Current batch index of the training, validation
|
||||
or testing loop.
|
||||
n (int): Whether current inner iteration can be evenly
|
||||
divided by n.
|
||||
|
@ -371,7 +371,7 @@ class Hook:
|
|||
bool: Whether current inner iteration can be evenly
|
||||
divided by n.
|
||||
"""
|
||||
return (inner_iter + 1) % n == 0 if n > 0 else False
|
||||
return (batch_idx + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def every_n_iters(self, runner, n: int) -> bool:
|
||||
"""Test whether current iteration can be evenly divided by n.
|
||||
|
@ -387,19 +387,18 @@ class Hook:
|
|||
"""
|
||||
return (runner.iter + 1) % n == 0 if n > 0 else False
|
||||
|
||||
def end_of_epoch(self, runner, batch_idx: int) -> bool:
|
||||
def end_of_epoch(self, dataloader, batch_idx: int) -> bool:
|
||||
"""Check whether the current iteration reaches the last iteration of
|
||||
current dataloader.
|
||||
the dataloader.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
dataloader (Dataloader): The dataloader of the training,
|
||||
validation or testing process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
|
||||
Returns:
|
||||
bool: Whether reaches the end of current epoch or not.
|
||||
"""
|
||||
return batch_idx + 1 == len(runner.cur_dataloader)
|
||||
return batch_idx + 1 == len(dataloader)
|
||||
|
||||
def is_last_train_epoch(self, runner) -> bool:
|
||||
"""Test whether current epoch is the last train epoch.
|
||||
|
@ -418,10 +417,10 @@ class Hook:
|
|||
Args:
|
||||
runner (Runner): The runner of the training, validation or testing
|
||||
process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
|
||||
Returns:
|
||||
bool: Whether current iteration is the last iteration.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
if mode == 'train':
|
||||
return runner.iter + 1 == runner.train_loop.max_iters
|
||||
|
|
|
@ -18,11 +18,25 @@ class IterTimerHook(Hook):
|
|||
|
||||
priority = 'NORMAL'
|
||||
|
||||
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""Record time flag before start a epoch.
|
||||
def __init__(self):
|
||||
self.time_sec_tot = 0
|
||||
self.start_iter = 0
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
"""Synchronize the number of iterations with the runner.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner: The runner of the training, validation or testing
|
||||
process.
|
||||
"""
|
||||
self.start_iter = runner.iter
|
||||
|
||||
def _before_epoch(self, runner, mode: str = 'train') -> None:
|
||||
"""Record timestamp before start an epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training validation and
|
||||
testing process.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
self.t = time.time()
|
||||
|
@ -32,16 +46,18 @@ class IterTimerHook(Hook):
|
|||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""Logging time for loading data and update the time flag.
|
||||
"""Calculating time for loading data and updating "data_time"
|
||||
``HistoryBuffer`` of ``runner.message_hub``.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the training, validation and
|
||||
testing process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
# Update data loading time in `runner.message_hub`.
|
||||
runner.message_hub.update_scalar(f'{mode}/data_time',
|
||||
time.time() - self.t)
|
||||
|
||||
|
@ -52,10 +68,12 @@ class IterTimerHook(Hook):
|
|||
outputs: Optional[Union[dict,
|
||||
Sequence[BaseDataElement]]] = None,
|
||||
mode: str = 'train') -> None:
|
||||
"""Logging time for a iteration and update the time flag.
|
||||
"""Calculating time for an iteration and updating "time"
|
||||
``HistoryBuffer`` of ``runner.message_hub``.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
runner (Runner): The runner of the training validation and
|
||||
testing process.
|
||||
batch_idx (int): The index of the current batch in the loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data
|
||||
from dataloader. Defaults to None.
|
||||
|
@ -63,7 +81,31 @@ class IterTimerHook(Hook):
|
|||
to None.
|
||||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
|
||||
runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
|
||||
# Update iteration time in `runner.message_hub`.
|
||||
message_hub = runner.message_hub
|
||||
message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
|
||||
self.t = time.time()
|
||||
window_size = runner.log_processor.window_size
|
||||
# Calculate eta every `window_size` iterations. Since test and val
|
||||
# loop will not update runner.iter, use `every_n_innter_iters`to check
|
||||
# the interval.
|
||||
if self.every_n_inner_iters(batch_idx, window_size):
|
||||
iter_time = message_hub.get_scalar(f'{mode}/time').mean(
|
||||
window_size)
|
||||
if mode == 'train':
|
||||
self.time_sec_tot += iter_time * window_size
|
||||
# Calculate average iterative time.
|
||||
time_sec_avg = self.time_sec_tot / (
|
||||
runner.iter - self.start_iter + 1)
|
||||
# Calculate eta.
|
||||
eta_sec = time_sec_avg * (
|
||||
runner.train_loop.max_iters - runner.iter - 1)
|
||||
runner.message_hub.update_info('eta', eta_sec)
|
||||
else:
|
||||
if mode == 'val':
|
||||
cur_dataloader = runner.val_loop.dataloader
|
||||
else:
|
||||
cur_dataloader = runner.test_loop.dataloader
|
||||
|
||||
eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1)
|
||||
runner.message_hub.update_info('eta', eta_sec)
|
||||
|
|
|
@ -1,16 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import datetime
|
||||
import os
|
||||
import os.path as osp
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Optional, Sequence, Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.data import BaseDataElement
|
||||
from mmengine.dist import master_only
|
||||
from mmengine.fileio import FileClient
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.registry import HOOKS
|
||||
|
@ -21,33 +15,20 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]]
|
|||
|
||||
@HOOKS.register_module()
|
||||
class LoggerHook(Hook):
|
||||
"""In this logger hook, the information will be printed on the terminal and
|
||||
saved in JSON file, tensorboard, wandb .etc.
|
||||
"""Collect logs from different components of ``Runner`` and write them to
|
||||
terminal, JSON file, tensorboard and wandb .etc.
|
||||
|
||||
``LoggerHook`` is used to record logs formatted by ``LogProcessor`` during
|
||||
training/validation/testing phase. It is used to control following
|
||||
behaviers:
|
||||
|
||||
- The frequency of logs update in terminal, local, tensorboad wandb.etc.
|
||||
- The frequency of show experiment information in terminal.
|
||||
- The work directory to save logs.
|
||||
|
||||
Args:
|
||||
by_epoch (bool): Whether ``EpochBasedLoop`` is used.
|
||||
Defaults to True.
|
||||
interval (int): Logging interval (every k iterations).
|
||||
Defaults to 10.
|
||||
custom_keys (dict, optional): Defines the keys in the log and which
|
||||
kinds of statistic methods should be used to log them.
|
||||
|
||||
- ``custom_keys`` contains multiple string-dict pairs. In each
|
||||
string-dict pair, the string defines a key name in the log and the
|
||||
dict is a config defines the statistic methods and corresponding
|
||||
arguments used to log the value. For example,
|
||||
``dict(loss=dict(method_name='mean', log_name='global_loss',
|
||||
window_size='global'))`` which means the log key ``loss`` will be
|
||||
counted as global mean and additionally logged as ``global_loss``.
|
||||
If ``log_name`` is not defined in config dict, the original logged
|
||||
key will be overwritten.
|
||||
- The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten
|
||||
because ``time`` and ``iter_time`` will be used to calculate
|
||||
estimated time of arrival. If you want to recount the time, you
|
||||
should set ``log_name`` in corresponding values.
|
||||
- For those statistic methods with the ``window_size`` argument,
|
||||
if ``by_epoch`` is set to False, ``windows_size`` should not be
|
||||
`epoch` to statistics log value by epoch.
|
||||
ignore_last (bool): Ignore the log of last iterations in each epoch if
|
||||
the number of remaining iterations is less than :attr:`interval`.
|
||||
Defaults to True.
|
||||
|
@ -72,64 +53,24 @@ class LoggerHook(Hook):
|
|||
Defaults to None.
|
||||
|
||||
Examples:
|
||||
>>> # `log_name` is defined, `loss_mean_window` will be an additional
|
||||
>>> # record.
|
||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
||||
>>> custom_keys=dict(
|
||||
>>> loss=dict(
|
||||
>>> log_name='loss_mean_window',
|
||||
>>> method_name='mean',
|
||||
>>> window_size=10)))
|
||||
>>> # `log_name` is not defined. `loss` will be overwritten by
|
||||
>>> # `global_mean` statistics.
|
||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
||||
>>> custom_keys=dict(
|
||||
>>> loss=dict(
|
||||
>>> method_name='mean',
|
||||
>>> window_size='global')))
|
||||
>>> # `time` cannot be overwritten, `global_time` will be an additional
|
||||
>>> # record.
|
||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
||||
>>> custom_keys=dict(
|
||||
>>> time=dict(
|
||||
>>> log_name='global_time',
|
||||
>>> method='mean',
|
||||
>>> window_size='global')))
|
||||
>>> # Record loss with different statistics methods.
|
||||
>>> logger_hook_cfg = dict(by_epoch=True,
|
||||
>>> custom_keys=dict(loss=[
|
||||
>>> dict(log_name='loss_mean_window',
|
||||
>>> method_name='mean',
|
||||
>>> window_size=10),
|
||||
>>> dict(method_name='mean',
|
||||
>>> window_size='global')]))
|
||||
>>> # A simplest LoggerHook config.
|
||||
>>> logger_hook_cfg = dict(interval=20)
|
||||
"""
|
||||
# eta will be calculated by time. `time` and `data_time` should not be
|
||||
# overwritten.
|
||||
fixed_smooth_keys = ('time', 'data_time')
|
||||
priority = 'BELOW_NORMAL'
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
by_epoch: bool = True,
|
||||
interval: int = 10,
|
||||
custom_keys: Optional[dict] = None,
|
||||
ignore_last: bool = True,
|
||||
interval_exp_name: int = 1000,
|
||||
out_dir: Optional[Union[str, Path]] = None,
|
||||
out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'),
|
||||
keep_local=True,
|
||||
file_client_args=None,
|
||||
keep_local: bool = True,
|
||||
file_client_args: Optional[dict] = None,
|
||||
):
|
||||
self._inner_iter = 0
|
||||
self.by_epoch = by_epoch
|
||||
self.interval = interval
|
||||
self.custom_keys = custom_keys if custom_keys is not None else dict()
|
||||
self.ignore_last = ignore_last
|
||||
|
||||
self.time_sec_tot = 0
|
||||
self.interval_exp_name = interval_exp_name
|
||||
self._check_custom_keys()
|
||||
|
||||
if out_dir is None and file_client_args is not None:
|
||||
raise ValueError(
|
||||
|
@ -169,7 +110,7 @@ class LoggerHook(Hook):
|
|||
f'{runner.timestamp}.log.json')
|
||||
self.yaml_log_path = osp.join(runner.work_dir,
|
||||
f'{runner.timestamp}.log.json')
|
||||
self.start_iter = runner.iter
|
||||
# TODO Compatible with Visualizer.
|
||||
if runner.meta is not None:
|
||||
runner.writer.add_params(runner.meta, file_path=self.yaml_log_path)
|
||||
|
||||
|
@ -178,41 +119,100 @@ class LoggerHook(Hook):
|
|||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[dict] = None) -> None:
|
||||
"""Record training logs.
|
||||
"""Record training logs after training iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[BaseDataElement], optional): Data from
|
||||
dataloader. Defaults to None.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (dict, optional): Outputs from model.
|
||||
Defaults to None.
|
||||
"""
|
||||
self._inner_iter = batch_idx
|
||||
if runner.meta is not None and 'exp_name' in runner.meta:
|
||||
if (self.every_n_iters(runner, self.interval_exp_name)) or (
|
||||
self.by_epoch and self.end_of_epoch(runner, batch_idx)):
|
||||
exp_info = f'Exp name: {runner.meta["exp_name"]}'
|
||||
runner.logger.info(exp_info)
|
||||
if self.by_epoch and self.every_n_inner_iters(batch_idx,
|
||||
self.interval):
|
||||
self._log_train(runner)
|
||||
elif not self.by_epoch and self.every_n_iters(runner, self.interval):
|
||||
self._log_train(runner)
|
||||
elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last:
|
||||
# Print experiment name every n iterations.
|
||||
if self.every_n_iters(runner,
|
||||
self.interval_exp_name) or (self.end_of_epoch(
|
||||
runner.train_dataloader, batch_idx)):
|
||||
exp_info = f'Exp name: {runner.experiment_name}'
|
||||
runner.logger.info(exp_info)
|
||||
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||
runner, batch_idx, 'train')
|
||||
elif (self.end_of_epoch(runner.train_dataloader, batch_idx)
|
||||
and not self.ignore_last):
|
||||
# `runner.max_iters` may not be divisible by `self.interval`. if
|
||||
# `self.ignore_last==True`, the log of remaining iterations will
|
||||
# be recorded (Epoch [4][1000/1007], the logs of 998-1007
|
||||
# iterations will be recorded).
|
||||
self._log_train(runner)
|
||||
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||
runner, batch_idx, 'train')
|
||||
else:
|
||||
return
|
||||
runner.logger.info(log_str)
|
||||
# TODO compatible with visualizer.
|
||||
runner.writer.add_scalars(tag, step=runner.iter + 1)
|
||||
|
||||
def after_val_iter(
|
||||
self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
|
||||
"""Record validation logs after validation iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (sequence, optional): Outputs from model. Defaults to None.
|
||||
"""
|
||||
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||
runner, batch_idx, 'val')
|
||||
runner.logger.info(log_str)
|
||||
|
||||
def after_test_iter(
|
||||
self,
|
||||
runner,
|
||||
batch_idx: int,
|
||||
data_batch: DATA_BATCH = None,
|
||||
outputs: Optional[Sequence[BaseDataElement]] = None) -> None:
|
||||
"""Record testing logs after iteration.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The index of the current batch in the train loop.
|
||||
data_batch (Sequence[Tuple[Any, BaseDataElement]], optional):
|
||||
Data from dataloader. Defaults to None.
|
||||
outputs (sequence, optional): Outputs from model. Defaults to None.
|
||||
"""
|
||||
if self.every_n_inner_iters(batch_idx, self.interval):
|
||||
tag, log_str = runner.log_processor.get_log_after_iter(
|
||||
runner, batch_idx, 'test')
|
||||
runner.logger.info(log_str)
|
||||
|
||||
def after_val_epoch(self, runner) -> None:
|
||||
"""Record validation logs.
|
||||
"""Record validation logs after validation epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
self._log_val(runner)
|
||||
tag, log_str = runner.log_processor.get_log_after_epoch(
|
||||
runner, len(runner.val_dataloader), 'val')
|
||||
runner.logger.info(log_str)
|
||||
# TODO compatible with visualizer.
|
||||
runner.writer.add_scalars(tag, step=runner.iter + 1)
|
||||
|
||||
def after_test_epoch(self, runner) -> None:
|
||||
"""Record testing logs after test epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
tag, log_str = runner.log_processor.get_log_after_epoch(
|
||||
runner, len(runner.val_dataloader), 'test')
|
||||
runner.logger.info(log_str)
|
||||
|
||||
def after_run(self, runner) -> None:
|
||||
"""Copy logs to ``self.out_dir`` if ``self.out_dir is not None``
|
||||
|
@ -237,280 +237,3 @@ class LoggerHook(Hook):
|
|||
os.remove(local_filepath)
|
||||
runner.logger.info((f'{local_filepath} was removed due to the '
|
||||
'`self.keep_local=False`'))
|
||||
|
||||
@master_only
|
||||
def _log_train(self, runner) -> None:
|
||||
"""Collect and record training logs which start named with "train/*".
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
tag = self._collect_info(runner, 'train')
|
||||
# The training log default defines `lr`, `momentum`, `time` and
|
||||
# `data_time`. `log_tag` will pop these keys and loop other keys to
|
||||
# `log_str`.
|
||||
log_tag = copy.deepcopy(tag)
|
||||
cur_iter = self._get_iter(runner, inner_iter=True)
|
||||
cur_epoch = self._get_epoch(runner, 'train')
|
||||
|
||||
# Record learning rate and momentum.
|
||||
lr_str_list = []
|
||||
momentum_str_list = []
|
||||
for key, value in tag.items():
|
||||
if key.startswith('lr'):
|
||||
log_tag.pop(key)
|
||||
lr_str_list.append(f'{key}: {value:.3e}')
|
||||
lr_str = ' '.join(lr_str_list)
|
||||
for key, value in tag.items():
|
||||
if key.startswith('momentum'):
|
||||
log_tag.pop(key)
|
||||
momentum_str_list.append(f'{key}: {value:.3e}')
|
||||
momentum_str = ' '.join(momentum_str_list)
|
||||
lr_momentum_str = f'{lr_str} {momentum_str}'
|
||||
# by epoch: Epoch [4][100/1000]
|
||||
# by iter: Iter [100/100000]
|
||||
if self.by_epoch:
|
||||
log_str = f'Epoch [{cur_epoch}]' \
|
||||
f'[{cur_iter}/{len(runner.cur_dataloader)}]\t'
|
||||
else:
|
||||
log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t'
|
||||
log_str += f'{lr_momentum_str}, '
|
||||
# Calculate eta time.
|
||||
self.time_sec_tot += (tag['time'] * self.interval)
|
||||
time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1)
|
||||
eta_sec = time_sec_avg * (
|
||||
runner.train_loop.max_iters - runner.iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_sec)))
|
||||
log_str += f'eta: {eta_str}, '
|
||||
log_str += f'time: {tag["time"]:.3f}, ' \
|
||||
f'data_time: {tag["data_time"]:.3f}, '
|
||||
# Pop recorded keys
|
||||
log_tag.pop('time')
|
||||
log_tag.pop('data_time')
|
||||
# statistic memory
|
||||
if torch.cuda.is_available():
|
||||
log_str += f'memory: {self._get_max_memory(runner)}, '
|
||||
# Loop left keys to fill `log_str`.
|
||||
log_items = []
|
||||
for name, val in log_tag.items():
|
||||
if isinstance(val, float):
|
||||
val = f'{val:.4f}'
|
||||
log_items.append(f'{name}: {val}')
|
||||
log_str += ', '.join(log_items)
|
||||
runner.logger.info(log_str)
|
||||
# Write logs to local, tensorboad, and wandb.
|
||||
runner.writer.add_scalars(
|
||||
tag, step=runner.iter + 1, file_path=self.json_log_path)
|
||||
|
||||
@master_only
|
||||
def _log_val(self, runner) -> None:
|
||||
"""Collect and record training logs which start named with "val/*".
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
tag = self._collect_info(runner, 'val')
|
||||
# Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501
|
||||
eval_iter = len(runner.cur_dataloader)
|
||||
cur_iter = self._get_iter(runner)
|
||||
cur_epoch = self._get_epoch(runner, 'val')
|
||||
# val/test time
|
||||
# here 1000 is the length of the val dataloader
|
||||
# by epoch: Epoch[val] [4][1000]
|
||||
# by iter: Iter[val] [1000]
|
||||
if self.by_epoch:
|
||||
# runner.epoch += 1 has been done before val workflow
|
||||
log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t'
|
||||
else:
|
||||
log_str = f'Iter(val) [{eval_iter}]\t'
|
||||
|
||||
log_items = []
|
||||
for name, val in tag.items():
|
||||
if isinstance(val, float):
|
||||
val = f'{val:.4f}'
|
||||
log_items.append(f'{name}: {val}')
|
||||
log_str += ', '.join(log_items)
|
||||
runner.logger.info(log_str)
|
||||
# Write tag.
|
||||
runner.writer.add_scalars(
|
||||
tag, step=cur_iter, file_path=self.json_log_path)
|
||||
|
||||
def _get_window_size(self, runner, window_size: Union[int, str]) \
|
||||
-> int:
|
||||
"""Parse window_size specified in ``self.custom_keys`` to int value.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
window_size (int or str): Smoothing scale of logs.
|
||||
|
||||
Returns:
|
||||
int: Smoothing window for statistical methods.
|
||||
"""
|
||||
if isinstance(window_size, int):
|
||||
assert window_size == self.interval, \
|
||||
'The value of windows size must equal to LoggerHook.interval'
|
||||
return window_size
|
||||
elif window_size == 'epoch':
|
||||
return self._inner_iter + 1
|
||||
elif window_size == 'global':
|
||||
return runner.iter + 1
|
||||
else:
|
||||
raise ValueError('window_size should be int, epoch or global, but '
|
||||
f'got invalid {window_size}')
|
||||
|
||||
def _collect_info(self, runner, mode: str) -> dict:
|
||||
"""Collect log information to a dict according to mode.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
mode (str): 'train' or 'val', which means the prefix attached by
|
||||
runner.
|
||||
|
||||
Returns:
|
||||
dict: Statistical values of logs.
|
||||
"""
|
||||
tag = OrderedDict()
|
||||
log_buffers = runner.message_hub.log_scalars
|
||||
mode_log_buffers = OrderedDict()
|
||||
# Filter log_buffers which starts with `mode`.
|
||||
for prefix_key, log_buffer in log_buffers.items():
|
||||
if prefix_key.startswith(mode):
|
||||
key = prefix_key.split('/')[-1]
|
||||
mode_log_buffers[key] = log_buffer
|
||||
# Ensure all metric and lr values are latest.
|
||||
for key in mode_log_buffers:
|
||||
# Update the latest learning rate and smoothed time logs.
|
||||
if key in self.fixed_smooth_keys or key.startswith('loss'):
|
||||
tag[key] = mode_log_buffers[key].mean(self.interval)
|
||||
else:
|
||||
tag[key] = mode_log_buffers[key].current()
|
||||
# Update custom keys.
|
||||
if mode == 'train':
|
||||
for log_key, log_cfg in self.custom_keys.items():
|
||||
self._parse_custom_keys(runner, log_key,
|
||||
copy.deepcopy(log_cfg),
|
||||
mode_log_buffers, tag)
|
||||
return tag
|
||||
|
||||
def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict,
|
||||
log_buffers: OrderedDict, tag: OrderedDict) -> None:
|
||||
"""Statistics logs in log_buffers according to custom_keys.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
log_key (str): log key specified in ``self.custom_keys``
|
||||
log_cfg (dict): A config dict for describing the logging
|
||||
statistics method.
|
||||
log_buffers (OrderedDict): All logs for the corresponding phase.
|
||||
tag (OrderedDict): A dict which defines all statistic values of
|
||||
logs.
|
||||
"""
|
||||
if isinstance(log_cfg, list):
|
||||
log_names = set()
|
||||
for cfg in log_cfg:
|
||||
log_name = cfg.get('log_name', None)
|
||||
if log_name in log_names:
|
||||
raise KeyError(f'{cfg["log_name"]} cannot be redefined in '
|
||||
'log_key')
|
||||
if log_name is not None:
|
||||
log_names.add(log_name)
|
||||
self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag)
|
||||
assert len(log_names) == len(log_cfg) - 1, \
|
||||
f'{log_key} cannot be overwritten multiple times, please ' \
|
||||
f'check only one key does not contain `log_name` in {log_cfg}.'
|
||||
elif isinstance(log_cfg, dict):
|
||||
if 'window_size' in log_cfg:
|
||||
log_cfg['window_size'] = \
|
||||
self._get_window_size(runner, log_cfg['window_size'])
|
||||
if 'log_name' in log_cfg:
|
||||
name = log_cfg.pop('log_name')
|
||||
else:
|
||||
name = log_key
|
||||
tag[name] = log_buffers[log_key].statistics(**log_cfg).item()
|
||||
else:
|
||||
raise ValueError('The structure of `LoggerHook.custom key` is '
|
||||
'wrong, please make sure the type of each key is '
|
||||
'dict or list.')
|
||||
|
||||
def _get_max_memory(self, runner) -> int:
|
||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
|
||||
for a given device.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
|
||||
Returns:
|
||||
The maximum GPU memory occupied by tensors in megabytes for a given
|
||||
device.
|
||||
"""
|
||||
device = getattr(runner.model, 'output_device', None)
|
||||
mem = torch.cuda.max_memory_allocated(device=device)
|
||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return int(mem_mb.item())
|
||||
|
||||
def _check_custom_keys(self) -> None:
|
||||
"""Check the legality of ``self.custom_keys``.
|
||||
|
||||
If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The
|
||||
key of ``self.fixed_smooth_keys`` cannot be overwritten.
|
||||
"""
|
||||
|
||||
def _check_window_size(item):
|
||||
if not self.by_epoch:
|
||||
assert item['window_size'] != 'epoch', \
|
||||
'window_size cannot be epoch if LoggerHook.by_epoch is ' \
|
||||
'False.'
|
||||
|
||||
def _check_fixed_keys(key, item):
|
||||
if key in self.fixed_smooth_keys:
|
||||
assert 'log_name' in item, f'{key} cannot be overwritten by ' \
|
||||
'custom keys!'
|
||||
|
||||
for key, value in self.custom_keys.items():
|
||||
if isinstance(value, Sequence):
|
||||
[(_check_window_size(item), _check_fixed_keys(key, item))
|
||||
for item in value]
|
||||
|
||||
else:
|
||||
_check_window_size(value)
|
||||
_check_fixed_keys(key, value)
|
||||
|
||||
def _get_epoch(self, runner, mode: str) -> int:
|
||||
"""Get epoch according to mode.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
mode (str): Train or val.
|
||||
|
||||
Returns:
|
||||
int: The current epoch.
|
||||
"""
|
||||
if mode == 'train':
|
||||
epoch = runner.epoch + 1
|
||||
elif mode == 'val':
|
||||
# normal val mode
|
||||
# runner.epoch += 1 has been done before val workflow
|
||||
epoch = runner.epoch
|
||||
else:
|
||||
raise ValueError(f"runner mode should be 'train' or 'val', "
|
||||
f'but got {runner.mode}')
|
||||
return epoch
|
||||
|
||||
def _get_iter(self, runner, inner_iter=False) -> int:
|
||||
"""Get the current training iteration step.
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
inner_iter (bool): Whether to return the inner iter of an epoch.
|
||||
Defaults to False.
|
||||
|
||||
Returns:
|
||||
int: The current global iter or inner iter.
|
||||
"""
|
||||
if self.by_epoch and inner_iter:
|
||||
current_iter = self._inner_iter + 1
|
||||
else:
|
||||
current_iter = runner.iter + 1
|
||||
return current_iter
|
||||
|
|
|
@ -86,6 +86,9 @@ class OptimizerHook(Hook):
|
|||
we keep ``outputs`` here. Defaults to None.
|
||||
"""
|
||||
runner.optimizer.zero_grad()
|
||||
runner.message_hub.update_scalar(
|
||||
'train/lr', runner.optimizer.param_groups[0]['lr'])
|
||||
|
||||
if self.detect_anomalous_params:
|
||||
self.detect_anomalous_parameters(runner.outputs['loss'], runner)
|
||||
runner.outputs['loss'].backward()
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .history_buffer import HistoryBuffer
|
||||
from .log_processor import LogProcessor
|
||||
from .logger import MMLogger, print_log
|
||||
from .message_hub import MessageHub
|
||||
|
||||
__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']
|
||||
__all__ = [
|
||||
'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,409 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import datetime
|
||||
from collections import OrderedDict
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class LogProcessor:
|
||||
"""A log processor used to format log information collected from
|
||||
``runner.message_hub.log_scalars``.
|
||||
|
||||
``LogProcessor`` instance is built by runner and will format
|
||||
``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can
|
||||
directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument
|
||||
``custom_cfg`` of constructor can control the statistics method of logs.
|
||||
|
||||
Args:
|
||||
window_size (int): default smooth interval Defaults to 10.
|
||||
by_epoch (bool): Whether to format logs with epoch stype. Defaults to
|
||||
True.
|
||||
custom_cfg (list[dict], optional): Contains multiple log config dict,
|
||||
in which key means the data source name of log and value means the
|
||||
statistic method and corresponding arguments used to count the
|
||||
data source. Defaults to None
|
||||
- If custom_cfg is None, all logs will be formatted via default
|
||||
methods, such as smoothing loss by default window_size. If
|
||||
custom_cfg is defined as a list of config dict, for example:
|
||||
[dict(data_src=loss, method='mean', log_name='global_loss',
|
||||
window_size='global')]. It means the log item ``loss`` will be
|
||||
counted as global mean and additionally logged as ``global_loss``
|
||||
(defined by ``log_name``). If ``log_name`` is not defined in
|
||||
config dict, the original logged key will be overwritten.
|
||||
|
||||
- The original log item cannot be overwritten twice. Here is
|
||||
an error example:
|
||||
[dict(data_src=loss, method='mean', window_size='global'),
|
||||
dict(data_src=loss, method='mean', window_size='epoch')].
|
||||
Both log config dict in custom_cfg do not have ``log_name`` key,
|
||||
which means the loss item will be overwritten twice.
|
||||
|
||||
- For those statistic methods with the ``window_size`` argument,
|
||||
if ``by_epoch`` is set to False, ``windows_size`` should not be
|
||||
`epoch` to statistics log value by epoch.
|
||||
|
||||
Examples:
|
||||
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
||||
>>> # record.
|
||||
>>> log_processor = dict(
|
||||
>>> window_size=10,
|
||||
>>> by_epoch=True,
|
||||
>>> custom_cfg=[dict(data_src='loss',
|
||||
>>> log_name='loss_large_window',
|
||||
>>> method_name='mean',
|
||||
>>> window_size=100)])
|
||||
>>> # `log_name` is not defined. `loss` will be overwritten.
|
||||
>>> log_processor = dict(
|
||||
>>> window_size=10,
|
||||
>>> by_epoch=True,
|
||||
>>> custom_cfg=[dict(data_src='loss',
|
||||
>>> method_name='mean',
|
||||
>>> window_size=100)])
|
||||
>>> # Record loss with different statistics methods.
|
||||
>>> log_processor = dict(
|
||||
>>> window_size=10,
|
||||
>>> by_epoch=True,
|
||||
>>> custom_cfg=[dict(data_src='loss',
|
||||
>>> log_name='loss_large_window',
|
||||
>>> method_name='mean',
|
||||
>>> window_size=100),
|
||||
>>> dict(data_src='loss',
|
||||
>>> method_name='mean',
|
||||
>>> window_size=100)])
|
||||
>>> # Overwrite loss item twice will raise an error.
|
||||
>>> log_processor = dict(
|
||||
>>> window_size=10,
|
||||
>>> by_epoch=True,
|
||||
>>> custom_cfg=[dict(data_src='loss',
|
||||
>>> method_name='mean',
|
||||
>>> window_size=100),
|
||||
>>> dict(data_src='loss',
|
||||
>>> method_name='max',
|
||||
>>> window_size=100)])
|
||||
AssertionError
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
window_size=10,
|
||||
by_epoch=True,
|
||||
custom_cfg: Optional[List[dict]] = None):
|
||||
self.window_size = window_size
|
||||
self.by_epoch = by_epoch
|
||||
self.custom_cfg = custom_cfg if custom_cfg else []
|
||||
self._check_custom_cfg()
|
||||
|
||||
def get_log_after_iter(self, runner, batch_idx: int,
|
||||
mode: str) -> Tuple[dict, str]:
|
||||
"""Format log string after training, validation or testing epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of training phase.
|
||||
batch_idx (int): The index of the current batch in the current
|
||||
loop.
|
||||
mode (str): Current mode of runner, train, test or val.
|
||||
|
||||
Return:
|
||||
Tuple(dict, str): Formatted log dict/string which will be
|
||||
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
||||
"""
|
||||
assert mode in ['train', 'test', 'val']
|
||||
current_loop = self._get_cur_loop(runner, mode)
|
||||
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
|
||||
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
|
||||
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
|
||||
# tag is used to write log information to different backends.
|
||||
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
|
||||
# `log_tag` will pop 'lr' and loop other keys to `log_str`.
|
||||
log_tag = copy.deepcopy(tag)
|
||||
# Record learning rate.
|
||||
lr_str_list = []
|
||||
for key, value in tag.items():
|
||||
if key.startswith('lr'):
|
||||
log_tag.pop(key)
|
||||
lr_str_list.append(f'{key}: {value:.3e}')
|
||||
lr_str = ' '.join(lr_str_list)
|
||||
# Format log header.
|
||||
# by_epoch == True
|
||||
# train/val: Epoch [5][5/10] ...
|
||||
# test: Epoch [5/10]
|
||||
# by_epoch == False
|
||||
# train: Epoch [5/10000] ... (divided by `max_iter`)
|
||||
# val/test: Epoch [5/2000] ... (divided by length of dataloader)
|
||||
if self.by_epoch:
|
||||
if mode in ['train', 'val']:
|
||||
cur_epoch = self._get_epoch(runner, mode)
|
||||
log_str = (f'Epoch({mode}) [{cur_epoch}]'
|
||||
f'[{cur_iter}/{len(current_loop.dataloader)}] ')
|
||||
else:
|
||||
log_str = (f'Epoch({mode}) '
|
||||
f'[{cur_iter}/{len(current_loop.dataloader)}] ')
|
||||
else:
|
||||
if mode == 'train':
|
||||
log_str = (f'Iter({mode}) '
|
||||
f'[{cur_iter}/{runner.train_loop.max_iters}] ')
|
||||
else:
|
||||
log_str = (f'Iter({mode}) [{batch_idx+1}'
|
||||
f'/{len(current_loop.dataloader)}] ')
|
||||
# Concatenate lr, momentum string with log header.
|
||||
log_str += f'{lr_str} '
|
||||
# If IterTimerHook used in runner, eta, time, and data_time should be
|
||||
# recorded.
|
||||
if (all(item in tag for item in ['time', 'data_time'])
|
||||
and 'eta' in runner.message_hub.runtime_info):
|
||||
eta = runner.message_hub.get_info('eta')
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta)))
|
||||
log_str += f'eta: {eta_str} '
|
||||
log_str += (f'time: {tag["time"]:.3f} '
|
||||
f'data_time: {tag["data_time"]:.3f} ')
|
||||
# Pop recorded keys
|
||||
log_tag.pop('time')
|
||||
log_tag.pop('data_time')
|
||||
|
||||
# If cuda is available, the max memory occupied should be calculated.
|
||||
if torch.cuda.is_available():
|
||||
log_str += f'memory: {self._get_max_memory(runner)} '
|
||||
# Loop left keys to fill `log_str`.
|
||||
if mode in ('train', 'val'):
|
||||
log_items = []
|
||||
for name, val in log_tag.items():
|
||||
if mode == 'val' and not name.startswith('loss'):
|
||||
continue
|
||||
if isinstance(val, float):
|
||||
val = f'{val:.4f}'
|
||||
log_items.append(f'{name}: {val}')
|
||||
log_str += ' '.join(log_items)
|
||||
return tag, log_str
|
||||
|
||||
def get_log_after_epoch(self, runner, batch_idx: int,
|
||||
mode: str) -> Tuple[dict, str]:
|
||||
"""Format log string after validation or testing epoch.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of training phase.
|
||||
batch_idx (int): The index of the current batch in the current
|
||||
loop.
|
||||
mode (str): Current mode of runner.
|
||||
|
||||
Return:
|
||||
Tuple(dict, str): Formatted log dict/string which will be
|
||||
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
||||
"""
|
||||
assert mode in [
|
||||
'test', 'val'
|
||||
], ('`_get_metric_log_str` only accept val or test mode, but got '
|
||||
f'{mode}')
|
||||
cur_loop = self._get_cur_loop(runner, mode)
|
||||
dataloader_len = len(cur_loop.dataloader)
|
||||
|
||||
custom_cfg_copy = self._parse_windows_size(runner, batch_idx)
|
||||
# tag is used to write log information to different backends.
|
||||
tag = self._collect_scalars(custom_cfg_copy, runner, mode)
|
||||
# validation log string needs cur epoch/iteration and max
|
||||
# epochs/iterations. test log string only needs length of test
|
||||
# dataloader.
|
||||
cur_iter = self._get_iter(runner, batch_idx)
|
||||
if self.by_epoch:
|
||||
if mode == 'val':
|
||||
cur_epoch = self._get_epoch(runner, mode)
|
||||
log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/'
|
||||
f'{dataloader_len}] ')
|
||||
else:
|
||||
log_str = (
|
||||
f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ')
|
||||
|
||||
else:
|
||||
if mode == 'train':
|
||||
log_str = (f'Iter({mode}) [{cur_iter}/'
|
||||
f'{runner.train_loop.max_iters}] ')
|
||||
else:
|
||||
log_str = (
|
||||
f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
|
||||
log_items = []
|
||||
for name, val in tag.items():
|
||||
if name in ('time', 'data_time'):
|
||||
continue
|
||||
if isinstance(val, float):
|
||||
val = f'{val:.4f}'
|
||||
log_items.append(f'{name}: {val}')
|
||||
log_str += ' '.join(log_items)
|
||||
return tag, log_str
|
||||
|
||||
def _collect_scalars(self, custom_cfg: List[dict], runner,
|
||||
mode: str) -> dict:
|
||||
"""Collect log information to compose a dict according to mode.
|
||||
|
||||
Args:
|
||||
custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int
|
||||
``window_size``.
|
||||
runner (Runner): The runner of the training process.
|
||||
mode (str): 'train' or 'val', which means the prefix attached by
|
||||
runner.
|
||||
|
||||
Returns:
|
||||
dict: Statistical values of logs.
|
||||
"""
|
||||
tag = OrderedDict()
|
||||
# history_scalars of train/val/test phase.
|
||||
history_scalars = runner.message_hub.log_scalars
|
||||
# corresponding mode history_scalars
|
||||
mode_history_scalars = OrderedDict()
|
||||
# extract log scalars and remove prefix to `mode_history_scalars`
|
||||
# according to mode.
|
||||
for prefix_key, log_buffer in history_scalars.items():
|
||||
if prefix_key.startswith(mode):
|
||||
key = prefix_key.split('/')[-1]
|
||||
mode_history_scalars[key] = log_buffer
|
||||
for key in mode_history_scalars:
|
||||
# Update the latest learning rate and smoothed time logs.
|
||||
if key.startswith('loss'):
|
||||
tag[key] = mode_history_scalars[key].mean(self.window_size)
|
||||
else:
|
||||
# Default statistic method is current.
|
||||
tag[key] = mode_history_scalars[key].current()
|
||||
# Update custom keys.
|
||||
for log_cfg in custom_cfg:
|
||||
data_src = log_cfg.pop('data_src')
|
||||
if 'log_name' in log_cfg:
|
||||
log_name = log_cfg.pop('log_name')
|
||||
else:
|
||||
log_name = data_src
|
||||
# log item in custom_cfg could only exist in train or val
|
||||
# mode.
|
||||
if data_src in mode_history_scalars:
|
||||
tag[log_name] = mode_history_scalars[data_src].statistics(
|
||||
**log_cfg)
|
||||
return tag
|
||||
|
||||
def _check_custom_cfg(self) -> None:
|
||||
"""Check the legality of ``self.custom_cfg``."""
|
||||
|
||||
def _check_window_size():
|
||||
for log_cfg in self.custom_cfg:
|
||||
if not self.by_epoch:
|
||||
assert log_cfg['window_size'] != 'epoch', \
|
||||
'window_size cannot be epoch if LoggerHook.by_epoch' \
|
||||
' is False.'
|
||||
|
||||
def _check_repeated_log_name():
|
||||
check_dict = dict()
|
||||
# The `log_name` of the same data_src should not be repeated.
|
||||
# If `log_name` is not specified, `data_src` will be overwritten.
|
||||
# But only allowed to be overwritten once.
|
||||
for log_cfg in self.custom_cfg:
|
||||
assert 'data_src' in log_cfg
|
||||
data_src = log_cfg['data_src']
|
||||
log_name = log_cfg.get('log_name', data_src)
|
||||
check_dict.setdefault(data_src,
|
||||
dict(log_names=set(), log_counts=0))
|
||||
check_dict[data_src]['log_names'].add(log_name)
|
||||
check_dict[data_src]['log_counts'] += 1
|
||||
assert (len(
|
||||
check_dict[data_src]
|
||||
['log_names']) == check_dict[data_src]['log_counts']), (
|
||||
f'If you want to statistic {data_src} with multiple '
|
||||
'statistics method, please check `log_name` is unique'
|
||||
f'and {data_src} will not be overwritten twice. See '
|
||||
f'more information in the docstring of `LogProcessor`')
|
||||
|
||||
_check_repeated_log_name()
|
||||
_check_window_size()
|
||||
|
||||
def _parse_windows_size(self, runner, batch_idx: int) -> list:
|
||||
"""Parse window_size defined in custom_cfg to int value.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int): The iteration index of current dataloader.
|
||||
"""
|
||||
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
|
||||
for log_cfg in custom_cfg_copy:
|
||||
window_size = log_cfg.get('window_size', None)
|
||||
if window_size is None or isinstance(window_size, int):
|
||||
continue
|
||||
elif window_size == 'epoch':
|
||||
log_cfg['window_size'] = batch_idx + 1
|
||||
elif window_size == 'global':
|
||||
log_cfg['window_size'] = runner.iter + 1
|
||||
else:
|
||||
raise TypeError(
|
||||
'window_size should be int, epoch or global, but got '
|
||||
f'invalid {window_size}')
|
||||
return custom_cfg_copy
|
||||
|
||||
def _get_max_memory(self, runner) -> int:
|
||||
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
|
||||
for a given device.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
|
||||
Returns:
|
||||
The maximum GPU memory occupied by tensors in megabytes for a given
|
||||
device.
|
||||
"""
|
||||
device = getattr(runner.model, 'output_device', None)
|
||||
mem = torch.cuda.max_memory_allocated(device=device)
|
||||
mem_mb = torch.tensor([int(mem) // (1024 * 1024)],
|
||||
dtype=torch.int,
|
||||
device=device)
|
||||
torch.cuda.reset_peak_memory_stats()
|
||||
return int(mem_mb.item())
|
||||
|
||||
def _get_iter(self, runner, batch_idx: int = None) -> int:
|
||||
"""Get current training iteration step.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
batch_idx (int, optional): The interaction index of current
|
||||
dataloader. Defaults to None.
|
||||
|
||||
Returns:
|
||||
int: The current global iter or inner iter.
|
||||
"""
|
||||
if self.by_epoch and batch_idx:
|
||||
current_iter = batch_idx + 1
|
||||
else:
|
||||
current_iter = runner.iter + 1
|
||||
return current_iter
|
||||
|
||||
def _get_epoch(self, runner, mode: str) -> int:
|
||||
"""Get current epoch according to mode.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training/validation process.
|
||||
mode (str): Current mode of runner, "train" or "val".
|
||||
|
||||
Returns:
|
||||
int: The current epoch.
|
||||
"""
|
||||
if mode == 'train':
|
||||
epoch = runner.epoch + 1
|
||||
elif mode == 'val':
|
||||
# normal val mode
|
||||
# runner.epoch += 1 has been done before validation
|
||||
epoch = runner.epoch
|
||||
else:
|
||||
raise ValueError(
|
||||
f"runner mode should be 'train' or 'val', but got {mode}")
|
||||
return epoch
|
||||
|
||||
def _get_cur_loop(self, runner, mode: str):
|
||||
"""Get current loop according to mode.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training/validation/testing
|
||||
process.
|
||||
mode (str): Current mode of runner, "train", "val" or test.
|
||||
|
||||
Returns:
|
||||
BaseLoop: Current loop of runner.
|
||||
"""
|
||||
# returns type hint will occur circular import
|
||||
if mode == 'train':
|
||||
return runner.train_loop
|
||||
elif mode == 'val':
|
||||
return runner.val_loop
|
||||
else:
|
||||
return runner.test_loop
|
|
@ -32,15 +32,15 @@ class MMFormatter(logging.Formatter):
|
|||
info_prefix = self._get_prefix('INFO', color)
|
||||
debug_prefix = self._get_prefix('DEBUG', color)
|
||||
# Config output format.
|
||||
self.err_format = f'%(asctime)s - %(name)s - {error_prefix} - ' \
|
||||
f'%(pathname)s - %(funcName)s - %(lineno)d - ' \
|
||||
'%(message)s'
|
||||
self.warn_format = f'%(asctime)s - %(name)s - {warn_prefix} - %(' \
|
||||
'message)s'
|
||||
self.info_format = f'%(asctime)s - %(name)s - {info_prefix} - %(' \
|
||||
'message)s'
|
||||
self.debug_format = f'%(asctime)s - %(name)s - {debug_prefix} - %(' \
|
||||
'message)s'
|
||||
self.err_format = (f'%(asctime)s - %(name)s - {error_prefix} - '
|
||||
'%(pathname)s - %(funcName)s - %(lineno)d - '
|
||||
'%(message)s')
|
||||
self.warn_format = (f'%(asctime)s - %(name)s - {warn_prefix} - %('
|
||||
'message)s')
|
||||
self.info_format = (f'%(asctime)s - %(name)s - {info_prefix} - %('
|
||||
'message)s')
|
||||
self.debug_format = (f'%(asctime)s - %(name)s - {debug_prefix} - %('
|
||||
'message)s')
|
||||
|
||||
def _get_prefix(self, level: str, color: bool) -> str:
|
||||
"""Get the prefix of the target log level.
|
||||
|
|
|
@ -25,7 +25,7 @@ from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only,
|
|||
sync_random_seed)
|
||||
from mmengine.evaluator import Evaluator
|
||||
from mmengine.hooks import Hook
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||
from mmengine.model import is_model_wrapper
|
||||
from mmengine.optim import _ParamScheduler, build_optimizer
|
||||
from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS,
|
||||
|
@ -127,6 +127,8 @@ class Runner:
|
|||
non-distributed environment will be launched.
|
||||
env_cfg (dict): A dict used for setting environment. Defaults to
|
||||
dict(dist_cfg=dict(backend='nccl')).
|
||||
log_processor (dict, optional): A processor to format logs. Defaults to
|
||||
None.
|
||||
log_level (int or str): The log level of MMLogger handlers.
|
||||
Defaults to 'INFO'.
|
||||
writer (ComposedWriter or dict, optional): A ComposedWriter object or a
|
||||
|
@ -184,6 +186,7 @@ class Runner:
|
|||
param_scheduler=dict(type='ParamSchedulerHook')),
|
||||
launcher='none',
|
||||
env_cfg=dict(dist_cfg=dict(backend='nccl')),
|
||||
log_processor=dict(window_size=20),
|
||||
writer=dict(
|
||||
name='composed_writer',
|
||||
writers=[dict(type='LocalWriter', save_dir='temp_dir')])
|
||||
|
@ -218,6 +221,7 @@ class Runner:
|
|||
launcher: str = 'none',
|
||||
env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')),
|
||||
log_level: str = 'INFO',
|
||||
log_processor: Optional[Dict] = None,
|
||||
writer: Optional[Union[ComposedWriter, Dict]] = None,
|
||||
default_scope: Optional[str] = None,
|
||||
randomness: Dict = dict(seed=None),
|
||||
|
@ -310,6 +314,10 @@ class Runner:
|
|||
else:
|
||||
self._experiment_name = self.timestamp
|
||||
|
||||
log_processor = dict() if log_processor is None else log_processor
|
||||
self.log_processor = LogProcessor(**log_processor)
|
||||
# Since `get_instance` could return any subclass of ManagerMixin. The
|
||||
# corresponding attribute needs a type hint.
|
||||
self.logger = self.build_logger(log_level=log_level)
|
||||
# Build `message_hub` for communication among components.
|
||||
# `message_hub` can store log scalars (loss, learning rate) and
|
||||
|
@ -385,6 +393,7 @@ class Runner:
|
|||
resume=cfg.get('resume', False),
|
||||
launcher=cfg.get('launcher', 'none'),
|
||||
env_cfg=cfg.get('env_cfg'), # type: ignore
|
||||
log_processor=cfg.get('log_processor'),
|
||||
log_level=cfg.get('log_level', 'INFO'),
|
||||
writer=cfg.get('writer'),
|
||||
default_scope=cfg.get('default_scope'),
|
||||
|
|
|
@ -157,18 +157,17 @@ class TestHook:
|
|||
|
||||
def test_end_of_epoch(self):
|
||||
hook = Hook()
|
||||
runner = Mock()
|
||||
|
||||
# last inner iter
|
||||
batch_idx = 1
|
||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||
runner.cur_dataloader.__len__ = Mock(return_value=2)
|
||||
return_val = hook.end_of_epoch(runner, batch_idx)
|
||||
dataloader = Mock()
|
||||
dataloader.__len__ = Mock(return_value=2)
|
||||
return_val = hook.end_of_epoch(dataloader, batch_idx)
|
||||
assert return_val
|
||||
|
||||
# not the last inner iter
|
||||
batch_idx = 0
|
||||
return_val = hook.end_of_epoch(runner, batch_idx)
|
||||
return_val = hook.end_of_epoch(dataloader, batch_idx)
|
||||
assert not return_val
|
||||
|
||||
def test_is_last_train_epoch(self):
|
||||
|
|
|
@ -1,29 +1,70 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import Mock
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from mmengine.hooks import IterTimerHook
|
||||
from mmengine.logging import MessageHub
|
||||
|
||||
|
||||
class TestIterTimerHook:
|
||||
def time_patch():
|
||||
if not hasattr(time_patch, 'time'):
|
||||
time_patch.time = 0
|
||||
else:
|
||||
time_patch.time += 1
|
||||
return time_patch.time
|
||||
|
||||
|
||||
class TestIterTimerHook(TestCase):
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.hook = IterTimerHook()
|
||||
|
||||
def test_init(self):
|
||||
assert self.hook.time_sec_tot == 0
|
||||
assert self.hook.start_iter == 0
|
||||
|
||||
def test_before_run(self):
|
||||
runner = MagicMock()
|
||||
runner.iter = 1
|
||||
self.hook.before_run(runner)
|
||||
assert self.hook.start_iter == 1
|
||||
|
||||
def test_before_epoch(self):
|
||||
hook = IterTimerHook()
|
||||
runner = Mock()
|
||||
hook._before_epoch(runner)
|
||||
assert isinstance(hook.t, float)
|
||||
self.hook._before_epoch(runner)
|
||||
assert isinstance(self.hook.t, float)
|
||||
|
||||
@patch('time.time', MagicMock(return_value=1))
|
||||
def test_before_iter(self):
|
||||
hook = IterTimerHook()
|
||||
runner = Mock()
|
||||
runner = MagicMock()
|
||||
runner.log_buffer = dict()
|
||||
hook._before_epoch(runner)
|
||||
hook._before_iter(runner, 0)
|
||||
runner.message_hub.update_scalar.assert_called()
|
||||
self.hook._before_epoch(runner)
|
||||
for mode in ('train', 'val', 'test'):
|
||||
self.hook._before_iter(runner, batch_idx=1, mode=mode)
|
||||
runner.message_hub.update_scalar.assert_called_with(
|
||||
f'{mode}/data_time', 0)
|
||||
|
||||
@patch('time.time', time_patch)
|
||||
def test_after_iter(self):
|
||||
hook = IterTimerHook()
|
||||
runner = Mock()
|
||||
runner = MagicMock()
|
||||
runner.log_buffer = dict()
|
||||
hook._before_epoch(runner)
|
||||
hook._after_iter(runner, 0)
|
||||
runner.log_processor.window_size = 10
|
||||
runner.train_loop.max_iters = 100
|
||||
runner.iter = 0
|
||||
runner.test_loop.dataloader = [0] * 20
|
||||
runner.val_loop.dataloader = [0] * 20
|
||||
self.hook._before_epoch(runner)
|
||||
self.hook.before_run(runner)
|
||||
self.hook._after_iter(runner, batch_idx=1)
|
||||
runner.message_hub.update_scalar.assert_called()
|
||||
runner.message_hub.get_log.assert_not_called()
|
||||
runner.message_hub.update_info.assert_not_called()
|
||||
runner.message_hub = MessageHub.get_instance('test_iter_timer_hook')
|
||||
runner.iter = 9
|
||||
# eta = (100 - 10) / 1
|
||||
self.hook._after_iter(runner, batch_idx=89)
|
||||
assert runner.message_hub.get_info('eta') == 90
|
||||
self.hook._after_iter(runner, batch_idx=9, mode='val')
|
||||
assert runner.message_hub.get_info('eta') == 10
|
||||
self.hook._after_iter(runner, batch_idx=19, mode='test')
|
||||
assert runner.message_hub.get_info('eta') == 0
|
||||
|
|
|
@ -1,13 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import datetime
|
||||
import logging
|
||||
import os.path as osp
|
||||
import sys
|
||||
from collections import OrderedDict
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.fileio.file_client import HardDiskBackend
|
||||
from mmengine.hooks import LoggerHook
|
||||
|
@ -17,11 +12,8 @@ class TestLoggerHook:
|
|||
|
||||
def test_init(self):
|
||||
logger_hook = LoggerHook(out_dir='tmp.txt')
|
||||
assert logger_hook.by_epoch
|
||||
assert logger_hook.interval == 10
|
||||
assert not logger_hook.custom_keys
|
||||
assert logger_hook.ignore_last
|
||||
assert logger_hook.time_sec_tot == 0
|
||||
assert logger_hook.interval_exp_name == 1000
|
||||
assert logger_hook.out_suffix == ('.log.json', '.log', '.py')
|
||||
assert logger_hook.keep_local
|
||||
|
@ -30,22 +22,7 @@ class TestLoggerHook:
|
|||
# out_dir should be None or string or tuple of string.
|
||||
with pytest.raises(TypeError):
|
||||
LoggerHook(out_dir=1)
|
||||
# time cannot be overwritten.
|
||||
with pytest.raises(AssertionError):
|
||||
LoggerHook(custom_keys=dict(time=dict(method='max')))
|
||||
LoggerHook(
|
||||
custom_keys=dict(time=[
|
||||
dict(method='max', log_name='time_max'),
|
||||
dict(method='min', log_name='time_min')
|
||||
]))
|
||||
# Epoch window_size cannot be used when `LoggerHook.by_epoch=False`
|
||||
with pytest.raises(AssertionError):
|
||||
LoggerHook(
|
||||
by_epoch=False,
|
||||
custom_keys=dict(
|
||||
time=dict(
|
||||
method='max', log_name='time_max',
|
||||
window_size='epoch')))
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
LoggerHook(file_client_args=dict(enable_mc=True))
|
||||
|
||||
|
@ -60,20 +37,23 @@ class TestLoggerHook:
|
|||
assert logger_hook.out_dir == osp.join('out_dir', 'work_dir')
|
||||
assert logger_hook.json_log_path == osp.join('work_dir',
|
||||
'timestamp.log.json')
|
||||
assert logger_hook.start_iter == runner.iter
|
||||
runner.writer.add_params.assert_called()
|
||||
|
||||
def test_after_run(self, tmp_path):
|
||||
# Test
|
||||
out_dir = tmp_path / 'out_dir'
|
||||
out_dir.mkdir()
|
||||
work_dir = tmp_path / 'work_dir'
|
||||
work_dir.mkdir()
|
||||
work_dir_json = work_dir / 'tmp.log.json'
|
||||
json_f = open(work_dir_json, 'w')
|
||||
json_f.close()
|
||||
runner = MagicMock()
|
||||
runner.work_dir = work_dir
|
||||
|
||||
# Test without out_dir.
|
||||
logger_hook = LoggerHook()
|
||||
logger_hook.after_run(runner)
|
||||
# Test with out_dir and make sure json file has been moved to out_dir.
|
||||
json_f = open(work_dir_json, 'w')
|
||||
json_f.close()
|
||||
logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False)
|
||||
logger_hook.out_dir = str(out_dir)
|
||||
logger_hook.after_run(runner)
|
||||
|
@ -84,274 +64,83 @@ class TestLoggerHook:
|
|||
def test_after_train_iter(self):
|
||||
# Test LoggerHook by iter.
|
||||
runner = MagicMock()
|
||||
runner.iter = 10
|
||||
batch_idx = 5
|
||||
logger_hook = LoggerHook(by_epoch=False)
|
||||
logger_hook._log_train = MagicMock()
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
runner.log_processor.get_log_after_iter = MagicMock(
|
||||
return_value=(dict(), 'log_str'))
|
||||
logger_hook = LoggerHook()
|
||||
logger_hook.after_train_iter(runner, batch_idx=5)
|
||||
# `cur_iter=10+1`, which cannot be exact division by
|
||||
# `logger_hook.interval`
|
||||
logger_hook._log_train.assert_not_called()
|
||||
runner.iter = 9
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
logger_hook._log_train.assert_called()
|
||||
runner.log_processor.get_log_after_iter.assert_not_called()
|
||||
logger_hook.after_train_iter(runner, batch_idx=9)
|
||||
runner.log_processor.get_log_after_iter.assert_called()
|
||||
|
||||
# Test LoggerHook by epoch.
|
||||
logger_hook = LoggerHook(by_epoch=True)
|
||||
logger_hook._log_train = MagicMock()
|
||||
# Only `runner.inner_iter` will work.
|
||||
runner.iter = 9
|
||||
batch_idx = 10
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
logger_hook._log_train.assert_not_called()
|
||||
batch_idx = 9
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
logger_hook._log_train.assert_called()
|
||||
logger_hook = LoggerHook()
|
||||
runner = MagicMock()
|
||||
runner.log_processor.get_log_after_iter = MagicMock(
|
||||
return_value=(dict(), 'log_str'))
|
||||
# Only `batch_idx` will work.
|
||||
logger_hook.after_train_iter(runner, batch_idx=10)
|
||||
runner.log_processor.get_log_after_iter.assert_not_called()
|
||||
logger_hook.after_train_iter(runner, batch_idx=9)
|
||||
runner.log_processor.get_log_after_iter.assert_called()
|
||||
|
||||
# Test end of the epoch.
|
||||
logger_hook = LoggerHook(by_epoch=True, ignore_last=False)
|
||||
logger_hook._log_train = MagicMock()
|
||||
runner.cur_dataloader = [0] * 5
|
||||
batch_idx = 4
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
logger_hook._log_train.assert_called()
|
||||
runner = MagicMock()
|
||||
runner.log_processor.get_log_after_iter = MagicMock(
|
||||
return_value=(dict(), 'log_str'))
|
||||
logger_hook = LoggerHook(ignore_last=False)
|
||||
runner.train_dataloader = [0] * 5
|
||||
logger_hook.after_train_iter(runner, batch_idx=4)
|
||||
runner.log_processor.get_log_after_iter.assert_called()
|
||||
|
||||
# Test print exp_name
|
||||
runner = MagicMock()
|
||||
runner.log_processor.get_log_after_iter = MagicMock(
|
||||
return_value=(dict(), 'log_str'))
|
||||
runner.meta = dict(exp_name='retinanet')
|
||||
logger_hook = LoggerHook()
|
||||
runner.logger = MagicMock()
|
||||
logger_hook._log_train = MagicMock()
|
||||
logger_hook.after_train_iter(runner, batch_idx=batch_idx)
|
||||
runner.logger.info.assert_called_with(
|
||||
f'Exp name: {runner.meta["exp_name"]}')
|
||||
logger_hook = LoggerHook()
|
||||
logger_hook.after_train_iter(runner, batch_idx=999)
|
||||
runner.logger.info.assert_called()
|
||||
|
||||
def test_after_val_epoch(self):
|
||||
logger_hook = LoggerHook()
|
||||
runner = MagicMock()
|
||||
logger_hook._log_val = MagicMock()
|
||||
runner.log_processor.get_log_after_epoch = MagicMock(
|
||||
return_value=(dict(), 'string'))
|
||||
logger_hook.after_val_epoch(runner)
|
||||
logger_hook._log_val.assert_called()
|
||||
runner.log_processor.get_log_after_epoch.assert_called()
|
||||
runner.logger.info.assert_called()
|
||||
runner.writer.add_scalars.assert_called()
|
||||
|
||||
@pytest.mark.parametrize('by_epoch', [True, False])
|
||||
def test_log_train(self, by_epoch, capsys):
|
||||
runner = self._setup_runner()
|
||||
runner.meta = dict(exp_name='retinanet')
|
||||
# Prepare LoggerHook
|
||||
logger_hook = LoggerHook(by_epoch=by_epoch)
|
||||
logger_hook._inner_iter = 1
|
||||
logger_hook.writer = MagicMock()
|
||||
logger_hook.time_sec_tot = 1000
|
||||
logger_hook.start_iter = 0
|
||||
logger_hook._get_max_memory = MagicMock(return_value='100')
|
||||
logger_hook.json_log_path = 'tmp.json'
|
||||
|
||||
# Prepare training information.
|
||||
train_infos = dict(
|
||||
lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0)
|
||||
logger_hook._collect_info = MagicMock(return_value=train_infos)
|
||||
logger_hook._log_train(runner)
|
||||
# Verify that the correct variables have been written.
|
||||
runner.writer.add_scalars.assert_called_with(
|
||||
train_infos, step=11, file_path='tmp.json')
|
||||
# Verify that the correct context have been logged.
|
||||
out, _ = capsys.readouterr()
|
||||
time_avg = logger_hook.time_sec_tot / (
|
||||
runner.iter + 1 - logger_hook.start_iter)
|
||||
eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1)
|
||||
eta_str = str(datetime.timedelta(seconds=int(eta_second)))
|
||||
if by_epoch:
|
||||
if torch.cuda.is_available():
|
||||
log_str = 'Epoch [2][2/5]\t' \
|
||||
f"lr: {train_infos['lr']:.3e} " \
|
||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
||||
f'eta: {eta_str}, ' \
|
||||
f"time: {train_infos['time']:.3f}, " \
|
||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
||||
f'memory: 100, ' \
|
||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
||||
else:
|
||||
log_str = 'Epoch [2][2/5]\t' \
|
||||
f"lr: {train_infos['lr']:.3e} " \
|
||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
||||
f'eta: {eta_str}, ' \
|
||||
f"time: {train_infos['time']:.3f}, " \
|
||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
||||
assert out == log_str
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
log_str = 'Iter [11/50]\t' \
|
||||
f"lr: {train_infos['lr']:.3e} " \
|
||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
||||
f'eta: {eta_str}, ' \
|
||||
f"time: {train_infos['time']:.3f}, " \
|
||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
||||
f'memory: 100, ' \
|
||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
||||
else:
|
||||
log_str = 'Iter [11/50]\t' \
|
||||
f"lr: {train_infos['lr']:.3e} " \
|
||||
f"momentum: {train_infos['momentum']:.3e}, " \
|
||||
f'eta: {eta_str}, ' \
|
||||
f"time: {train_infos['time']:.3f}, " \
|
||||
f"data_time: {train_infos['data_time']:.3f}, " \
|
||||
f"loss_cls: {train_infos['loss_cls']:.4f}\n"
|
||||
assert out == log_str
|
||||
|
||||
@pytest.mark.parametrize('by_epoch', [True, False])
|
||||
def test_log_val(self, by_epoch, capsys):
|
||||
runner = self._setup_runner()
|
||||
# Prepare LoggerHook.
|
||||
logger_hook = LoggerHook(by_epoch=by_epoch)
|
||||
logger_hook.json_log_path = 'tmp.json'
|
||||
metric = dict(accuracy=0.9, data_time=1.0)
|
||||
logger_hook._collect_info = MagicMock(return_value=metric)
|
||||
logger_hook._log_val(runner)
|
||||
# Verify that the correct context have been logged.
|
||||
out, _ = capsys.readouterr()
|
||||
runner.writer.add_scalars.assert_called_with(
|
||||
metric, step=11, file_path='tmp.json')
|
||||
if by_epoch:
|
||||
assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \
|
||||
'data_time: 1.0000\n'
|
||||
|
||||
else:
|
||||
assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \
|
||||
'data_time: 1.0000\n'
|
||||
|
||||
def test_get_window_size(self):
|
||||
runner = self._setup_runner()
|
||||
logger_hook = LoggerHook()
|
||||
logger_hook._inner_iter = 1
|
||||
# Test get window size by name.
|
||||
assert logger_hook._get_window_size(runner, 'epoch') == 2
|
||||
assert logger_hook._get_window_size(runner, 'global') == 11
|
||||
assert logger_hook._get_window_size(runner, 10) == 10
|
||||
# Window size must equal to `logger_hook.interval`.
|
||||
with pytest.raises(AssertionError):
|
||||
logger_hook._get_window_size(runner, 20)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
logger_hook._get_window_size(runner, 'unknwon')
|
||||
|
||||
def test_parse_custom_keys(self):
|
||||
tag = OrderedDict()
|
||||
runner = self._setup_runner()
|
||||
log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock())
|
||||
cfg_dict = dict(
|
||||
lr=dict(method='min'),
|
||||
loss=[
|
||||
dict(method='min', window_size='global'),
|
||||
dict(method='max', log_name='loss_max')
|
||||
])
|
||||
logger_hook = LoggerHook()
|
||||
for log_key, log_cfg in cfg_dict.items():
|
||||
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
|
||||
log_buffers, tag)
|
||||
assert list(tag) == ['lr', 'loss', 'loss_max']
|
||||
assert log_buffers['lr'].min.assert_called
|
||||
assert log_buffers['loss'].min.assert_called
|
||||
assert log_buffers['loss'].max.assert_called
|
||||
assert log_buffers['loss'].mean.assert_called
|
||||
# `log_name` Cannot be repeated.
|
||||
with pytest.raises(KeyError):
|
||||
cfg_dict = dict(loss=[
|
||||
dict(method='min', window_size='global'),
|
||||
dict(method='max', log_name='loss_max'),
|
||||
dict(method='mean', log_name='loss_max')
|
||||
])
|
||||
logger_hook.custom_keys = cfg_dict
|
||||
for log_key, log_cfg in cfg_dict.items():
|
||||
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
|
||||
log_buffers, tag)
|
||||
# `log_key` cannot be overwritten multiple times.
|
||||
with pytest.raises(AssertionError):
|
||||
cfg_dict = dict(loss=[
|
||||
dict(method='min', window_size='global'),
|
||||
dict(method='max'),
|
||||
])
|
||||
logger_hook.custom_keys = cfg_dict
|
||||
for log_key, log_cfg in cfg_dict.items():
|
||||
logger_hook._parse_custom_keys(runner, log_key, log_cfg,
|
||||
log_buffers, tag)
|
||||
|
||||
def test_collect_info(self):
|
||||
runner = self._setup_runner()
|
||||
logger_hook = LoggerHook(
|
||||
custom_keys=dict(time=dict(method='max', log_name='time_max')))
|
||||
logger_hook._parse_custom_keys = MagicMock()
|
||||
# Collect with prefix.
|
||||
log_buffers = {
|
||||
'train/time': MagicMock(),
|
||||
'lr': MagicMock(),
|
||||
'train/loss_cls': MagicMock(),
|
||||
'val/metric': MagicMock()
|
||||
}
|
||||
runner.message_hub.log_scalars = log_buffers
|
||||
tag = logger_hook._collect_info(runner, mode='train')
|
||||
# Test parse custom_keys
|
||||
logger_hook._parse_custom_keys.assert_called()
|
||||
# Test training key in tag.
|
||||
assert list(tag.keys()) == ['time', 'loss_cls']
|
||||
# Test statistics lr with `current`, loss and time with 'mean'
|
||||
log_buffers['train/time'].mean.assert_called()
|
||||
log_buffers['train/loss_cls'].mean.assert_called()
|
||||
log_buffers['train/loss_cls'].current.assert_not_called()
|
||||
|
||||
tag = logger_hook._collect_info(runner, mode='val')
|
||||
assert list(tag.keys()) == ['metric']
|
||||
log_buffers['val/metric'].current.assert_called()
|
||||
|
||||
@patch('torch.cuda.max_memory_allocated', MagicMock())
|
||||
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
|
||||
def test_get_max_memory(self):
|
||||
def test_after_test_epoch(self):
|
||||
logger_hook = LoggerHook()
|
||||
runner = MagicMock()
|
||||
runner.world_size = 1
|
||||
runner.model = torch.nn.Linear(1, 1)
|
||||
logger_hook._get_max_memory(runner)
|
||||
torch.cuda.max_memory_allocated.assert_called()
|
||||
torch.cuda.reset_peak_memory_stats.assert_called()
|
||||
runner.log_processor.get_log_after_epoch = MagicMock(
|
||||
return_value=(dict(), 'log_str'))
|
||||
logger_hook.after_test_epoch(runner)
|
||||
runner.log_processor.get_log_after_epoch.assert_called()
|
||||
runner.logger.info.assert_called()
|
||||
|
||||
def test_get_iter(self):
|
||||
runner = self._setup_runner()
|
||||
def test_after_val_iter(self):
|
||||
logger_hook = LoggerHook()
|
||||
logger_hook._inner_iter = 1
|
||||
# Get global iter when `inner_iter=False`
|
||||
iter = logger_hook._get_iter(runner)
|
||||
assert iter == 11
|
||||
# Get inner iter
|
||||
iter = logger_hook._get_iter(runner, inner_iter=True)
|
||||
assert iter == 2
|
||||
# Still get global iter when `logger_hook.by_epoch==False`
|
||||
logger_hook.by_epoch = False
|
||||
iter = logger_hook._get_iter(runner, inner_iter=True)
|
||||
assert iter == 11
|
||||
|
||||
def test_get_epoch(self):
|
||||
runner = self._setup_runner()
|
||||
logger_hook = LoggerHook()
|
||||
epoch = logger_hook._get_epoch(runner, 'train')
|
||||
assert epoch == 2
|
||||
epoch = logger_hook._get_epoch(runner, 'val')
|
||||
assert epoch == 1
|
||||
with pytest.raises(ValueError):
|
||||
logger_hook._get_epoch(runner, 'test')
|
||||
|
||||
def _setup_runner(self):
|
||||
runner = MagicMock()
|
||||
runner.epoch = 1
|
||||
runner.cur_dataloader = [0] * 5
|
||||
runner.iter = 10
|
||||
runner.train_loop.max_iters = 50
|
||||
logger = logging.getLogger()
|
||||
logger.setLevel(logging.INFO)
|
||||
for handler in logger.handlers:
|
||||
if not isinstance(handler, logging.StreamHandler):
|
||||
continue
|
||||
else:
|
||||
logger.addHandler(logging.StreamHandler(stream=sys.stdout))
|
||||
runner.logger = logger
|
||||
runner.message_hub = MagicMock()
|
||||
runner.composed_wirter = MagicMock()
|
||||
return runner
|
||||
runner.iter = 0
|
||||
runner.log_processor.get_log_after_iter = MagicMock(
|
||||
return_value=(dict(), 'log_str'))
|
||||
logger_hook.after_val_iter(runner, 1)
|
||||
runner.log_processor.get_log_after_iter.assert_not_called()
|
||||
logger_hook.after_val_iter(runner, 9)
|
||||
runner.log_processor.get_log_after_iter.assert_called()
|
||||
|
||||
def test_after_test_iter(self):
|
||||
logger_hook = LoggerHook()
|
||||
runner = MagicMock()
|
||||
runner.iter = 0
|
||||
runner.log_processor.get_log_after_iter = MagicMock(
|
||||
return_value=(dict(), 'log_str'))
|
||||
logger_hook.after_test_iter(runner, 1)
|
||||
runner.log_processor.get_log_after_iter.assert_not_called()
|
||||
logger_hook.after_test_iter(runner, 9)
|
||||
runner.log_processor.get_log_after_iter.assert_called()
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import MagicMock, Mock
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
@ -45,7 +45,7 @@ class TestOptimizerHook:
|
|||
model = Model()
|
||||
x = torch.rand(1, 1, 3, 3)
|
||||
|
||||
dummy_runner = Mock()
|
||||
dummy_runner = MagicMock()
|
||||
dummy_runner.optimizer.zero_grad = Mock(return_value=None)
|
||||
dummy_runner.optimizer.step = Mock(return_value=None)
|
||||
dummy_runner.model = model
|
||||
|
|
|
@ -0,0 +1,242 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine.logging import LogProcessor, MessageHub, MMLogger
|
||||
|
||||
|
||||
class TestLogProcessor:
|
||||
|
||||
def test_init(self):
|
||||
log_processor = LogProcessor(
|
||||
window_size=10, by_epoch=True, custom_cfg=None)
|
||||
assert log_processor.by_epoch
|
||||
assert log_processor.window_size == 10
|
||||
assert log_processor.custom_cfg == []
|
||||
|
||||
def test_check_custom_cfg(self):
|
||||
# ``by_epoch==False`` and `window_size='epoch'` in log config will
|
||||
# raise AssertionError.
|
||||
custom_cfg = [dict(data_src='loss', window_size='epoch')]
|
||||
with pytest.raises(AssertionError):
|
||||
LogProcessor(by_epoch=False, custom_cfg=custom_cfg)
|
||||
# Duplicate log_name will raise AssertionError.
|
||||
custom_cfg = [
|
||||
dict(data_src='loss', log_name='loss_1'),
|
||||
dict(data_src='loss', log_name='loss_1')
|
||||
]
|
||||
with pytest.raises(AssertionError):
|
||||
LogProcessor(custom_cfg=custom_cfg)
|
||||
# Overwrite loss item twice will raise AssertionError.
|
||||
custom_cfg = [dict(data_src='loss'), dict(data_src='loss')]
|
||||
with pytest.raises(AssertionError):
|
||||
LogProcessor(custom_cfg=custom_cfg)
|
||||
|
||||
custom_cfg = [
|
||||
dict(data_src='loss_cls', window_size=100, method_name='min'),
|
||||
dict(data_src='loss', log_name='loss_min', method_name='max'),
|
||||
dict(data_src='loss', log_name='loss_max', method_name='max')
|
||||
]
|
||||
LogProcessor(custom_cfg=custom_cfg)
|
||||
|
||||
def test_parse_windows_size(self):
|
||||
log_processor = LogProcessor()
|
||||
# Test parse 'epoch' window_size.
|
||||
log_processor.custom_cfg = [
|
||||
dict(data_src='loss_cls', window_size='epoch')
|
||||
]
|
||||
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
|
||||
assert custom_cfg[0]['window_size'] == 2
|
||||
|
||||
# Test parse 'global' window_size.
|
||||
log_processor.custom_cfg = [
|
||||
dict(data_src='loss_cls', window_size='global')
|
||||
]
|
||||
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
|
||||
assert custom_cfg[0]['window_size'] == 11
|
||||
|
||||
# Test parse int window_size
|
||||
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=100)]
|
||||
custom_cfg = log_processor._parse_windows_size(self.runner, 1)
|
||||
assert custom_cfg[0]['window_size'] == 100
|
||||
|
||||
# Invalid type window_size will raise TypeError.
|
||||
log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=[])]
|
||||
with pytest.raises(TypeError):
|
||||
log_processor._parse_windows_size(custom_cfg, self.runner)
|
||||
|
||||
@pytest.mark.parametrize('by_epoch,mode',
|
||||
([True, 'train'], [False, 'train'], [True, 'val'],
|
||||
[False, 'val'], [True, 'test'], [False, 'test']))
|
||||
def test_get_log_after_iter(self, by_epoch, mode):
|
||||
# Prepare LoggerHook
|
||||
log_processor = LogProcessor(by_epoch=by_epoch)
|
||||
log_processor._get_max_memory = MagicMock(return_value='100')
|
||||
eta = 40
|
||||
self.runner.message_hub.update_info('eta', eta)
|
||||
# Prepare training information.
|
||||
if mode == 'train':
|
||||
train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0)
|
||||
else:
|
||||
train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0)
|
||||
log_processor._collect_scalars = MagicMock(return_value=train_logs)
|
||||
tag, out = log_processor.get_log_after_iter(self.runner, 1, mode)
|
||||
# Verify that the correct context have been logged.
|
||||
cur_loop = log_processor._get_cur_loop(self.runner, mode)
|
||||
if by_epoch:
|
||||
if mode in ['train', 'val']:
|
||||
cur_epoch = log_processor._get_epoch(self.runner, mode)
|
||||
log_str = (f'Epoch({mode}) [{cur_epoch}][2/'
|
||||
f'{len(cur_loop.dataloader)}] ')
|
||||
else:
|
||||
log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ')
|
||||
|
||||
if mode == 'train':
|
||||
log_str += f"lr: {train_logs['lr']:.3e} "
|
||||
else:
|
||||
log_str += ' '
|
||||
|
||||
log_str += (f'eta: 0:00:40 '
|
||||
f"time: {train_logs['time']:.3f} "
|
||||
f"data_time: {train_logs['data_time']:.3f} ")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
log_str += 'memory: 100 '
|
||||
if mode == 'train':
|
||||
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
|
||||
assert out == log_str
|
||||
else:
|
||||
if mode == 'train':
|
||||
max_iters = self.runner.train_loop.max_iters
|
||||
log_str = f'Iter({mode}) [11/{max_iters}] '
|
||||
else:
|
||||
max_iters = len(cur_loop.dataloader)
|
||||
log_str = f'Iter({mode}) [2/{max_iters}] '
|
||||
|
||||
if mode == 'train':
|
||||
log_str += f"lr: {train_logs['lr']:.3e} "
|
||||
else:
|
||||
log_str += ' '
|
||||
|
||||
log_str += (f'eta: 0:00:40 '
|
||||
f"time: {train_logs['time']:.3f} "
|
||||
f"data_time: {train_logs['data_time']:.3f} ")
|
||||
|
||||
if torch.cuda.is_available():
|
||||
log_str += 'memory: 100 '
|
||||
|
||||
if mode == 'train':
|
||||
log_str += f"loss_cls: {train_logs['loss_cls']:.4f}"
|
||||
assert out == log_str
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
'by_epoch,mode',
|
||||
([True, 'val'], [False, 'val'], [True, 'test'], [False, 'test']))
|
||||
def test_log_val(self, by_epoch, mode):
|
||||
# Prepare LoggerHook
|
||||
log_processor = LogProcessor(by_epoch=by_epoch)
|
||||
# Prepare validation information.
|
||||
val_logs = dict(accuracy=0.9, data_time=1.0)
|
||||
log_processor._collect_scalars = MagicMock(return_value=val_logs)
|
||||
_, out = log_processor.get_log_after_epoch(self.runner, 2, mode)
|
||||
if by_epoch:
|
||||
if mode == 'test':
|
||||
assert out == 'Epoch(test) [5/5] accuracy: 0.9000'
|
||||
else:
|
||||
assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000'
|
||||
else:
|
||||
if mode == 'test':
|
||||
assert out == 'Iter(test) [5/5] accuracy: 0.9000'
|
||||
else:
|
||||
assert out == 'Iter(val) [10/10] accuracy: 0.9000'
|
||||
|
||||
def test_collect_scalars(self):
|
||||
custom_cfg = [
|
||||
dict(data_src='time', method_name='mean', window_size=100),
|
||||
dict(data_src='time', method_name='max', log_name='time_max')
|
||||
]
|
||||
logger_hook = LogProcessor(custom_cfg=custom_cfg)
|
||||
# Collect with prefix.
|
||||
log_scalars = {
|
||||
'train/time': MagicMock(),
|
||||
'lr': MagicMock(),
|
||||
'train/loss_cls': MagicMock(),
|
||||
'val/metric': MagicMock()
|
||||
}
|
||||
self.runner.message_hub._log_scalars = log_scalars
|
||||
tag = logger_hook._collect_scalars(
|
||||
copy.deepcopy(custom_cfg), self.runner, mode='train')
|
||||
# Test training key in tag.
|
||||
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
|
||||
# Test statistics lr with `current`, loss and time with 'mean'
|
||||
log_scalars['train/time'].statistics.assert_called_with(
|
||||
method_name='max')
|
||||
log_scalars['train/loss_cls'].mean.assert_called()
|
||||
|
||||
tag = logger_hook._collect_scalars(
|
||||
copy.deepcopy(custom_cfg), self.runner, mode='val')
|
||||
assert list(tag.keys()) == ['metric']
|
||||
log_scalars['val/metric'].current.assert_called()
|
||||
|
||||
@patch('torch.cuda.max_memory_allocated', MagicMock())
|
||||
@patch('torch.cuda.reset_peak_memory_stats', MagicMock())
|
||||
def test_get_max_memory(self):
|
||||
logger_hook = LogProcessor()
|
||||
runner = MagicMock()
|
||||
runner.world_size = 1
|
||||
runner.model = torch.nn.Linear(1, 1)
|
||||
logger_hook._get_max_memory(runner)
|
||||
torch.cuda.max_memory_allocated.assert_called()
|
||||
torch.cuda.reset_peak_memory_stats.assert_called()
|
||||
|
||||
def test_get_iter(self):
|
||||
log_processor = LogProcessor()
|
||||
# Get global iter when `inner_iter=False`
|
||||
iter = log_processor._get_iter(self.runner)
|
||||
assert iter == 11
|
||||
# Get inner iter
|
||||
iter = log_processor._get_iter(self.runner, 1)
|
||||
assert iter == 2
|
||||
# Still get global iter when `logger_hook.by_epoch==False`
|
||||
log_processor.by_epoch = False
|
||||
iter = log_processor._get_iter(self.runner, 1)
|
||||
assert iter == 11
|
||||
|
||||
def test_get_epoch(self):
|
||||
log_processor = LogProcessor()
|
||||
epoch = log_processor._get_epoch(self.runner, 'train')
|
||||
assert epoch == 2
|
||||
epoch = log_processor._get_epoch(self.runner, 'val')
|
||||
assert epoch == 1
|
||||
with pytest.raises(ValueError):
|
||||
log_processor._get_epoch(self.runner, 'test')
|
||||
|
||||
def test_get_cur_loop(self):
|
||||
log_processor = LogProcessor()
|
||||
loop = log_processor._get_cur_loop(self.runner, 'train')
|
||||
assert len(loop.dataloader) == 20
|
||||
loop = log_processor._get_cur_loop(self.runner, 'val')
|
||||
assert len(loop.dataloader) == 10
|
||||
loop = log_processor._get_cur_loop(self.runner, 'test')
|
||||
assert len(loop.dataloader) == 5
|
||||
|
||||
def setup(self):
|
||||
runner = MagicMock()
|
||||
runner.epoch = 1
|
||||
runner.iter = 10
|
||||
runner.train_loop.max_iters = 50
|
||||
runner.train_loop.dataloader = [0] * 20
|
||||
runner.val_loop.dataloader = [0] * 10
|
||||
runner.test_loop.dataloader = [0] * 5
|
||||
logger = MMLogger.get_instance('log_processor_test')
|
||||
runner.logger = logger
|
||||
message_hub = MessageHub.get_instance('log_processor_test')
|
||||
for i in range(10):
|
||||
message_hub.update_scalar('train/loss', 10 - i)
|
||||
for i in range(10):
|
||||
message_hub.update_scalar('val/acc', i * 0.1)
|
||||
runner.message_hub = message_hub
|
||||
self.runner = runner
|
|
@ -221,7 +221,7 @@ class TestRunner(TestCase):
|
|||
self.iter_based_cfg.default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False),
|
||||
logger=dict(type='LoggerHook', by_epoch=False),
|
||||
logger=dict(type='LoggerHook'),
|
||||
optimizer=dict(type='OptimizerHook', grad_clip=None),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'))
|
||||
|
||||
|
|
Loading…
Reference in New Issue