mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
* align the evaluation result in log * align the evaluation result in log * align the evaluation result in log * align the evaluation result in log * fix test log_processor
560 lines
23 KiB
Python
560 lines
23 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import copy
|
|
import datetime
|
|
from collections import OrderedDict
|
|
from itertools import chain
|
|
from typing import List, Optional, Tuple
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmengine.device import get_max_cuda_memory, is_cuda_available
|
|
from mmengine.registry import LOG_PROCESSORS
|
|
|
|
|
|
@LOG_PROCESSORS.register_module() # type: ignore
|
|
class LogProcessor:
|
|
"""A log processor used to format log information collected from
|
|
``runner.message_hub.log_scalars``.
|
|
|
|
``LogProcessor`` instance is built by runner and will format
|
|
``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can
|
|
directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument
|
|
``custom_cfg`` of constructor can control the statistics method of logs.
|
|
|
|
Args:
|
|
window_size (int): default smooth interval Defaults to 10.
|
|
by_epoch (bool): Whether to format logs with epoch stype. Defaults to
|
|
True.
|
|
custom_cfg (list[dict], optional): Contains multiple log config dict,
|
|
in which key means the data source name of log and value means the
|
|
statistic method and corresponding arguments used to count the
|
|
data source. Defaults to None.
|
|
|
|
- If custom_cfg is None, all logs will be formatted via default
|
|
methods, such as smoothing loss by default window_size. If
|
|
custom_cfg is defined as a list of config dict, for example:
|
|
[dict(data_src=loss, method='mean', log_name='global_loss',
|
|
window_size='global')]. It means the log item ``loss`` will be
|
|
counted as global mean and additionally logged as ``global_loss``
|
|
(defined by ``log_name``). If ``log_name`` is not defined in
|
|
config dict, the original logged key will be overwritten.
|
|
|
|
- The original log item cannot be overwritten twice. Here is
|
|
an error example:
|
|
[dict(data_src=loss, method='mean', window_size='global'),
|
|
dict(data_src=loss, method='mean', window_size='epoch')].
|
|
Both log config dict in custom_cfg do not have ``log_name`` key,
|
|
which means the loss item will be overwritten twice.
|
|
|
|
- For those statistic methods with the ``window_size`` argument,
|
|
if ``by_epoch`` is set to False, ``windows_size`` should not be
|
|
`epoch` to statistics log value by epoch.
|
|
num_digits (int): The number of significant digit shown in the
|
|
logging message.
|
|
log_with_hierarchy (bool): Whether to log with hierarchy. If it is
|
|
True, the information is written to visualizer backend such as
|
|
:obj:`LocalVisBackend` and :obj:`TensorboardBackend`
|
|
with hierarchy. For example, ``loss`` will be saved as
|
|
``train/loss``, and accuracy will be saved as ``val/accuracy``.
|
|
Defaults to False.
|
|
`New in version 0.7.0.`
|
|
|
|
Examples:
|
|
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
|
>>> # record.
|
|
>>> log_processor = dict(
|
|
>>> window_size=10,
|
|
>>> by_epoch=True,
|
|
>>> custom_cfg=[dict(data_src='loss',
|
|
>>> log_name='loss_large_window',
|
|
>>> method_name='mean',
|
|
>>> window_size=100)])
|
|
>>> # `log_name` is not defined. `loss` will be overwritten.
|
|
>>> log_processor = dict(
|
|
>>> window_size=10,
|
|
>>> by_epoch=True,
|
|
>>> custom_cfg=[dict(data_src='loss',
|
|
>>> method_name='mean',
|
|
>>> window_size=100)])
|
|
>>> # Record loss with different statistics methods.
|
|
>>> log_processor = dict(
|
|
>>> window_size=10,
|
|
>>> by_epoch=True,
|
|
>>> custom_cfg=[dict(data_src='loss',
|
|
>>> log_name='loss_large_window',
|
|
>>> method_name='mean',
|
|
>>> window_size=100),
|
|
>>> dict(data_src='loss',
|
|
>>> method_name='mean',
|
|
>>> window_size=100)])
|
|
>>> # Overwrite loss item twice will raise an error.
|
|
>>> log_processor = dict(
|
|
>>> window_size=10,
|
|
>>> by_epoch=True,
|
|
>>> custom_cfg=[dict(data_src='loss',
|
|
>>> method_name='mean',
|
|
>>> window_size=100),
|
|
>>> dict(data_src='loss',
|
|
>>> method_name='max',
|
|
>>> window_size=100)])
|
|
AssertionError
|
|
"""
|
|
|
|
def __init__(self,
|
|
window_size=10,
|
|
by_epoch=True,
|
|
custom_cfg: Optional[List[dict]] = None,
|
|
num_digits: int = 4,
|
|
log_with_hierarchy: bool = False):
|
|
self.window_size = window_size
|
|
self.by_epoch = by_epoch
|
|
self.custom_cfg = custom_cfg if custom_cfg else []
|
|
self.num_digits = num_digits
|
|
self.log_with_hierarchy = log_with_hierarchy
|
|
self._check_custom_cfg()
|
|
|
|
def get_log_after_iter(self, runner, batch_idx: int,
|
|
mode: str) -> Tuple[dict, str]:
|
|
"""Format log string after training, validation or testing epoch.
|
|
|
|
Args:
|
|
runner (Runner): The runner of training phase.
|
|
batch_idx (int): The index of the current batch in the current
|
|
loop.
|
|
mode (str): Current mode of runner, train, test or val.
|
|
|
|
Return:
|
|
Tuple(dict, str): Formatted log dict/string which will be
|
|
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
|
"""
|
|
assert mode in ['train', 'test', 'val']
|
|
cur_iter = self._get_iter(runner, batch_idx=batch_idx)
|
|
# Overwrite ``window_size`` defined in ``custom_cfg`` to int value.
|
|
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
|
self.custom_cfg)
|
|
# log_tag is used to write log information to terminal
|
|
# If `self.log_with_hierarchy` is False, the tag is the same as
|
|
# log_tag. Otherwise, each key in tag starts with prefix `train`,
|
|
# `test` or `val`
|
|
log_tag = self._collect_scalars(parsed_cfg, runner, mode)
|
|
|
|
if not self.log_with_hierarchy:
|
|
tag = copy.deepcopy(log_tag)
|
|
else:
|
|
tag = self._collect_scalars(parsed_cfg, runner, mode, True)
|
|
|
|
# Record learning rate.
|
|
lr_str_list = []
|
|
for key, value in tag.items():
|
|
if key.endswith('lr'):
|
|
key = self._remove_prefix(key, f'{mode}/')
|
|
log_tag.pop(key)
|
|
lr_str_list.append(f'{key}: '
|
|
f'{value:.{self.num_digits}e}')
|
|
lr_str = ' '.join(lr_str_list)
|
|
# Format log header.
|
|
# by_epoch == True
|
|
# train/val: Epoch [5][5/10] ...
|
|
# test: Epoch [5/10]
|
|
# by_epoch == False
|
|
# train: Epoch [5/10000] ... (divided by `max_iter`)
|
|
# val/test: Epoch [5/2000] ... (divided by length of dataloader)
|
|
if self.by_epoch:
|
|
# Align the iteration log:
|
|
# Epoch(train) [ 9][010/270]
|
|
# ... ||| |||
|
|
# Epoch(train) [ 10][100/270]
|
|
dataloader_len = self._get_dataloader_size(runner, mode)
|
|
cur_iter_str = str(cur_iter).rjust(len(str(dataloader_len)))
|
|
|
|
if mode in ['train', 'val']:
|
|
# Right Align the epoch log:
|
|
# Epoch(train) [9][100/270]
|
|
# ... ||
|
|
# Epoch(train) [100][100/270]
|
|
cur_epoch = self._get_epoch(runner, mode)
|
|
max_epochs = runner.max_epochs
|
|
# 3 means the three characters: "[", "]", and " " occupied in
|
|
# " [{max_epochs}]"
|
|
cur_epoch_str = f'[{cur_epoch}]'.rjust(
|
|
len(str(max_epochs)) + 3, ' ')
|
|
tag['epoch'] = cur_epoch
|
|
log_str = (f'Epoch({mode}){cur_epoch_str}'
|
|
f'[{cur_iter_str}/{dataloader_len}] ')
|
|
else:
|
|
log_str = (f'Epoch({mode}) '
|
|
f'[{cur_iter_str}/{dataloader_len}] ')
|
|
else:
|
|
if mode == 'train':
|
|
cur_iter_str = str(cur_iter).rjust(len(str(runner.max_iters)))
|
|
log_str = (f'Iter({mode}) '
|
|
f'[{cur_iter_str}/{runner.max_iters}] ')
|
|
else:
|
|
dataloader_len = self._get_dataloader_size(runner, mode)
|
|
cur_iter_str = str(batch_idx + 1).rjust(
|
|
len(str(dataloader_len)))
|
|
log_str = (f'Iter({mode}) [{cur_iter_str}/{dataloader_len}] ')
|
|
# Concatenate lr, momentum string with log header.
|
|
log_str += f'{lr_str} '
|
|
# If IterTimerHook used in runner, eta, time, and data_time should be
|
|
# recorded.
|
|
if (all(item in log_tag for item in ['time', 'data_time'])
|
|
and 'eta' in runner.message_hub.runtime_info):
|
|
eta = runner.message_hub.get_info('eta')
|
|
eta_str = str(datetime.timedelta(seconds=int(eta)))
|
|
log_str += f'eta: {eta_str} '
|
|
log_str += (f'time: {log_tag["time"]:.{self.num_digits}f} '
|
|
f'data_time: '
|
|
f'{log_tag["data_time"]:.{self.num_digits}f} ')
|
|
# Pop recorded keys
|
|
log_tag.pop('time')
|
|
log_tag.pop('data_time')
|
|
|
|
# If cuda is available, the max memory occupied should be calculated.
|
|
if is_cuda_available():
|
|
max_memory = self._get_max_memory(runner)
|
|
log_str += f'memory: {max_memory} '
|
|
tag['memory'] = max_memory
|
|
# Loop left keys to fill `log_str`.
|
|
if mode in ('train', 'val'):
|
|
log_items = []
|
|
for name, val in log_tag.items():
|
|
if mode == 'val' and not name.startswith('val/loss'):
|
|
continue
|
|
if isinstance(val, float):
|
|
val = f'{val:.{self.num_digits}f}'
|
|
log_items.append(f'{name}: {val}')
|
|
log_str += ' '.join(log_items)
|
|
return tag, log_str
|
|
|
|
def get_log_after_epoch(self,
|
|
runner,
|
|
batch_idx: int,
|
|
mode: str,
|
|
with_non_scalar: bool = False) -> Tuple[dict, str]:
|
|
"""Format log string after validation or testing epoch.
|
|
|
|
Args:
|
|
runner (Runner): The runner of validation/testing phase.
|
|
batch_idx (int): The index of the current batch in the current
|
|
loop.
|
|
mode (str): Current mode of runner.
|
|
with_non_scalar (bool): Whether to include non-scalar infos in the
|
|
returned tag. Defaults to False.
|
|
|
|
Return:
|
|
Tuple(dict, str): Formatted log dict/string which will be
|
|
recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`.
|
|
"""
|
|
assert mode in [
|
|
'test', 'val'
|
|
], ('`_get_metric_log_str` only accept val or test mode, but got '
|
|
f'{mode}')
|
|
dataloader_len = self._get_dataloader_size(runner, mode)
|
|
|
|
# By epoch:
|
|
# Epoch(val) [10][1000/1000] ...
|
|
# Epoch(test) [1000/1000] ...
|
|
# By iteration:
|
|
# Iteration(val) [1000/1000] ...
|
|
# Iteration(test) [1000/1000] ...
|
|
if self.by_epoch:
|
|
if mode == 'val':
|
|
cur_epoch = self._get_epoch(runner, mode)
|
|
log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/'
|
|
f'{dataloader_len}] ')
|
|
else:
|
|
log_str = (
|
|
f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ')
|
|
|
|
else:
|
|
log_str = (f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ')
|
|
|
|
custom_cfg_copy = copy.deepcopy(self.custom_cfg)
|
|
# remove prefix
|
|
custom_keys = [
|
|
self._remove_prefix(cfg['data_src'], f'{mode}/')
|
|
for cfg in custom_cfg_copy
|
|
]
|
|
# Count the averaged time and data_time by epoch
|
|
if 'time' not in custom_keys:
|
|
custom_cfg_copy.append(
|
|
dict(
|
|
data_src=f'{mode}/time',
|
|
window_size='epoch',
|
|
method_name='mean'))
|
|
if 'data_time' not in custom_keys:
|
|
custom_cfg_copy.append(
|
|
dict(
|
|
data_src=f'{mode}/data_time',
|
|
window_size='epoch',
|
|
method_name='mean'))
|
|
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
|
custom_cfg_copy)
|
|
# tag is used to write log information to different backends.
|
|
ori_tag = self._collect_scalars(parsed_cfg, runner, mode,
|
|
self.log_with_hierarchy)
|
|
non_scalar_tag = self._collect_non_scalars(runner, mode)
|
|
# move `time` or `data_time` to the end of the log
|
|
tag = OrderedDict()
|
|
time_tag = OrderedDict()
|
|
for key, value in ori_tag.items():
|
|
if key in (f'{mode}/time', f'{mode}/data_time', 'time',
|
|
'data_time'):
|
|
time_tag[key] = value
|
|
else:
|
|
tag[key] = value
|
|
# Log other messages.
|
|
log_items = []
|
|
log_str += ' '
|
|
for name, val in chain(tag.items(), non_scalar_tag.items(),
|
|
time_tag.items()):
|
|
if isinstance(val, float):
|
|
val = f'{val:.{self.num_digits}f}'
|
|
if isinstance(val, (torch.Tensor, np.ndarray)):
|
|
# newline to display tensor and array.
|
|
val = f'\n{val}\n'
|
|
log_items.append(f'{name}: {val}')
|
|
log_str += ' '.join(log_items)
|
|
|
|
if with_non_scalar:
|
|
tag.update(non_scalar_tag)
|
|
tag.update(time_tag)
|
|
return tag, log_str
|
|
|
|
def _collect_scalars(self,
|
|
custom_cfg: List[dict],
|
|
runner,
|
|
mode: str,
|
|
reserve_prefix: bool = False) -> dict:
|
|
"""Collect log information to compose a dict according to mode.
|
|
|
|
Args:
|
|
custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int
|
|
``window_size``.
|
|
runner (Runner): The runner of the training/testing/validation
|
|
process.
|
|
mode (str): Current mode of runner.
|
|
reserve_prefix (bool): Whether to reserve the prefix of the key.
|
|
|
|
Returns:
|
|
dict: Statistical values of logs.
|
|
"""
|
|
tag = OrderedDict()
|
|
# history_scalars of train/val/test phase.
|
|
history_scalars = runner.message_hub.log_scalars
|
|
# corresponding mode history_scalars
|
|
mode_history_scalars = OrderedDict()
|
|
# extract log scalars and remove prefix to `mode_history_scalars`
|
|
# according to mode.
|
|
for prefix_key, log_buffer in history_scalars.items():
|
|
if prefix_key.startswith(mode):
|
|
if not reserve_prefix:
|
|
key = self._remove_prefix(prefix_key, f'{mode}/')
|
|
else:
|
|
key = prefix_key
|
|
mode_history_scalars[key] = log_buffer
|
|
for key in mode_history_scalars:
|
|
# Update the latest learning rate and smoothed time logs.
|
|
if 'loss' in key or key in ('time', 'data_time', 'grad_norm'):
|
|
tag[key] = mode_history_scalars[key].mean(self.window_size)
|
|
else:
|
|
# Default statistic method is current.
|
|
tag[key] = mode_history_scalars[key].current()
|
|
# Update custom keys.
|
|
for log_cfg in custom_cfg:
|
|
data_src = log_cfg.pop('data_src')
|
|
if 'log_name' in log_cfg:
|
|
log_name = log_cfg.pop('log_name')
|
|
else:
|
|
log_name = data_src
|
|
# log item in custom_cfg could only exist in train or val
|
|
# mode.
|
|
if data_src in mode_history_scalars:
|
|
tag[log_name] = mode_history_scalars[data_src].statistics(
|
|
**log_cfg)
|
|
return tag
|
|
|
|
def _collect_non_scalars(self, runner, mode: str) -> dict:
|
|
"""Collect log information to compose a dict according to mode.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training/testing/validation
|
|
process.
|
|
mode (str): Current mode of runner.
|
|
|
|
Returns:
|
|
dict: non-scalar infos of the specified mode.
|
|
"""
|
|
# infos of train/val/test phase.
|
|
infos = runner.message_hub.runtime_info
|
|
# corresponding mode infos
|
|
mode_infos = OrderedDict()
|
|
# extract log info and remove prefix to `mode_infos` according to mode.
|
|
for prefix_key, value in infos.items():
|
|
if prefix_key.startswith(mode):
|
|
if self.log_with_hierarchy:
|
|
key = prefix_key
|
|
else:
|
|
key = self._remove_prefix(prefix_key, f'{mode}/')
|
|
mode_infos[key] = value
|
|
return mode_infos
|
|
|
|
def _remove_prefix(self, string: str, prefix: str):
|
|
"""Remove the prefix ``train``, ``val`` and ``test`` of the key."""
|
|
if string.startswith(prefix):
|
|
return string[len(prefix):]
|
|
else:
|
|
return string
|
|
|
|
def _check_custom_cfg(self) -> None:
|
|
"""Check the legality of ``self.custom_cfg``."""
|
|
|
|
def _check_window_size():
|
|
for log_cfg in self.custom_cfg:
|
|
if not self.by_epoch:
|
|
assert log_cfg['window_size'] != 'epoch', \
|
|
'window_size cannot be epoch if LoggerHook.by_epoch' \
|
|
' is False.'
|
|
|
|
def _check_repeated_log_name():
|
|
# The `log_name` of the same data_src should not be repeated.
|
|
# If `log_name` is not specified, `data_src` will be overwritten.
|
|
# But only allowed to be overwritten once.
|
|
check_set = set()
|
|
for log_cfg in self.custom_cfg:
|
|
assert 'data_src' in log_cfg
|
|
data_src = log_cfg['data_src']
|
|
log_name = log_cfg.get('log_name', data_src)
|
|
assert log_name not in check_set, (
|
|
f'Found duplicate {log_name} for {data_src}. Please check'
|
|
'your `custom_cfg` for `log_processor`. You should '
|
|
f'neither define duplicate `{log_name}` for {data_src} '
|
|
f'nor do not define any {log_name} for multiple '
|
|
f'{data_src}, See more information in the docstring of '
|
|
'LogProcessor')
|
|
|
|
check_set.add(log_name)
|
|
|
|
_check_repeated_log_name()
|
|
_check_window_size()
|
|
|
|
def _parse_windows_size(self,
|
|
runner,
|
|
batch_idx: int,
|
|
custom_cfg: Optional[list] = None) -> list:
|
|
"""Parse window_size defined in custom_cfg to int value.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training/testing/validation
|
|
process.
|
|
batch_idx (int): The iteration index of current dataloader.
|
|
custom_cfg (list): A copy of ``self.custom_cfg``. Defaults to None
|
|
to keep backward compatibility.
|
|
"""
|
|
if custom_cfg is None:
|
|
custom_cfg = copy.deepcopy(self.custom_cfg)
|
|
else:
|
|
custom_cfg = copy.deepcopy(custom_cfg)
|
|
for log_cfg in custom_cfg:
|
|
window_size = log_cfg.get('window_size', None)
|
|
if window_size is None or isinstance(window_size, int):
|
|
continue
|
|
elif window_size == 'epoch':
|
|
log_cfg['window_size'] = batch_idx + 1
|
|
elif window_size == 'global':
|
|
log_cfg['window_size'] = runner.iter + 1
|
|
else:
|
|
raise TypeError(
|
|
'window_size should be int, epoch or global, but got '
|
|
f'invalid {window_size}')
|
|
return custom_cfg
|
|
|
|
def _get_max_memory(self, runner) -> int:
|
|
"""Returns the maximum GPU memory occupied by tensors in megabytes (MB)
|
|
for a given device.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training/testing/validation
|
|
process.
|
|
|
|
Returns:
|
|
The maximum GPU memory occupied by tensors in megabytes for a given
|
|
device.
|
|
"""
|
|
|
|
device = getattr(runner.model, 'output_device', None)
|
|
return get_max_cuda_memory(device)
|
|
|
|
def _get_iter(self, runner, batch_idx: int = None) -> int:
|
|
"""Get current iteration index.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training/testing/validation
|
|
process.
|
|
batch_idx (int, optional): The iteration index of current
|
|
dataloader. Defaults to None.
|
|
|
|
Returns:
|
|
int: The current global iter or inner iter.
|
|
"""
|
|
if self.by_epoch and batch_idx is not None:
|
|
current_iter = batch_idx + 1
|
|
else:
|
|
current_iter = runner.iter + 1
|
|
return current_iter
|
|
|
|
def _get_epoch(self, runner, mode: str) -> int:
|
|
"""Get current epoch according to mode.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training/testing/validation
|
|
process.
|
|
mode (str): Current mode of runner.
|
|
|
|
Returns:
|
|
int: The current epoch.
|
|
"""
|
|
if mode == 'train':
|
|
epoch = runner.epoch + 1
|
|
elif mode == 'val':
|
|
# normal val mode
|
|
# runner.epoch += 1 has been done before validation
|
|
epoch = runner.epoch
|
|
else:
|
|
raise ValueError(
|
|
f"runner mode should be 'train' or 'val', but got {mode}")
|
|
return epoch
|
|
|
|
def _get_cur_loop(self, runner, mode: str):
|
|
"""Get current loop according to mode.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training/validation/testing
|
|
process.
|
|
mode (str): Current mode of runner.
|
|
|
|
Returns:
|
|
BaseLoop: Current loop of runner.
|
|
"""
|
|
# returns type hint will occur circular import
|
|
if mode == 'train':
|
|
return runner.train_loop
|
|
elif mode == 'val':
|
|
return runner.val_loop
|
|
else:
|
|
return runner.test_loop
|
|
|
|
def _get_dataloader_size(self, runner, mode) -> int:
|
|
"""Get dataloader size of current loop.
|
|
|
|
Args:
|
|
runner (Runner): The runner of the training/validation/testing
|
|
mode (str): Current mode of runner.
|
|
|
|
Returns:
|
|
int: The dataloader size of current loop.
|
|
"""
|
|
return len(self._get_cur_loop(runner=runner, mode=mode).dataloader)
|