[Enhancement] Refine logging. (#148)

* rename global accessible and intergration get_sintance and create_instance

* move ManagerMixin to utils

* fix as docstring and seporate get_instance to get_instance and get_current_instance

* fix lint

* fix docstring, rename and move test_global_meta

* rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume

* refine MMLogger timestamp, update unit test

* MMLogger add logger_name arguments

* Fix docstring

* change default logger_name to mmengine

* Fix docstring comment and unitt test

* fix docstring

fix docstring

* fix docstring

* Fix lint

* Fix hook unit test

* Fix comments

* should not accept other arguments if corresponding instance has been created

* fix logging ddp file saving

* fix logging ddp file saving

* fix docstring

* fix unit test

* fix docstring as comment
pull/192/head
Mashiro 2022-04-21 19:12:10 +08:00 committed by GitHub
parent 45567b1d1c
commit 82a313d09b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 479 additions and 174 deletions

View File

@ -42,8 +42,8 @@ class IterTimerHook(Hook):
mode (str): Current mode of runner. Defaults to 'train'.
"""
# TODO: update for new logging system
runner.message_hub.update_log(f'{mode}/data_time',
time.time() - self.t)
runner.message_hub.update_scalar(f'{mode}/data_time',
time.time() - self.t)
def _after_iter(self,
runner,
@ -65,5 +65,5 @@ class IterTimerHook(Hook):
"""
# TODO: update for new logging system
runner.message_hub.update_log(f'{mode}/time', time.time() - self.t)
runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
self.t = time.time()

View File

@ -370,7 +370,7 @@ class LoggerHook(Hook):
dict: Statistical values of logs.
"""
tag = OrderedDict()
log_buffers = runner.message_hub.log_buffers
log_buffers = runner.message_hub.log_scalars
mode_log_buffers = OrderedDict()
# Filter log_buffers which starts with `mode`.
for prefix_key, log_buffer in log_buffers.items():

View File

@ -1,6 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .log_buffer import LogBuffer
from .history_buffer import HistoryBuffer
from .logger import MMLogger, print_log
from .message_hub import MessageHub
__all__ = ['LogBuffer', 'MessageHub', 'MMLogger', 'print_log']
__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']

View File

@ -1,16 +1,28 @@
# Copyright (c) OpenMMLab. All rights reserved.
import warnings
from functools import partial
from typing import Any, Callable, Optional, Sequence, Tuple, Union
import numpy as np
class BaseLogBuffer:
class HistoryBuffer:
"""Unified storage format for different log types.
Record the history of log for further statistics. The subclass inherited
from ``BaseLogBuffer`` will implement the specific statistical methods.
``HistoryBuffer`` records the history of log for further statistics.
Examples:
>>> history_buffer = HistoryBuffer()
>>> # Update history_buffer.
>>> history_buffer.update(1)
>>> history_buffer.update(2)
>>> history_buffer.min() # minimum of (1, 2)
1
>>> history_buffer.max() # maximum of (1, 2)
2
>>> history_buffer.mean() # mean of (1, 2)
1.5
>>> history_buffer.statistics('mean') # access method by string.
1.5
Args:
log_history (Sequence): History logs. Defaults to [].
@ -25,6 +37,7 @@ class BaseLogBuffer:
max_length: int = 1000000):
self.max_length = max_length
self._set_default_statistics()
assert len(log_history) == len(count_history), \
'The lengths of log_history and count_histroy should be equal'
if len(log_history) > max_length:
@ -37,10 +50,18 @@ class BaseLogBuffer:
self._log_history = np.array(log_history)
self._count_history = np.array(count_history)
def _set_default_statistics(self) -> None:
"""Register default statistic methods: min, max, current and mean."""
self._statistics_methods.setdefault('min', HistoryBuffer.min)
self._statistics_methods.setdefault('max', HistoryBuffer.max)
self._statistics_methods.setdefault('current', HistoryBuffer.current)
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
def update(self, log_val: Union[int, float], count: int = 1) -> None:
"""update the log history. If the length of the buffer exceeds
``self._max_length``, the oldest element will be removed from the
buffer.
"""update the log history.
If the length of the buffer exceeds ``self._max_length``, the oldest
element will be removed from the buffer.
Args:
log_val (int or float): The value of log.
@ -72,9 +93,23 @@ class BaseLogBuffer:
def register_statistics(cls, method: Callable) -> Callable:
"""Register custom statistics method to ``_statistics_methods``.
The registered method can be called by ``history_buffer.statistics``
with corresponding method name and arguments.
Examples:
>>> @HistoryBuffer.register_statistics
>>> def weighted_mean(self, window_size, weight):
>>> assert len(weight) == window_size
>>> return (self._log_history[-window_size:] *
>>> np.array(weight)).sum() / \
>>> self._count_history[-window_size:]
>>> log_buffer = HistoryBuffer([1, 2], [1, 1])
>>> log_buffer.statistics('weighted_mean', 2, [2, 1])
2
Args:
method (Callable): Custom statistics method.
Returns:
Callable: Original custom statistics method.
"""
@ -95,26 +130,20 @@ class BaseLogBuffer:
"""
if method_name not in self._statistics_methods:
raise KeyError(f'{method_name} has not been registered in '
'BaseLogBuffer._statistics_methods')
'HistoryBuffer._statistics_methods')
method = self._statistics_methods[method_name]
# Provide self arguments for registered functions.
method = partial(method, self)
return method(*arg, **kwargs)
return method(self, *arg, **kwargs)
class LogBuffer(BaseLogBuffer):
"""``LogBuffer`` inherits from ``BaseLogBuffer`` and provides some basic
statistics methods, such as ``min``, ``max``, ``current`` and ``mean``."""
@BaseLogBuffer.register_statistics
def mean(self, window_size: Optional[int] = None) -> np.ndarray:
"""Return the mean of the latest ``window_size`` values in log
histories. If ``window_size is None``, return the global mean of
history logs.
histories.
If ``window_size is None`` or ``window_size > len(self._log_history)``,
return the global mean value of history logs.
Args:
window_size (int, optional): Size of statistics window.
Returns:
np.ndarray: Mean value within the window.
"""
@ -128,15 +157,15 @@ class LogBuffer(BaseLogBuffer):
counts_sum = self._count_history[-window_size:].sum()
return logs_sum / counts_sum
@BaseLogBuffer.register_statistics
def max(self, window_size: Optional[int] = None) -> np.ndarray:
"""Return the maximum value of the latest ``window_size`` values in log
histories. If ``window_size is None``, return the global maximum value
of history logs.
histories.
If ``window_size is None`` or ``window_size > len(self._log_history)``,
return the global maximum value of history logs.
Args:
window_size (int, optional): Size of statistics window.
Returns:
np.ndarray: The maximum value within the window.
"""
@ -148,15 +177,15 @@ class LogBuffer(BaseLogBuffer):
window_size = len(self._log_history)
return self._log_history[-window_size:].max()
@BaseLogBuffer.register_statistics
def min(self, window_size: Optional[int] = None) -> np.ndarray:
"""Return the minimum value of the latest ``window_size`` values in log
histories. If ``window_size is None``, return the global minimum value
of history logs.
histories.
If ``window_size is None`` or ``window_size > len(self._log_history)``,
return the global minimum value of history logs.
Args:
window_size (int, optional): Size of statistics window.
Returns:
np.ndarray: The minimum value within the window.
"""
@ -168,7 +197,6 @@ class LogBuffer(BaseLogBuffer):
window_size = len(self._log_history)
return self._log_history[-window_size:].min()
@BaseLogBuffer.register_statistics
def current(self) -> np.ndarray:
"""Return the recently updated values in log histories.
@ -176,6 +204,6 @@ class LogBuffer(BaseLogBuffer):
np.ndarray: Recently updated values in log histories.
"""
if len(self._log_history) == 0:
raise ValueError('LogBuffer._log_history is an empty array! '
raise ValueError('HistoryBuffer._log_history is an empty array! '
'please call update first')
return self._log_history[-1]

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import logging
import os
import os.path as osp
import sys
from logging import Logger, LogRecord
from typing import Optional, Union
@ -8,7 +9,7 @@ from typing import Optional, Union
import torch.distributed as dist
from termcolor import colored
from mmengine.utils import ManagerMixin
from mmengine.utils import ManagerMixin, mkdir_or_exist
class MMFormatter(logging.Formatter):
@ -85,54 +86,134 @@ class MMFormatter(logging.Formatter):
class MMLogger(Logger, ManagerMixin):
"""The Logger manager which can create formatted logger and get specified
logger globally. MMLogger is created and accessed in the same way as
ManagerMixin.
"""Formatted logger used to record messages.
``MMLogger`` can create formatted logger to log message with different
log levels and get instance in the same way as ``ManagerMixin``.
``MMLogger`` has the following features:
- Distributed log storage, ``MMLogger`` can choose whether to save log of
different ranks according to `log_file`.
- Message with different log levels will have different colors and format
when displayed on terminal.
Note:
- The `name` of logger and the ``instance_name`` of ``MMLogger`` could
be different. We can only get ``MMLogger`` instance by
``MMLogger.get_instance`` but not ``logging.getLogger``. This feature
ensures ``MMLogger`` will not be incluenced by third-party logging
config.
- Different from ``logging.Logger``, ``MMLogger`` will not log warrning
or error message without ``Handler``.
- If `log_file=/path/to/tmp.log`, all logs will be saved to
`/path/to/tmp/tmp.log`
Examples:
>>> logger = MMLogger.get_instance(name='MMLogger',
>>> logger_name='Logger')
>>> # Although logger has name attribute just like `logging.Logger`
>>> # We cannot get logger instance by `logging.getLogger`.
>>> assert logger.name == 'Logger'
>>> assert logger.instance_name = 'MMLogger'
>>> assert id(logger) != id(logging.getLogger('Logger'))
>>> # Get logger that do not store logs.
>>> logger1 = MMLogger.get_instance('logger1')
>>> # Get logger only save rank0 logs.
>>> logger2 = MMLogger.get_instance('logger2', log_file='out.log')
>>> # Get logger only save multiple ranks logs.
>>> logger3 = MMLogger.get_instance('logger3', log_file='out.log',
>>> distributed=True)
Args:
name (str): Logger name. Defaults to ''.
name (str): Global instance name.
logger_name (str): ``name`` attribute of ``Logging.Logger`` instance.
If `logger_name` is not defined, defaults to 'mmengine'.
log_file (str, optional): The log filename. If specified, a
``FileHandler`` will be added to the logger. Defaults to None.
log_level: The log level of the handler. Defaults to 'NOTSET'.
file_mode (str): The file mode used in opening log file.
Defaults to 'w'.
log_level (str): The log level of the handler and logger. Defaults to
"NOTSET".
file_mode (str): The file mode used to open log file. Defaults to 'w'.
distributed (bool): Whether to save distributed logs, Defaults to
false.
"""
def __init__(self,
name: str = '',
name: str,
logger_name='mmengine',
log_file: Optional[str] = None,
log_level: str = 'NOTSET',
file_mode: str = 'w'):
Logger.__init__(self, name)
log_level: str = 'INFO',
file_mode: str = 'w',
distributed=False):
Logger.__init__(self, logger_name)
ManagerMixin.__init__(self, name)
# Get rank in DDP mode.
if dist.is_available() and dist.is_initialized():
rank = dist.get_rank()
else:
rank = 0
# Config stream_handler. If `rank != 0`. stream_handler can only
# export ERROR logs.
stream_handler = logging.StreamHandler(stream=sys.stdout)
stream_handler.setFormatter(MMFormatter(color=True))
# `StreamHandler` record month, day, hour, minute, and second
# timestamp.
stream_handler.setFormatter(
MMFormatter(color=True, datefmt='%m/%d %H:%M:%S'))
# Only rank0 `StreamHandler` will log messages below error level.
stream_handler.setLevel(log_level) if rank == 0 else \
stream_handler.setLevel(logging.ERROR)
self.handlers.append(stream_handler)
if log_file is not None:
# If `log_file=/path/to/tmp.log`, all logs will be saved to
# `/path/to/tmp/tmp.log`
log_dir = osp.dirname(log_file)
filename = osp.basename(log_file)
filename_list = filename.split('.')
sub_file_name = '.'.join(filename_list[:-1])
log_dir = osp.join(log_dir, sub_file_name)
mkdir_or_exist(log_dir)
log_file = osp.join(log_dir, filename)
if rank != 0:
# rename `log_file` with rank prefix.
# rename `log_file` with rank suffix.
path_split = log_file.split(os.sep)
path_split[-1] = f'rank{rank}_{path_split[-1]}'
if '.' in path_split[-1]:
filename_list = path_split[-1].split('.')
filename_list[-2] = f'{filename_list[-2]}_rank{rank}'
path_split[-1] = '.'.join(filename_list)
else:
path_split[-1] = f'{path_split[-1]}_rank{rank}'
log_file = os.sep.join(path_split)
# Here, the default behaviour of the official logger is 'a'. Thus,
# we provide an interface to change the file mode to the default
# behaviour. `FileHandler` is not supported to have colors,
# otherwise it will appear garbled.
file_handler = logging.FileHandler(log_file, file_mode)
file_handler.setFormatter(MMFormatter(color=False))
file_handler.setLevel(log_level)
self.handlers.append(file_handler)
# Save multi-ranks logs if distributed is True. The logs of rank0
# will always be saved.
if rank == 0 or distributed:
# Here, the default behaviour of the official logger is 'a'.
# Thus, we provide an interface to change the file mode to
# the default behaviour. `FileHandler` is not supported to
# have colors, otherwise it will appear garbled.
file_handler = logging.FileHandler(log_file, file_mode)
# `StreamHandler` record year, month, day hour, minute,
# and second timestamp. file_handler will only record logs
# without color to avoid garbled code saved in files.
file_handler.setFormatter(
MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
file_handler.setLevel(log_level)
self.handlers.append(file_handler)
def callHandlers(self, record: LogRecord) -> None:
"""Pass a record to all relevant handlers.
Override ``callHandlers`` method in ``logging.Logger`` to avoid
multiple warning messages in DDP mode. Loop through all handlers of
the logger instance and its parents in the logger hierarchy. If no
handler was found, the record will not be output.
Args:
record (LogRecord): A ``LogRecord`` instance contains logged
message.
"""
for handler in self.handlers:
if record.levelno >= handler.level:
handler.handle(record)
def print_log(msg,

View File

@ -1,14 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from collections import OrderedDict
from typing import Any, Union
from typing import Any, Optional, Union
import numpy as np
import torch
from mmengine.utils import ManagerMixin
from mmengine.visualization.utils import check_type
from .log_buffer import LogBuffer
from .history_buffer import HistoryBuffer
class MessageHub(ManagerMixin):
@ -17,84 +17,210 @@ class MessageHub(ManagerMixin):
``MessageHub`` will record log information and runtime information. The
log information refers to the learning rate, loss, etc. of the model
when training a model, which will be stored as ``LogBuffer``. The runtime
information refers to the iter times, meta information of runner etc.,
which will be overwritten by next update.
during training phase, which will be stored as ``HistoryBuffer``. The
runtime information refers to the iter times, meta information of
runner etc., which will be overwritten by next update.
Args:
name (str): Name of message hub, for global access. Defaults to ''.
name (str): Name of message hub used to get corresponding instance
globally.
log_scalars (OrderedDict, optional): Each key-value pair in the
dictionary is the name of the log information such as "loss", "lr",
"metric" and their corresponding values. The type of value must be
HistoryBuffer. Defaults to None.
runtime_info (OrderedDict, optional): Each key-value pair in the
dictionary is the name of the runtime information and their
corresponding values. Defaults to None.
resumed_keys (OrderedDict, optional): Each key-value pair in the
dictionary decides whether the key in :attr:`_log_scalars` and
:attr:`_runtime_info` will be serialized.
Note:
Key in :attr:`_resumed_keys` belongs to :attr:`_log_scalars` or
:attr:`_runtime_info`. The corresponding value cannot be set
repeatedly.
Examples:
>>> # create empty `MessageHub`.
>>> message_hub1 = MessageHub()
>>> log_scalars = OrderedDict(loss=HistoryBuffer())
>>> runtime_info = OrderedDict(task='task')
>>> resumed_keys = dict(loss=True)
>>> # create `MessageHub` from data.
>>> message_hub2 = MessageHub(
>>> name='name',
>>> log_scalars=log_scalars,
>>> runtime_info=runtime_info,
>>> resumed_keys=resumed_keys)
"""
def __init__(self, name: str = ''):
self._log_buffers: OrderedDict = OrderedDict()
self._runtime_info: OrderedDict = OrderedDict()
def __init__(self,
name: str,
log_scalars: Optional[OrderedDict] = None,
runtime_info: Optional[OrderedDict] = None,
resumed_keys: Optional[OrderedDict] = None):
super().__init__(name)
self._log_scalars = log_scalars if log_scalars is not None else \
OrderedDict()
self._runtime_info = runtime_info if runtime_info is not None else \
OrderedDict()
self._resumed_keys = resumed_keys if resumed_keys is not None else \
OrderedDict()
def update_log(self, key: str, value: Union[int, float], count: int = 1) \
-> None:
"""Update log buffer.
assert isinstance(self._log_scalars, OrderedDict)
assert isinstance(self._runtime_info, OrderedDict)
assert isinstance(self._resumed_keys, OrderedDict)
for value in self._log_scalars.values():
assert isinstance(value, HistoryBuffer), \
("The type of log_scalars'value must be HistoryBuffer, but "
f'got {type(value)}')
for key in self._resumed_keys.keys():
assert key in self._log_scalars or key in self._runtime_info, \
('Key in `resumed_keys` must contained in `log_scalars` or '
f'`runtime_info`, but got {key}')
def update_scalar(self,
key: str,
value: Union[int, float, np.ndarray, torch.Tensor],
count: int = 1,
resumed: bool = True) -> None:
"""Update :attr:_log_scalars.
Update ``HistoryBuffer`` in :attr:`_log_scalars`. If corresponding key
``HistoryBuffer`` has been created, ``value`` and ``count`` is the
argument of ``HistoryBuffer.update``, Otherwise, ``update_scalar``
will create an ``HistoryBuffer`` with value and count via the
constructor of ``HistoryBuffer``.
Examples:
>>> message_hub = MessageHub
>>> # create loss `HistoryBuffer` with value=1, count=1
>>> message_hub.update_scalar('loss', 1)
>>> # update loss `HistoryBuffer` with value
>>> message_hub.update_scalar('loss', 3)
>>> message_hub.update_scalar('loss', 3, resumed=False)
AssertionError: loss used to be true, but got false now. resumed
keys cannot be modified repeatedly'
Note:
resumed cannot be set repeatedly for the same key.
Args:
key (str): Key of ``LogBuffer``.
value (int or float): Value of log.
count (int): Accumulation times of log, defaults to 1. `count`
will be used in smooth statistics.
key (str): Key of ``HistoryBuffer``.
value (torch.Tensor or np.ndarray or int or float): Value of log.
count (torch.Tensor or np.ndarray or int or float): Accumulation
times of log, defaults to 1. `count` will be used in smooth
statistics.
resumed (str): Whether the corresponding ``HistoryBuffer``
could be resumed. Defaults to True.
"""
if key in self._log_buffers:
self._log_buffers[key].update(value, count)
self._set_resumed_keys(key, resumed)
checked_value = self._get_valid_value(key, value)
assert isinstance(count, int), (
f'The type of count must be int. but got {type(count): {count}}')
if key in self._log_scalars:
self._log_scalars[key].update(checked_value, count)
else:
self._log_buffers[key] = LogBuffer([value], [count])
self._log_scalars[key] = HistoryBuffer([checked_value], [count])
def update_log_vars(self, log_dict: dict) -> None:
"""Update :attr:`_log_buffers` with a dict.
def update_scalars(self, log_dict: dict, resumed: bool = True) -> None:
"""Update :attr:`_log_scalars` with a dict.
``update_scalars`` iterates through each pair of log_dict key-value,
and calls ``update_scalar``. If type of value is dict, the value should
be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in
``log_dict`` has the same resume option.
Args:
log_dict (str): Used for batch updating :attr:`_log_buffers`.
log_dict (str): Used for batch updating :attr:`_log_scalars`.
resumed (bool): Whether all ``HistoryBuffer`` referred in
log_dict should be resumed. Defaults to True.
Examples:
>>> message_hub = MessageHub.get_instance('mmengine')
>>> log_dict = dict(a=1, b=2, c=3)
>>> message_hub.update_log_vars(log_dict)
>>> message_hub.update_scalars(log_dict)
>>> # The default count of `a`, `b` and `c` is 1.
>>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2))
>>> message_hub.update_log_vars(log_dict)
>>> message_hub.update_scalars(log_dict)
>>> # The count of `c` is 2.
"""
assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, '
f'but got {type(log_dict)}')
for log_name, log_val in log_dict.items():
self._set_resumed_keys(log_name, resumed)
if isinstance(log_val, dict):
assert 'value' in log_val, \
f'value must be defined in {log_val}'
count = log_val.get('count', 1)
value = self._get_valid_value(log_name, log_val['value'])
count = self._get_valid_value(log_name,
log_val.get('count', 1))
checked_value = self._get_valid_value(log_name,
log_val['value'])
else:
value = self._get_valid_value(log_name, log_val)
count = 1
self.update_log(log_name, value, count)
checked_value = self._get_valid_value(log_name, log_val)
assert isinstance(count,
int), ('The type of count must be int. but got '
f'{type(count): {count}}')
self.update_scalar(log_name, checked_value, count)
def update_info(self, key: str, value: Any) -> None:
def update_info(self, key: str, value: Any, resumed: bool = True) -> None:
"""Update runtime information.
The key corresponding runtime information will be overwritten each
time calling ``update_info``.
Note:
resumed cannot be set repeatedly for the same key.
Examples:
>>> message_hub = MessageHub()
>>> message_hub.update_info('iter', 100)
Args:
key (str): Key of runtime information.
value (Any): Value of runtime information.
resumed (bool): Whether the corresponding ``HistoryBuffer``
could be resumed.
"""
self._set_resumed_keys(key, resumed)
self._resumed_keys[key] = resumed
self._runtime_info[key] = value
def _set_resumed_keys(self, key: str, resumed: bool) -> None:
"""Set corresponding resumed keys.
This method is called by ``update_scalar``, ``update_scalars`` and
``update_info`` to set the corresponding key is true or false in
:attr:`_resumed_keys`.
Args:
key (str): Key of :attr:`_log_scalrs` or :attr:`_runtime_info`.
resumed (bool): Whether the corresponding ``HistoryBuffer``
could be resumed.
"""
if key not in self._resumed_keys:
self._resumed_keys[key] = resumed
else:
assert self._resumed_keys[key] == resumed, \
f'{key} used to be {self._resumed_keys[key]}, but got ' \
'{resumed} now. resumed keys cannot be modified repeatedly'
@property
def log_buffers(self) -> OrderedDict:
"""Get all ``LogBuffer`` instances.
def log_scalars(self) -> OrderedDict:
"""Get all ``HistoryBuffer`` instances.
Note:
Considering the large memory footprint of ``log_buffers`` in the
post-training, ``MessageHub.log_buffers`` will not return the
result of ``copy.deepcopy``.
Considering the large memory footprint of history buffers in the
post-training, :meth:`get_scalar` will return a reference of
history buffer rather than a copy.
Returns:
OrderedDict: All ``LogBuffer`` instances.
OrderedDict: All ``HistoryBuffer`` instances.
"""
return self._log_buffers
return self._log_scalars
@property
def runtime_info(self) -> OrderedDict:
@ -105,24 +231,25 @@ class MessageHub(ManagerMixin):
"""
return copy.deepcopy(self._runtime_info)
def get_log(self, key: str) -> LogBuffer:
"""Get ``LogBuffer`` instance by key.
def get_scalar(self, key: str) -> HistoryBuffer:
"""Get ``HistoryBuffer`` instance by key.
Note:
Considering the large memory footprint of ``log_buffers`` in the
post-training, ``MessageHub.get_log`` will not return the
result of ``copy.deepcopy``.
Considering the large memory footprint of history buffers in the
post-training, :meth:`get_scalar` will not return a reference of
history buffer rather than a copy.
Args:
key (str): Key of ``LogBuffer``.
key (str): Key of ``HistoryBuffer``.
Returns:
LogBuffer: Corresponding ``LogBuffer`` instance if the key exists.
HistoryBuffer: Corresponding ``HistoryBuffer`` instance if the
key exists.
"""
if key not in self.log_buffers:
if key not in self.log_scalars:
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
f'instance name is: {MessageHub.instance_name}')
return self._log_buffers[key]
return self.log_scalars[key]
def get_info(self, key: str) -> Any:
"""Get runtime information by key.
@ -139,7 +266,7 @@ class MessageHub(ManagerMixin):
return copy.deepcopy(self._runtime_info[key])
def _get_valid_value(self, key: str,
value: Union[torch.Tensor, np.ndarray, int, float])\
value: Union[torch.Tensor, np.ndarray, int, float]) \
-> Union[int, float]:
"""Convert value to python built-in type.
@ -158,4 +285,22 @@ class MessageHub(ManagerMixin):
value = value.item()
else:
check_type(key, value, (int, float))
return value
return value # type: ignore
def __getstate__(self):
for key in list(self._log_scalars.keys()):
assert key in self._resumed_keys, (
f'Cannot found {key} in {self}._resumed_keys, '
'please make sure you do not change the _resumed_keys '
'outside the class')
if not self._resumed_keys[key]:
self._log_scalars.pop(key)
for key in list(self._runtime_info.keys()):
assert key in self._resumed_keys, (
f'Cannot found {key} in {self}._resumed_keys, '
'please make sure you do not change the _resumed_keys '
'outside the class')
if not self._resumed_keys[key]:
self._runtime_info.pop(key)
return self.__dict__

View File

@ -77,7 +77,7 @@ class EpochBasedTrainLoop(BaseLoop):
# TODO, should move to LoggerHook
for key, value in self.runner.outputs['log_vars'].items():
self.runner.message_hub.update_log(f'train/{key}', value)
self.runner.message_hub.update_scalar(f'train/{key}', value)
self.runner.call_hook(
'after_train_iter',
@ -147,7 +147,7 @@ class IterBasedTrainLoop(BaseLoop):
# TODO
for key, value in self.runner.outputs['log_vars'].items():
self.runner.message_hub.update_log(f'train/{key}', value)
self.runner.message_hub.update_scalar(f'train/{key}', value)
self.runner.call_hook(
'after_train_iter',
@ -195,7 +195,7 @@ class ValLoop(BaseLoop):
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
for key, value in metrics.items():
self.runner.message_hub.update_log(f'val/{key}', value)
self.runner.message_hub.update_scalar(f'val/{key}', value)
self.runner.call_hook('after_val_epoch')
self.runner.call_hook('after_val')
@ -252,7 +252,7 @@ class TestLoop(BaseLoop):
# compute metrics
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
for key, value in metrics.items():
self.runner.message_hub.update_log(f'test/{key}', value)
self.runner.message_hub.update_scalar(f'test/{key}', value)
self.runner.call_hook('after_test_epoch')
self.runner.call_hook('after_test')

View File

@ -71,6 +71,8 @@ class ManagerMixin(metaclass=ManagerMeta):
"""
def __init__(self, name: str = '', **kwargs):
assert isinstance(name, str) and name, \
'name argument must be an non-empty string.'
self._instance_name = name
@classmethod
@ -105,6 +107,10 @@ class ManagerMixin(metaclass=ManagerMeta):
if name not in instance_dict:
instance = cls(name=name, **kwargs)
instance_dict[name] = instance
else:
assert not kwargs, (
f'{cls} instance named of {name} has been created, the method '
'`get_instance` should not access any other arguments')
# Get latest instantiated instance or root instance.
_release_lock()
return instance_dict[name]
@ -117,12 +123,12 @@ class ManagerMixin(metaclass=ManagerMeta):
``get_instance(xxx)`` at least once.
Examples
>>> instance = GlobalAccessible.get_current_instance(current=True)
>>> instance = GlobalAccessible.get_current_instance()
AssertionError: At least one of name and current needs to be set
>>> instance = GlobalAccessible.get_instance('name1')
>>> instance.instance_name
name1
>>> instance = GlobalAccessible.get_current_instance(current=True)
>>> instance = GlobalAccessible.get_current_instance()
>>> instance.instance_name
name1
@ -132,9 +138,8 @@ class ManagerMixin(metaclass=ManagerMeta):
_accquire_lock()
if not cls._instance_dict:
raise RuntimeError(
f'Before calling {cls.__name__}.get_instance('
'current=True), '
'you should call get_instance(name=xxx) at least once.')
f'Before calling {cls.__name__}.get_current_instance(), you '
'should call get_instance(name=xxx) at least once.')
name = next(iter(reversed(cls._instance_dict)))
_release_lock()
return cls._instance_dict[name]

View File

@ -18,7 +18,7 @@ class TestIterTimerHook:
runner.log_buffer = dict()
hook._before_epoch(runner)
hook._before_iter(runner, 0)
runner.message_hub.update_log.assert_called()
runner.message_hub.update_scalar.assert_called()
def test_after_iter(self):
hook = IterTimerHook()
@ -26,4 +26,4 @@ class TestIterTimerHook:
runner.log_buffer = dict()
hook._before_epoch(runner)
hook._after_iter(runner, 0)
runner.message_hub.update_log.assert_called()
runner.message_hub.update_scalar.assert_called()

View File

@ -287,7 +287,7 @@ class TestLoggerHook:
'train/loss_cls': MagicMock(),
'val/metric': MagicMock()
}
runner.message_hub.log_buffers = log_buffers
runner.message_hub.log_scalars = log_buffers
tag = logger_hook._collect_info(runner, mode='train')
# Test parse custom_keys
logger_hook._parse_custom_keys.assert_called()

View File

@ -3,15 +3,13 @@ import numpy as np
import pytest
import torch
from mmengine import LogBuffer
from mmengine import HistoryBuffer
class TestLoggerBuffer:
def test_init(self):
# `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to
# test `update` method
log_buffer = LogBuffer()
log_buffer = HistoryBuffer()
assert log_buffer.max_length == 1000000
log_history, counts = log_buffer.data
assert len(log_history) == 0
@ -19,7 +17,7 @@ class TestLoggerBuffer:
# test the length of array exceed `max_length`
logs = np.random.randint(1, 10, log_buffer.max_length + 1)
counts = np.random.randint(1, 10, log_buffer.max_length + 1)
log_buffer = LogBuffer(logs, counts)
log_buffer = HistoryBuffer(logs, counts)
log_history, count_history = log_buffer.data
assert len(log_history) == log_buffer.max_length
@ -30,14 +28,13 @@ class TestLoggerBuffer:
# The different lengths of `log_history` and `count_history` will
# raise error
with pytest.raises(AssertionError):
LogBuffer([1, 2], [1])
HistoryBuffer([1, 2], [1])
@pytest.mark.parametrize('array_method',
[torch.tensor, np.array, lambda x: x])
def test_update(self, array_method):
# `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to
# test `update` method
log_buffer = LogBuffer()
log_buffer = HistoryBuffer()
log_history = array_method([1, 2, 3, 4, 5])
count_history = array_method([5, 5, 5, 5, 5])
for i in range(len(log_history)):
@ -52,7 +49,7 @@ class TestLoggerBuffer:
# test the length of `array` exceed `max_length`
max_array = array_method([[-1] + [1] * (log_buffer.max_length - 1)])
max_count = array_method([[-1] + [1] * (log_buffer.max_length - 1)])
log_buffer = LogBuffer(max_array, max_count)
log_buffer = HistoryBuffer(max_array, max_count)
log_buffer.update(1)
log_history, count_history = log_buffer.data
assert log_history[0] == 1
@ -69,7 +66,7 @@ class TestLoggerBuffer:
def test_max_min(self, statistics_method, log_buffer_type):
log_history = np.random.randint(1, 5, 20)
count_history = np.ones(20)
log_buffer = LogBuffer(log_history, count_history)
log_buffer = HistoryBuffer(log_history, count_history)
assert statistics_method(log_history[-10:]) == \
getattr(log_buffer, log_buffer_type)(10)
assert statistics_method(log_history) == \
@ -78,7 +75,7 @@ class TestLoggerBuffer:
def test_mean(self):
log_history = np.random.randint(1, 5, 20)
count_history = np.ones(20)
log_buffer = LogBuffer(log_history, count_history)
log_buffer = HistoryBuffer(log_history, count_history)
assert np.sum(log_history[-10:]) / \
np.sum(count_history[-10:]) == \
log_buffer.mean(10)
@ -89,17 +86,17 @@ class TestLoggerBuffer:
def test_current(self):
log_history = np.random.randint(1, 5, 20)
count_history = np.ones(20)
log_buffer = LogBuffer(log_history, count_history)
log_buffer = HistoryBuffer(log_history, count_history)
assert log_history[-1] == log_buffer.current()
# test get empty array
log_buffer = LogBuffer()
log_buffer = HistoryBuffer()
with pytest.raises(ValueError):
log_buffer.current()
def test_statistics(self):
log_history = np.array([1, 2, 3, 4, 5])
count_history = np.array([1, 1, 1, 1, 1])
log_buffer = LogBuffer(log_history, count_history)
log_buffer = HistoryBuffer(log_history, count_history)
assert log_buffer.statistics('mean') == 3
assert log_buffer.statistics('min') == 1
assert log_buffer.statistics('max') == 5
@ -110,9 +107,9 @@ class TestLoggerBuffer:
def test_register_statistics(self):
@LogBuffer.register_statistics
@HistoryBuffer.register_statistics
def custom_statistics(self):
return -1
log_buffer = LogBuffer()
log_buffer = HistoryBuffer()
assert log_buffer.statistics('custom_statistics') == -1

View File

@ -11,14 +11,16 @@ from mmengine import MMLogger, print_log
class TestLogger:
regex_time = r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}'
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
@patch('torch.distributed.get_rank', lambda: 0)
@patch('torch.distributed.is_initialized', lambda: True)
@patch('torch.distributed.is_available', lambda: True)
def test_init_rank0(self, tmp_path):
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
assert logger.name == 'rank0.pkg1'
assert logger.name == 'mmengine'
assert logger.instance_name == 'rank0.pkg1'
assert logger.instance_name == 'rank0.pkg1'
# Logger get from `MMLogger.get_instance` does not inherit from
# `logging.root`
@ -38,6 +40,10 @@ class TestLogger:
assert isinstance(logger.handlers[1], logging.FileHandler)
logger_pkg3 = MMLogger.get_instance('rank0.pkg2')
assert id(logger_pkg3) == id(logger)
logger = MMLogger.get_instance(
'rank0.pkg3', logger_name='logger_test', log_level='INFO')
assert logger.name == 'logger_test'
assert logger.instance_name == 'rank0.pkg3'
logging.shutdown()
@patch('torch.distributed.get_rank', lambda: 1)
@ -46,12 +52,18 @@ class TestLogger:
def test_init_rank1(self, tmp_path):
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
tmp_file = tmp_path / 'tmp_file.log'
log_path = tmp_path / 'rank1_tmp_file.log'
log_path = tmp_path / 'tmp_file' / 'tmp_file_rank1.log'
logger = MMLogger.get_instance(
'rank1.pkg2', log_level='INFO', log_file=str(tmp_file))
assert len(logger.handlers) == 2
assert len(logger.handlers) == 1
logger = MMLogger.get_instance(
'rank1.pkg3',
log_level='INFO',
log_file=str(tmp_file),
distributed=True)
assert logger.handlers[0].level == logging.ERROR
assert logger.handlers[1].level == logging.INFO
assert len(logger.handlers) == 2
assert os.path.exists(log_path)
logging.shutdown()
@ -59,32 +71,33 @@ class TestLogger:
[logging.WARNING, logging.INFO, logging.DEBUG])
def test_handler(self, capsys, tmp_path, log_level):
# test stream handler can output correct format logs
logger_name = f'test_stream_{str(log_level)}'
logger = MMLogger.get_instance(logger_name, log_level=log_level)
instance_name = f'test_stream_{str(log_level)}'
logger = MMLogger.get_instance(instance_name, log_level=log_level)
logger.log(level=log_level, msg='welcome')
out, _ = capsys.readouterr()
# Skip match colored INFO
loglevl_name = logging._levelToName[log_level]
match = re.fullmatch(
self.regex_time + f' - {logger_name} - '
self.stream_handler_regex_time + f' - mmengine - '
f'(.*){loglevl_name}(.*) - welcome\n', out)
assert match is not None
# test file_handler output plain text without color.
tmp_file = tmp_path / 'tmp_file.log'
logger_name = f'test_file_{log_level}'
instance_name = f'test_file_{log_level}'
logger = MMLogger.get_instance(
logger_name, log_level=log_level, log_file=tmp_file)
instance_name, log_level=log_level, log_file=tmp_file)
logger.log(level=log_level, msg='welcome')
with open(tmp_file, 'r') as f:
with open(tmp_path / 'tmp_file' / 'tmp_file.log', 'r') as f:
log_text = f.read()
match = re.fullmatch(
self.regex_time + f' - {logger_name} - {loglevl_name} - '
self.file_handler_regex_time +
f' - mmengine - {loglevl_name} - '
f'welcome\n', log_text)
assert match is not None
logging.shutdown()
def test_erro_format(self, capsys):
def test_error_format(self, capsys):
# test error level log can output file path, function name and
# line number
logger = MMLogger.get_instance('test_error', log_level='INFO')
@ -92,9 +105,10 @@ class TestLogger:
lineno = sys._getframe().f_lineno - 1
file_path = __file__
function_name = sys._getframe().f_code.co_name
pattern = self.regex_time + r' - test_error - (.*)ERROR(.*) - '\
f'{file_path} - {function_name} - ' \
f'{lineno} - welcome\n'
pattern = self.stream_handler_regex_time + \
r' - mmengine - (.*)ERROR(.*) - ' \
f'{file_path} - {function_name} - ' \
f'{lineno} - welcome\n'
out, _ = capsys.readouterr()
match = re.fullmatch(pattern, out)
assert match is not None
@ -114,21 +128,21 @@ class TestLogger:
print_log('welcome', logger=logger)
out, _ = capsys.readouterr()
match = re.fullmatch(
self.regex_time + ' - test_print_log - (.*)INFO(.*) - '
self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - '
'welcome\n', out)
assert match is not None
# Test access logger by name.
print_log('welcome', logger='test_print_log')
out, _ = capsys.readouterr()
match = re.fullmatch(
self.regex_time + ' - test_print_log - (.*)INFO(.*) - '
self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - '
'welcome\n', out)
assert match is not None
# Test access the latest created logger.
print_log('welcome', logger='current')
out, _ = capsys.readouterr()
match = re.fullmatch(
self.regex_time + ' - test_print_log - (.*)INFO(.*) - '
self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - '
'welcome\n', out)
assert match is not None
# Test invalid logger type.

View File

@ -1,4 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import pickle
from collections import OrderedDict
import numpy as np
import pytest
import torch
@ -11,17 +14,26 @@ class TestMessageHub:
def test_init(self):
message_hub = MessageHub('name')
assert message_hub.instance_name == 'name'
assert len(message_hub.log_buffers) == 0
assert len(message_hub.log_buffers) == 0
assert len(message_hub.log_scalars) == 0
assert len(message_hub.log_scalars) == 0
# The type of log_scalars's value must be `HistoryBuffer`.
with pytest.raises(AssertionError):
MessageHub('hello', log_scalars=OrderedDict(a=1))
# `Resumed_keys`
with pytest.raises(AssertionError):
MessageHub(
'hello',
runtime_info=OrderedDict(iter=1),
resumed_keys=OrderedDict(iters=False))
def test_update_log(self):
def test_update_scalar(self):
message_hub = MessageHub.get_instance('mmengine')
# test create target `LogBuffer` by name
message_hub.update_log('name', 1)
log_buffer = message_hub.log_buffers['name']
# test create target `HistoryBuffer` by name
message_hub.update_scalar('name', 1)
log_buffer = message_hub.log_scalars['name']
assert (log_buffer._log_history == np.array([1])).all()
# test update target `LogBuffer` by name
message_hub.update_log('name', 1)
# test update target `HistoryBuffer` by name
message_hub.update_scalar('name', 1)
assert (log_buffer._log_history == np.array([1, 1])).all()
# unmatched string will raise a key error
@ -33,19 +45,19 @@ class TestMessageHub:
message_hub.update_info('key', 1)
assert message_hub.runtime_info['key'] == 1
def test_get_log_buffers(self):
def test_get_scalar(self):
message_hub = MessageHub.get_instance('mmengine')
# Get undefined key will raise error
with pytest.raises(KeyError):
message_hub.get_log('unknown')
message_hub.get_scalar('unknown')
# test get log_buffer as wished
log_history = np.array([1, 2, 3, 4, 5])
count = np.array([1, 1, 1, 1, 1])
for i in range(len(log_history)):
message_hub.update_log('test_value', float(log_history[i]),
int(count[i]))
message_hub.update_scalar('test_value', float(log_history[i]),
int(count[i]))
recorded_history, recorded_count = \
message_hub.get_log('test_value').data
message_hub.get_scalar('test_value').data
assert (log_history == recorded_history).all()
assert (recorded_count == count).all()
@ -57,18 +69,18 @@ class TestMessageHub:
message_hub.update_info('test_value', recorded_dict)
assert message_hub.get_info('test_value') == recorded_dict
def test_get_log_vars(self):
def test_get_scalars(self):
message_hub = MessageHub.get_instance('mmengine')
log_dict = dict(
loss=1,
loss_cls=torch.tensor(2),
loss_bbox=np.array(3),
loss_iou=dict(value=1, count=2))
message_hub.update_log_vars(log_dict)
loss = message_hub.get_log('loss')
loss_cls = message_hub.get_log('loss_cls')
loss_bbox = message_hub.get_log('loss_bbox')
loss_iou = message_hub.get_log('loss_iou')
message_hub.update_scalars(log_dict)
loss = message_hub.get_scalar('loss')
loss_cls = message_hub.get_scalar('loss_cls')
loss_bbox = message_hub.get_scalar('loss_bbox')
loss_iou = message_hub.get_scalar('loss_iou')
assert loss.current() == 1
assert loss_cls.current() == 2
assert loss_bbox.current() == 3
@ -76,8 +88,27 @@ class TestMessageHub:
with pytest.raises(TypeError):
loss_dict = dict(error_type=[])
message_hub.update_log_vars(loss_dict)
message_hub.update_scalars(loss_dict)
with pytest.raises(AssertionError):
loss_dict = dict(error_type=dict(count=1))
message_hub.update_log_vars(loss_dict)
message_hub.update_scalars(loss_dict)
def test_getstate(self):
message_hub = MessageHub.get_instance('name')
# update log_scalars.
message_hub.update_scalar('loss', 0.1)
message_hub.update_scalar('lr', 0.1, resumed=False)
# update runtime information
message_hub.update_info('iter', 1, resumed=True)
message_hub.update_info('feat', [1, 2, 3], resumed=False)
obj = pickle.dumps(message_hub)
instance = pickle.loads(obj)
with pytest.raises(KeyError):
instance.get_info('feat')
with pytest.raises(KeyError):
instance.get_info('lr')
instance.get_info('iter')
instance.get_scalar('loss')

View File

@ -70,3 +70,7 @@ class TestManagerMixin:
# Non-string instance name will raise `AssertionError`.
with pytest.raises(AssertionError):
SubClassA.get_instance(name=1)
# `get_instance` should not accept other arguments if corresponding
# instance has been created.
with pytest.raises(AssertionError):
SubClassA.get_instance('name2', a=1, b=2)