[Enhancement] Refine logging. (#148)
* rename global accessible and intergration get_sintance and create_instance * move ManagerMixin to utils * fix as docstring and seporate get_instance to get_instance and get_current_instance * fix lint * fix docstring, rename and move test_global_meta * rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume * refine MMLogger timestamp, update unit test * MMLogger add logger_name arguments * Fix docstring * change default logger_name to mmengine * Fix docstring comment and unitt test * fix docstring fix docstring * fix docstring * Fix lint * Fix hook unit test * Fix comments * should not accept other arguments if corresponding instance has been created * fix logging ddp file saving * fix logging ddp file saving * fix docstring * fix unit test * fix docstring as commentpull/192/head
parent
45567b1d1c
commit
82a313d09b
|
@ -42,8 +42,8 @@ class IterTimerHook(Hook):
|
|||
mode (str): Current mode of runner. Defaults to 'train'.
|
||||
"""
|
||||
# TODO: update for new logging system
|
||||
runner.message_hub.update_log(f'{mode}/data_time',
|
||||
time.time() - self.t)
|
||||
runner.message_hub.update_scalar(f'{mode}/data_time',
|
||||
time.time() - self.t)
|
||||
|
||||
def _after_iter(self,
|
||||
runner,
|
||||
|
@ -65,5 +65,5 @@ class IterTimerHook(Hook):
|
|||
"""
|
||||
# TODO: update for new logging system
|
||||
|
||||
runner.message_hub.update_log(f'{mode}/time', time.time() - self.t)
|
||||
runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t)
|
||||
self.t = time.time()
|
||||
|
|
|
@ -370,7 +370,7 @@ class LoggerHook(Hook):
|
|||
dict: Statistical values of logs.
|
||||
"""
|
||||
tag = OrderedDict()
|
||||
log_buffers = runner.message_hub.log_buffers
|
||||
log_buffers = runner.message_hub.log_scalars
|
||||
mode_log_buffers = OrderedDict()
|
||||
# Filter log_buffers which starts with `mode`.
|
||||
for prefix_key, log_buffer in log_buffers.items():
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .log_buffer import LogBuffer
|
||||
from .history_buffer import HistoryBuffer
|
||||
from .logger import MMLogger, print_log
|
||||
from .message_hub import MessageHub
|
||||
|
||||
__all__ = ['LogBuffer', 'MessageHub', 'MMLogger', 'print_log']
|
||||
__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log']
|
||||
|
|
|
@ -1,16 +1,28 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import warnings
|
||||
from functools import partial
|
||||
from typing import Any, Callable, Optional, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class BaseLogBuffer:
|
||||
class HistoryBuffer:
|
||||
"""Unified storage format for different log types.
|
||||
|
||||
Record the history of log for further statistics. The subclass inherited
|
||||
from ``BaseLogBuffer`` will implement the specific statistical methods.
|
||||
``HistoryBuffer`` records the history of log for further statistics.
|
||||
|
||||
Examples:
|
||||
>>> history_buffer = HistoryBuffer()
|
||||
>>> # Update history_buffer.
|
||||
>>> history_buffer.update(1)
|
||||
>>> history_buffer.update(2)
|
||||
>>> history_buffer.min() # minimum of (1, 2)
|
||||
1
|
||||
>>> history_buffer.max() # maximum of (1, 2)
|
||||
2
|
||||
>>> history_buffer.mean() # mean of (1, 2)
|
||||
1.5
|
||||
>>> history_buffer.statistics('mean') # access method by string.
|
||||
1.5
|
||||
|
||||
Args:
|
||||
log_history (Sequence): History logs. Defaults to [].
|
||||
|
@ -25,6 +37,7 @@ class BaseLogBuffer:
|
|||
max_length: int = 1000000):
|
||||
|
||||
self.max_length = max_length
|
||||
self._set_default_statistics()
|
||||
assert len(log_history) == len(count_history), \
|
||||
'The lengths of log_history and count_histroy should be equal'
|
||||
if len(log_history) > max_length:
|
||||
|
@ -37,10 +50,18 @@ class BaseLogBuffer:
|
|||
self._log_history = np.array(log_history)
|
||||
self._count_history = np.array(count_history)
|
||||
|
||||
def _set_default_statistics(self) -> None:
|
||||
"""Register default statistic methods: min, max, current and mean."""
|
||||
self._statistics_methods.setdefault('min', HistoryBuffer.min)
|
||||
self._statistics_methods.setdefault('max', HistoryBuffer.max)
|
||||
self._statistics_methods.setdefault('current', HistoryBuffer.current)
|
||||
self._statistics_methods.setdefault('mean', HistoryBuffer.mean)
|
||||
|
||||
def update(self, log_val: Union[int, float], count: int = 1) -> None:
|
||||
"""update the log history. If the length of the buffer exceeds
|
||||
``self._max_length``, the oldest element will be removed from the
|
||||
buffer.
|
||||
"""update the log history.
|
||||
|
||||
If the length of the buffer exceeds ``self._max_length``, the oldest
|
||||
element will be removed from the buffer.
|
||||
|
||||
Args:
|
||||
log_val (int or float): The value of log.
|
||||
|
@ -72,9 +93,23 @@ class BaseLogBuffer:
|
|||
def register_statistics(cls, method: Callable) -> Callable:
|
||||
"""Register custom statistics method to ``_statistics_methods``.
|
||||
|
||||
The registered method can be called by ``history_buffer.statistics``
|
||||
with corresponding method name and arguments.
|
||||
|
||||
Examples:
|
||||
>>> @HistoryBuffer.register_statistics
|
||||
>>> def weighted_mean(self, window_size, weight):
|
||||
>>> assert len(weight) == window_size
|
||||
>>> return (self._log_history[-window_size:] *
|
||||
>>> np.array(weight)).sum() / \
|
||||
>>> self._count_history[-window_size:]
|
||||
|
||||
>>> log_buffer = HistoryBuffer([1, 2], [1, 1])
|
||||
>>> log_buffer.statistics('weighted_mean', 2, [2, 1])
|
||||
2
|
||||
|
||||
Args:
|
||||
method (Callable): Custom statistics method.
|
||||
|
||||
Returns:
|
||||
Callable: Original custom statistics method.
|
||||
"""
|
||||
|
@ -95,26 +130,20 @@ class BaseLogBuffer:
|
|||
"""
|
||||
if method_name not in self._statistics_methods:
|
||||
raise KeyError(f'{method_name} has not been registered in '
|
||||
'BaseLogBuffer._statistics_methods')
|
||||
'HistoryBuffer._statistics_methods')
|
||||
method = self._statistics_methods[method_name]
|
||||
# Provide self arguments for registered functions.
|
||||
method = partial(method, self)
|
||||
return method(*arg, **kwargs)
|
||||
return method(self, *arg, **kwargs)
|
||||
|
||||
|
||||
class LogBuffer(BaseLogBuffer):
|
||||
"""``LogBuffer`` inherits from ``BaseLogBuffer`` and provides some basic
|
||||
statistics methods, such as ``min``, ``max``, ``current`` and ``mean``."""
|
||||
|
||||
@BaseLogBuffer.register_statistics
|
||||
def mean(self, window_size: Optional[int] = None) -> np.ndarray:
|
||||
"""Return the mean of the latest ``window_size`` values in log
|
||||
histories. If ``window_size is None``, return the global mean of
|
||||
history logs.
|
||||
histories.
|
||||
|
||||
If ``window_size is None`` or ``window_size > len(self._log_history)``,
|
||||
return the global mean value of history logs.
|
||||
|
||||
Args:
|
||||
window_size (int, optional): Size of statistics window.
|
||||
|
||||
Returns:
|
||||
np.ndarray: Mean value within the window.
|
||||
"""
|
||||
|
@ -128,15 +157,15 @@ class LogBuffer(BaseLogBuffer):
|
|||
counts_sum = self._count_history[-window_size:].sum()
|
||||
return logs_sum / counts_sum
|
||||
|
||||
@BaseLogBuffer.register_statistics
|
||||
def max(self, window_size: Optional[int] = None) -> np.ndarray:
|
||||
"""Return the maximum value of the latest ``window_size`` values in log
|
||||
histories. If ``window_size is None``, return the global maximum value
|
||||
of history logs.
|
||||
histories.
|
||||
|
||||
If ``window_size is None`` or ``window_size > len(self._log_history)``,
|
||||
return the global maximum value of history logs.
|
||||
|
||||
Args:
|
||||
window_size (int, optional): Size of statistics window.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The maximum value within the window.
|
||||
"""
|
||||
|
@ -148,15 +177,15 @@ class LogBuffer(BaseLogBuffer):
|
|||
window_size = len(self._log_history)
|
||||
return self._log_history[-window_size:].max()
|
||||
|
||||
@BaseLogBuffer.register_statistics
|
||||
def min(self, window_size: Optional[int] = None) -> np.ndarray:
|
||||
"""Return the minimum value of the latest ``window_size`` values in log
|
||||
histories. If ``window_size is None``, return the global minimum value
|
||||
of history logs.
|
||||
histories.
|
||||
|
||||
If ``window_size is None`` or ``window_size > len(self._log_history)``,
|
||||
return the global minimum value of history logs.
|
||||
|
||||
Args:
|
||||
window_size (int, optional): Size of statistics window.
|
||||
|
||||
Returns:
|
||||
np.ndarray: The minimum value within the window.
|
||||
"""
|
||||
|
@ -168,7 +197,6 @@ class LogBuffer(BaseLogBuffer):
|
|||
window_size = len(self._log_history)
|
||||
return self._log_history[-window_size:].min()
|
||||
|
||||
@BaseLogBuffer.register_statistics
|
||||
def current(self) -> np.ndarray:
|
||||
"""Return the recently updated values in log histories.
|
||||
|
||||
|
@ -176,6 +204,6 @@ class LogBuffer(BaseLogBuffer):
|
|||
np.ndarray: Recently updated values in log histories.
|
||||
"""
|
||||
if len(self._log_history) == 0:
|
||||
raise ValueError('LogBuffer._log_history is an empty array! '
|
||||
raise ValueError('HistoryBuffer._log_history is an empty array! '
|
||||
'please call update first')
|
||||
return self._log_history[-1]
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import logging
|
||||
import os
|
||||
import os.path as osp
|
||||
import sys
|
||||
from logging import Logger, LogRecord
|
||||
from typing import Optional, Union
|
||||
|
@ -8,7 +9,7 @@ from typing import Optional, Union
|
|||
import torch.distributed as dist
|
||||
from termcolor import colored
|
||||
|
||||
from mmengine.utils import ManagerMixin
|
||||
from mmengine.utils import ManagerMixin, mkdir_or_exist
|
||||
|
||||
|
||||
class MMFormatter(logging.Formatter):
|
||||
|
@ -85,54 +86,134 @@ class MMFormatter(logging.Formatter):
|
|||
|
||||
|
||||
class MMLogger(Logger, ManagerMixin):
|
||||
"""The Logger manager which can create formatted logger and get specified
|
||||
logger globally. MMLogger is created and accessed in the same way as
|
||||
ManagerMixin.
|
||||
"""Formatted logger used to record messages.
|
||||
|
||||
``MMLogger`` can create formatted logger to log message with different
|
||||
log levels and get instance in the same way as ``ManagerMixin``.
|
||||
``MMLogger`` has the following features:
|
||||
|
||||
- Distributed log storage, ``MMLogger`` can choose whether to save log of
|
||||
different ranks according to `log_file`.
|
||||
- Message with different log levels will have different colors and format
|
||||
when displayed on terminal.
|
||||
|
||||
Note:
|
||||
- The `name` of logger and the ``instance_name`` of ``MMLogger`` could
|
||||
be different. We can only get ``MMLogger`` instance by
|
||||
``MMLogger.get_instance`` but not ``logging.getLogger``. This feature
|
||||
ensures ``MMLogger`` will not be incluenced by third-party logging
|
||||
config.
|
||||
- Different from ``logging.Logger``, ``MMLogger`` will not log warrning
|
||||
or error message without ``Handler``.
|
||||
- If `log_file=/path/to/tmp.log`, all logs will be saved to
|
||||
`/path/to/tmp/tmp.log`
|
||||
|
||||
Examples:
|
||||
>>> logger = MMLogger.get_instance(name='MMLogger',
|
||||
>>> logger_name='Logger')
|
||||
>>> # Although logger has name attribute just like `logging.Logger`
|
||||
>>> # We cannot get logger instance by `logging.getLogger`.
|
||||
>>> assert logger.name == 'Logger'
|
||||
>>> assert logger.instance_name = 'MMLogger'
|
||||
>>> assert id(logger) != id(logging.getLogger('Logger'))
|
||||
>>> # Get logger that do not store logs.
|
||||
>>> logger1 = MMLogger.get_instance('logger1')
|
||||
>>> # Get logger only save rank0 logs.
|
||||
>>> logger2 = MMLogger.get_instance('logger2', log_file='out.log')
|
||||
>>> # Get logger only save multiple ranks logs.
|
||||
>>> logger3 = MMLogger.get_instance('logger3', log_file='out.log',
|
||||
>>> distributed=True)
|
||||
|
||||
Args:
|
||||
name (str): Logger name. Defaults to ''.
|
||||
name (str): Global instance name.
|
||||
logger_name (str): ``name`` attribute of ``Logging.Logger`` instance.
|
||||
If `logger_name` is not defined, defaults to 'mmengine'.
|
||||
log_file (str, optional): The log filename. If specified, a
|
||||
``FileHandler`` will be added to the logger. Defaults to None.
|
||||
log_level: The log level of the handler. Defaults to 'NOTSET'.
|
||||
file_mode (str): The file mode used in opening log file.
|
||||
Defaults to 'w'.
|
||||
log_level (str): The log level of the handler and logger. Defaults to
|
||||
"NOTSET".
|
||||
file_mode (str): The file mode used to open log file. Defaults to 'w'.
|
||||
distributed (bool): Whether to save distributed logs, Defaults to
|
||||
false.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
name: str = '',
|
||||
name: str,
|
||||
logger_name='mmengine',
|
||||
log_file: Optional[str] = None,
|
||||
log_level: str = 'NOTSET',
|
||||
file_mode: str = 'w'):
|
||||
Logger.__init__(self, name)
|
||||
log_level: str = 'INFO',
|
||||
file_mode: str = 'w',
|
||||
distributed=False):
|
||||
Logger.__init__(self, logger_name)
|
||||
ManagerMixin.__init__(self, name)
|
||||
# Get rank in DDP mode.
|
||||
if dist.is_available() and dist.is_initialized():
|
||||
rank = dist.get_rank()
|
||||
else:
|
||||
rank = 0
|
||||
|
||||
# Config stream_handler. If `rank != 0`. stream_handler can only
|
||||
# export ERROR logs.
|
||||
stream_handler = logging.StreamHandler(stream=sys.stdout)
|
||||
stream_handler.setFormatter(MMFormatter(color=True))
|
||||
# `StreamHandler` record month, day, hour, minute, and second
|
||||
# timestamp.
|
||||
stream_handler.setFormatter(
|
||||
MMFormatter(color=True, datefmt='%m/%d %H:%M:%S'))
|
||||
# Only rank0 `StreamHandler` will log messages below error level.
|
||||
stream_handler.setLevel(log_level) if rank == 0 else \
|
||||
stream_handler.setLevel(logging.ERROR)
|
||||
self.handlers.append(stream_handler)
|
||||
|
||||
if log_file is not None:
|
||||
# If `log_file=/path/to/tmp.log`, all logs will be saved to
|
||||
# `/path/to/tmp/tmp.log`
|
||||
log_dir = osp.dirname(log_file)
|
||||
filename = osp.basename(log_file)
|
||||
filename_list = filename.split('.')
|
||||
sub_file_name = '.'.join(filename_list[:-1])
|
||||
log_dir = osp.join(log_dir, sub_file_name)
|
||||
mkdir_or_exist(log_dir)
|
||||
log_file = osp.join(log_dir, filename)
|
||||
if rank != 0:
|
||||
# rename `log_file` with rank prefix.
|
||||
# rename `log_file` with rank suffix.
|
||||
path_split = log_file.split(os.sep)
|
||||
path_split[-1] = f'rank{rank}_{path_split[-1]}'
|
||||
if '.' in path_split[-1]:
|
||||
filename_list = path_split[-1].split('.')
|
||||
filename_list[-2] = f'{filename_list[-2]}_rank{rank}'
|
||||
path_split[-1] = '.'.join(filename_list)
|
||||
else:
|
||||
path_split[-1] = f'{path_split[-1]}_rank{rank}'
|
||||
log_file = os.sep.join(path_split)
|
||||
# Here, the default behaviour of the official logger is 'a'. Thus,
|
||||
# we provide an interface to change the file mode to the default
|
||||
# behaviour. `FileHandler` is not supported to have colors,
|
||||
# otherwise it will appear garbled.
|
||||
file_handler = logging.FileHandler(log_file, file_mode)
|
||||
file_handler.setFormatter(MMFormatter(color=False))
|
||||
file_handler.setLevel(log_level)
|
||||
self.handlers.append(file_handler)
|
||||
# Save multi-ranks logs if distributed is True. The logs of rank0
|
||||
# will always be saved.
|
||||
if rank == 0 or distributed:
|
||||
# Here, the default behaviour of the official logger is 'a'.
|
||||
# Thus, we provide an interface to change the file mode to
|
||||
# the default behaviour. `FileHandler` is not supported to
|
||||
# have colors, otherwise it will appear garbled.
|
||||
file_handler = logging.FileHandler(log_file, file_mode)
|
||||
# `StreamHandler` record year, month, day hour, minute,
|
||||
# and second timestamp. file_handler will only record logs
|
||||
# without color to avoid garbled code saved in files.
|
||||
file_handler.setFormatter(
|
||||
MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S'))
|
||||
file_handler.setLevel(log_level)
|
||||
self.handlers.append(file_handler)
|
||||
|
||||
def callHandlers(self, record: LogRecord) -> None:
|
||||
"""Pass a record to all relevant handlers.
|
||||
|
||||
Override ``callHandlers`` method in ``logging.Logger`` to avoid
|
||||
multiple warning messages in DDP mode. Loop through all handlers of
|
||||
the logger instance and its parents in the logger hierarchy. If no
|
||||
handler was found, the record will not be output.
|
||||
|
||||
Args:
|
||||
record (LogRecord): A ``LogRecord`` instance contains logged
|
||||
message.
|
||||
"""
|
||||
for handler in self.handlers:
|
||||
if record.levelno >= handler.level:
|
||||
handler.handle(record)
|
||||
|
||||
|
||||
def print_log(msg,
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Union
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from mmengine.utils import ManagerMixin
|
||||
from mmengine.visualization.utils import check_type
|
||||
from .log_buffer import LogBuffer
|
||||
from .history_buffer import HistoryBuffer
|
||||
|
||||
|
||||
class MessageHub(ManagerMixin):
|
||||
|
@ -17,84 +17,210 @@ class MessageHub(ManagerMixin):
|
|||
|
||||
``MessageHub`` will record log information and runtime information. The
|
||||
log information refers to the learning rate, loss, etc. of the model
|
||||
when training a model, which will be stored as ``LogBuffer``. The runtime
|
||||
information refers to the iter times, meta information of runner etc.,
|
||||
which will be overwritten by next update.
|
||||
during training phase, which will be stored as ``HistoryBuffer``. The
|
||||
runtime information refers to the iter times, meta information of
|
||||
runner etc., which will be overwritten by next update.
|
||||
|
||||
Args:
|
||||
name (str): Name of message hub, for global access. Defaults to ''.
|
||||
name (str): Name of message hub used to get corresponding instance
|
||||
globally.
|
||||
log_scalars (OrderedDict, optional): Each key-value pair in the
|
||||
dictionary is the name of the log information such as "loss", "lr",
|
||||
"metric" and their corresponding values. The type of value must be
|
||||
HistoryBuffer. Defaults to None.
|
||||
runtime_info (OrderedDict, optional): Each key-value pair in the
|
||||
dictionary is the name of the runtime information and their
|
||||
corresponding values. Defaults to None.
|
||||
resumed_keys (OrderedDict, optional): Each key-value pair in the
|
||||
dictionary decides whether the key in :attr:`_log_scalars` and
|
||||
:attr:`_runtime_info` will be serialized.
|
||||
|
||||
Note:
|
||||
Key in :attr:`_resumed_keys` belongs to :attr:`_log_scalars` or
|
||||
:attr:`_runtime_info`. The corresponding value cannot be set
|
||||
repeatedly.
|
||||
|
||||
Examples:
|
||||
>>> # create empty `MessageHub`.
|
||||
>>> message_hub1 = MessageHub()
|
||||
>>> log_scalars = OrderedDict(loss=HistoryBuffer())
|
||||
>>> runtime_info = OrderedDict(task='task')
|
||||
>>> resumed_keys = dict(loss=True)
|
||||
>>> # create `MessageHub` from data.
|
||||
>>> message_hub2 = MessageHub(
|
||||
>>> name='name',
|
||||
>>> log_scalars=log_scalars,
|
||||
>>> runtime_info=runtime_info,
|
||||
>>> resumed_keys=resumed_keys)
|
||||
"""
|
||||
|
||||
def __init__(self, name: str = ''):
|
||||
self._log_buffers: OrderedDict = OrderedDict()
|
||||
self._runtime_info: OrderedDict = OrderedDict()
|
||||
def __init__(self,
|
||||
name: str,
|
||||
log_scalars: Optional[OrderedDict] = None,
|
||||
runtime_info: Optional[OrderedDict] = None,
|
||||
resumed_keys: Optional[OrderedDict] = None):
|
||||
super().__init__(name)
|
||||
self._log_scalars = log_scalars if log_scalars is not None else \
|
||||
OrderedDict()
|
||||
self._runtime_info = runtime_info if runtime_info is not None else \
|
||||
OrderedDict()
|
||||
self._resumed_keys = resumed_keys if resumed_keys is not None else \
|
||||
OrderedDict()
|
||||
|
||||
def update_log(self, key: str, value: Union[int, float], count: int = 1) \
|
||||
-> None:
|
||||
"""Update log buffer.
|
||||
assert isinstance(self._log_scalars, OrderedDict)
|
||||
assert isinstance(self._runtime_info, OrderedDict)
|
||||
assert isinstance(self._resumed_keys, OrderedDict)
|
||||
|
||||
for value in self._log_scalars.values():
|
||||
assert isinstance(value, HistoryBuffer), \
|
||||
("The type of log_scalars'value must be HistoryBuffer, but "
|
||||
f'got {type(value)}')
|
||||
|
||||
for key in self._resumed_keys.keys():
|
||||
assert key in self._log_scalars or key in self._runtime_info, \
|
||||
('Key in `resumed_keys` must contained in `log_scalars` or '
|
||||
f'`runtime_info`, but got {key}')
|
||||
|
||||
def update_scalar(self,
|
||||
key: str,
|
||||
value: Union[int, float, np.ndarray, torch.Tensor],
|
||||
count: int = 1,
|
||||
resumed: bool = True) -> None:
|
||||
"""Update :attr:_log_scalars.
|
||||
|
||||
Update ``HistoryBuffer`` in :attr:`_log_scalars`. If corresponding key
|
||||
``HistoryBuffer`` has been created, ``value`` and ``count`` is the
|
||||
argument of ``HistoryBuffer.update``, Otherwise, ``update_scalar``
|
||||
will create an ``HistoryBuffer`` with value and count via the
|
||||
constructor of ``HistoryBuffer``.
|
||||
|
||||
Examples:
|
||||
>>> message_hub = MessageHub
|
||||
>>> # create loss `HistoryBuffer` with value=1, count=1
|
||||
>>> message_hub.update_scalar('loss', 1)
|
||||
>>> # update loss `HistoryBuffer` with value
|
||||
>>> message_hub.update_scalar('loss', 3)
|
||||
>>> message_hub.update_scalar('loss', 3, resumed=False)
|
||||
AssertionError: loss used to be true, but got false now. resumed
|
||||
keys cannot be modified repeatedly'
|
||||
|
||||
Note:
|
||||
resumed cannot be set repeatedly for the same key.
|
||||
|
||||
Args:
|
||||
key (str): Key of ``LogBuffer``.
|
||||
value (int or float): Value of log.
|
||||
count (int): Accumulation times of log, defaults to 1. `count`
|
||||
will be used in smooth statistics.
|
||||
key (str): Key of ``HistoryBuffer``.
|
||||
value (torch.Tensor or np.ndarray or int or float): Value of log.
|
||||
count (torch.Tensor or np.ndarray or int or float): Accumulation
|
||||
times of log, defaults to 1. `count` will be used in smooth
|
||||
statistics.
|
||||
resumed (str): Whether the corresponding ``HistoryBuffer``
|
||||
could be resumed. Defaults to True.
|
||||
"""
|
||||
if key in self._log_buffers:
|
||||
self._log_buffers[key].update(value, count)
|
||||
self._set_resumed_keys(key, resumed)
|
||||
checked_value = self._get_valid_value(key, value)
|
||||
assert isinstance(count, int), (
|
||||
f'The type of count must be int. but got {type(count): {count}}')
|
||||
if key in self._log_scalars:
|
||||
self._log_scalars[key].update(checked_value, count)
|
||||
else:
|
||||
self._log_buffers[key] = LogBuffer([value], [count])
|
||||
self._log_scalars[key] = HistoryBuffer([checked_value], [count])
|
||||
|
||||
def update_log_vars(self, log_dict: dict) -> None:
|
||||
"""Update :attr:`_log_buffers` with a dict.
|
||||
def update_scalars(self, log_dict: dict, resumed: bool = True) -> None:
|
||||
"""Update :attr:`_log_scalars` with a dict.
|
||||
|
||||
``update_scalars`` iterates through each pair of log_dict key-value,
|
||||
and calls ``update_scalar``. If type of value is dict, the value should
|
||||
be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in
|
||||
``log_dict`` has the same resume option.
|
||||
|
||||
Args:
|
||||
log_dict (str): Used for batch updating :attr:`_log_buffers`.
|
||||
log_dict (str): Used for batch updating :attr:`_log_scalars`.
|
||||
resumed (bool): Whether all ``HistoryBuffer`` referred in
|
||||
log_dict should be resumed. Defaults to True.
|
||||
|
||||
Examples:
|
||||
>>> message_hub = MessageHub.get_instance('mmengine')
|
||||
>>> log_dict = dict(a=1, b=2, c=3)
|
||||
>>> message_hub.update_log_vars(log_dict)
|
||||
>>> message_hub.update_scalars(log_dict)
|
||||
>>> # The default count of `a`, `b` and `c` is 1.
|
||||
>>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2))
|
||||
>>> message_hub.update_log_vars(log_dict)
|
||||
>>> message_hub.update_scalars(log_dict)
|
||||
>>> # The count of `c` is 2.
|
||||
"""
|
||||
assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, '
|
||||
f'but got {type(log_dict)}')
|
||||
for log_name, log_val in log_dict.items():
|
||||
self._set_resumed_keys(log_name, resumed)
|
||||
if isinstance(log_val, dict):
|
||||
assert 'value' in log_val, \
|
||||
f'value must be defined in {log_val}'
|
||||
count = log_val.get('count', 1)
|
||||
value = self._get_valid_value(log_name, log_val['value'])
|
||||
count = self._get_valid_value(log_name,
|
||||
log_val.get('count', 1))
|
||||
checked_value = self._get_valid_value(log_name,
|
||||
log_val['value'])
|
||||
else:
|
||||
value = self._get_valid_value(log_name, log_val)
|
||||
count = 1
|
||||
self.update_log(log_name, value, count)
|
||||
checked_value = self._get_valid_value(log_name, log_val)
|
||||
assert isinstance(count,
|
||||
int), ('The type of count must be int. but got '
|
||||
f'{type(count): {count}}')
|
||||
self.update_scalar(log_name, checked_value, count)
|
||||
|
||||
def update_info(self, key: str, value: Any) -> None:
|
||||
def update_info(self, key: str, value: Any, resumed: bool = True) -> None:
|
||||
"""Update runtime information.
|
||||
|
||||
The key corresponding runtime information will be overwritten each
|
||||
time calling ``update_info``.
|
||||
|
||||
Note:
|
||||
resumed cannot be set repeatedly for the same key.
|
||||
|
||||
Examples:
|
||||
>>> message_hub = MessageHub()
|
||||
>>> message_hub.update_info('iter', 100)
|
||||
|
||||
Args:
|
||||
key (str): Key of runtime information.
|
||||
value (Any): Value of runtime information.
|
||||
resumed (bool): Whether the corresponding ``HistoryBuffer``
|
||||
could be resumed.
|
||||
"""
|
||||
self._set_resumed_keys(key, resumed)
|
||||
self._resumed_keys[key] = resumed
|
||||
self._runtime_info[key] = value
|
||||
|
||||
def _set_resumed_keys(self, key: str, resumed: bool) -> None:
|
||||
"""Set corresponding resumed keys.
|
||||
|
||||
This method is called by ``update_scalar``, ``update_scalars`` and
|
||||
``update_info`` to set the corresponding key is true or false in
|
||||
:attr:`_resumed_keys`.
|
||||
|
||||
Args:
|
||||
key (str): Key of :attr:`_log_scalrs` or :attr:`_runtime_info`.
|
||||
resumed (bool): Whether the corresponding ``HistoryBuffer``
|
||||
could be resumed.
|
||||
"""
|
||||
if key not in self._resumed_keys:
|
||||
self._resumed_keys[key] = resumed
|
||||
else:
|
||||
assert self._resumed_keys[key] == resumed, \
|
||||
f'{key} used to be {self._resumed_keys[key]}, but got ' \
|
||||
'{resumed} now. resumed keys cannot be modified repeatedly'
|
||||
|
||||
@property
|
||||
def log_buffers(self) -> OrderedDict:
|
||||
"""Get all ``LogBuffer`` instances.
|
||||
def log_scalars(self) -> OrderedDict:
|
||||
"""Get all ``HistoryBuffer`` instances.
|
||||
|
||||
Note:
|
||||
Considering the large memory footprint of ``log_buffers`` in the
|
||||
post-training, ``MessageHub.log_buffers`` will not return the
|
||||
result of ``copy.deepcopy``.
|
||||
Considering the large memory footprint of history buffers in the
|
||||
post-training, :meth:`get_scalar` will return a reference of
|
||||
history buffer rather than a copy.
|
||||
|
||||
Returns:
|
||||
OrderedDict: All ``LogBuffer`` instances.
|
||||
OrderedDict: All ``HistoryBuffer`` instances.
|
||||
"""
|
||||
return self._log_buffers
|
||||
return self._log_scalars
|
||||
|
||||
@property
|
||||
def runtime_info(self) -> OrderedDict:
|
||||
|
@ -105,24 +231,25 @@ class MessageHub(ManagerMixin):
|
|||
"""
|
||||
return copy.deepcopy(self._runtime_info)
|
||||
|
||||
def get_log(self, key: str) -> LogBuffer:
|
||||
"""Get ``LogBuffer`` instance by key.
|
||||
def get_scalar(self, key: str) -> HistoryBuffer:
|
||||
"""Get ``HistoryBuffer`` instance by key.
|
||||
|
||||
Note:
|
||||
Considering the large memory footprint of ``log_buffers`` in the
|
||||
post-training, ``MessageHub.get_log`` will not return the
|
||||
result of ``copy.deepcopy``.
|
||||
Considering the large memory footprint of history buffers in the
|
||||
post-training, :meth:`get_scalar` will not return a reference of
|
||||
history buffer rather than a copy.
|
||||
|
||||
Args:
|
||||
key (str): Key of ``LogBuffer``.
|
||||
key (str): Key of ``HistoryBuffer``.
|
||||
|
||||
Returns:
|
||||
LogBuffer: Corresponding ``LogBuffer`` instance if the key exists.
|
||||
HistoryBuffer: Corresponding ``HistoryBuffer`` instance if the
|
||||
key exists.
|
||||
"""
|
||||
if key not in self.log_buffers:
|
||||
if key not in self.log_scalars:
|
||||
raise KeyError(f'{key} is not found in Messagehub.log_buffers: '
|
||||
f'instance name is: {MessageHub.instance_name}')
|
||||
return self._log_buffers[key]
|
||||
return self.log_scalars[key]
|
||||
|
||||
def get_info(self, key: str) -> Any:
|
||||
"""Get runtime information by key.
|
||||
|
@ -139,7 +266,7 @@ class MessageHub(ManagerMixin):
|
|||
return copy.deepcopy(self._runtime_info[key])
|
||||
|
||||
def _get_valid_value(self, key: str,
|
||||
value: Union[torch.Tensor, np.ndarray, int, float])\
|
||||
value: Union[torch.Tensor, np.ndarray, int, float]) \
|
||||
-> Union[int, float]:
|
||||
"""Convert value to python built-in type.
|
||||
|
||||
|
@ -158,4 +285,22 @@ class MessageHub(ManagerMixin):
|
|||
value = value.item()
|
||||
else:
|
||||
check_type(key, value, (int, float))
|
||||
return value
|
||||
return value # type: ignore
|
||||
|
||||
def __getstate__(self):
|
||||
for key in list(self._log_scalars.keys()):
|
||||
assert key in self._resumed_keys, (
|
||||
f'Cannot found {key} in {self}._resumed_keys, '
|
||||
'please make sure you do not change the _resumed_keys '
|
||||
'outside the class')
|
||||
if not self._resumed_keys[key]:
|
||||
self._log_scalars.pop(key)
|
||||
|
||||
for key in list(self._runtime_info.keys()):
|
||||
assert key in self._resumed_keys, (
|
||||
f'Cannot found {key} in {self}._resumed_keys, '
|
||||
'please make sure you do not change the _resumed_keys '
|
||||
'outside the class')
|
||||
if not self._resumed_keys[key]:
|
||||
self._runtime_info.pop(key)
|
||||
return self.__dict__
|
||||
|
|
|
@ -77,7 +77,7 @@ class EpochBasedTrainLoop(BaseLoop):
|
|||
|
||||
# TODO, should move to LoggerHook
|
||||
for key, value in self.runner.outputs['log_vars'].items():
|
||||
self.runner.message_hub.update_log(f'train/{key}', value)
|
||||
self.runner.message_hub.update_scalar(f'train/{key}', value)
|
||||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
|
@ -147,7 +147,7 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
|
||||
# TODO
|
||||
for key, value in self.runner.outputs['log_vars'].items():
|
||||
self.runner.message_hub.update_log(f'train/{key}', value)
|
||||
self.runner.message_hub.update_scalar(f'train/{key}', value)
|
||||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
|
@ -195,7 +195,7 @@ class ValLoop(BaseLoop):
|
|||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
for key, value in metrics.items():
|
||||
self.runner.message_hub.update_log(f'val/{key}', value)
|
||||
self.runner.message_hub.update_scalar(f'val/{key}', value)
|
||||
|
||||
self.runner.call_hook('after_val_epoch')
|
||||
self.runner.call_hook('after_val')
|
||||
|
@ -252,7 +252,7 @@ class TestLoop(BaseLoop):
|
|||
# compute metrics
|
||||
metrics = self.evaluator.evaluate(len(self.dataloader.dataset))
|
||||
for key, value in metrics.items():
|
||||
self.runner.message_hub.update_log(f'test/{key}', value)
|
||||
self.runner.message_hub.update_scalar(f'test/{key}', value)
|
||||
|
||||
self.runner.call_hook('after_test_epoch')
|
||||
self.runner.call_hook('after_test')
|
||||
|
|
|
@ -71,6 +71,8 @@ class ManagerMixin(metaclass=ManagerMeta):
|
|||
"""
|
||||
|
||||
def __init__(self, name: str = '', **kwargs):
|
||||
assert isinstance(name, str) and name, \
|
||||
'name argument must be an non-empty string.'
|
||||
self._instance_name = name
|
||||
|
||||
@classmethod
|
||||
|
@ -105,6 +107,10 @@ class ManagerMixin(metaclass=ManagerMeta):
|
|||
if name not in instance_dict:
|
||||
instance = cls(name=name, **kwargs)
|
||||
instance_dict[name] = instance
|
||||
else:
|
||||
assert not kwargs, (
|
||||
f'{cls} instance named of {name} has been created, the method '
|
||||
'`get_instance` should not access any other arguments')
|
||||
# Get latest instantiated instance or root instance.
|
||||
_release_lock()
|
||||
return instance_dict[name]
|
||||
|
@ -117,12 +123,12 @@ class ManagerMixin(metaclass=ManagerMeta):
|
|||
``get_instance(xxx)`` at least once.
|
||||
|
||||
Examples
|
||||
>>> instance = GlobalAccessible.get_current_instance(current=True)
|
||||
>>> instance = GlobalAccessible.get_current_instance()
|
||||
AssertionError: At least one of name and current needs to be set
|
||||
>>> instance = GlobalAccessible.get_instance('name1')
|
||||
>>> instance.instance_name
|
||||
name1
|
||||
>>> instance = GlobalAccessible.get_current_instance(current=True)
|
||||
>>> instance = GlobalAccessible.get_current_instance()
|
||||
>>> instance.instance_name
|
||||
name1
|
||||
|
||||
|
@ -132,9 +138,8 @@ class ManagerMixin(metaclass=ManagerMeta):
|
|||
_accquire_lock()
|
||||
if not cls._instance_dict:
|
||||
raise RuntimeError(
|
||||
f'Before calling {cls.__name__}.get_instance('
|
||||
'current=True), '
|
||||
'you should call get_instance(name=xxx) at least once.')
|
||||
f'Before calling {cls.__name__}.get_current_instance(), you '
|
||||
'should call get_instance(name=xxx) at least once.')
|
||||
name = next(iter(reversed(cls._instance_dict)))
|
||||
_release_lock()
|
||||
return cls._instance_dict[name]
|
||||
|
|
|
@ -18,7 +18,7 @@ class TestIterTimerHook:
|
|||
runner.log_buffer = dict()
|
||||
hook._before_epoch(runner)
|
||||
hook._before_iter(runner, 0)
|
||||
runner.message_hub.update_log.assert_called()
|
||||
runner.message_hub.update_scalar.assert_called()
|
||||
|
||||
def test_after_iter(self):
|
||||
hook = IterTimerHook()
|
||||
|
@ -26,4 +26,4 @@ class TestIterTimerHook:
|
|||
runner.log_buffer = dict()
|
||||
hook._before_epoch(runner)
|
||||
hook._after_iter(runner, 0)
|
||||
runner.message_hub.update_log.assert_called()
|
||||
runner.message_hub.update_scalar.assert_called()
|
||||
|
|
|
@ -287,7 +287,7 @@ class TestLoggerHook:
|
|||
'train/loss_cls': MagicMock(),
|
||||
'val/metric': MagicMock()
|
||||
}
|
||||
runner.message_hub.log_buffers = log_buffers
|
||||
runner.message_hub.log_scalars = log_buffers
|
||||
tag = logger_hook._collect_info(runner, mode='train')
|
||||
# Test parse custom_keys
|
||||
logger_hook._parse_custom_keys.assert_called()
|
||||
|
|
|
@ -3,15 +3,13 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine import LogBuffer
|
||||
from mmengine import HistoryBuffer
|
||||
|
||||
|
||||
class TestLoggerBuffer:
|
||||
|
||||
def test_init(self):
|
||||
# `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to
|
||||
# test `update` method
|
||||
log_buffer = LogBuffer()
|
||||
log_buffer = HistoryBuffer()
|
||||
assert log_buffer.max_length == 1000000
|
||||
log_history, counts = log_buffer.data
|
||||
assert len(log_history) == 0
|
||||
|
@ -19,7 +17,7 @@ class TestLoggerBuffer:
|
|||
# test the length of array exceed `max_length`
|
||||
logs = np.random.randint(1, 10, log_buffer.max_length + 1)
|
||||
counts = np.random.randint(1, 10, log_buffer.max_length + 1)
|
||||
log_buffer = LogBuffer(logs, counts)
|
||||
log_buffer = HistoryBuffer(logs, counts)
|
||||
log_history, count_history = log_buffer.data
|
||||
|
||||
assert len(log_history) == log_buffer.max_length
|
||||
|
@ -30,14 +28,13 @@ class TestLoggerBuffer:
|
|||
# The different lengths of `log_history` and `count_history` will
|
||||
# raise error
|
||||
with pytest.raises(AssertionError):
|
||||
LogBuffer([1, 2], [1])
|
||||
HistoryBuffer([1, 2], [1])
|
||||
|
||||
@pytest.mark.parametrize('array_method',
|
||||
[torch.tensor, np.array, lambda x: x])
|
||||
def test_update(self, array_method):
|
||||
# `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to
|
||||
# test `update` method
|
||||
log_buffer = LogBuffer()
|
||||
log_buffer = HistoryBuffer()
|
||||
log_history = array_method([1, 2, 3, 4, 5])
|
||||
count_history = array_method([5, 5, 5, 5, 5])
|
||||
for i in range(len(log_history)):
|
||||
|
@ -52,7 +49,7 @@ class TestLoggerBuffer:
|
|||
# test the length of `array` exceed `max_length`
|
||||
max_array = array_method([[-1] + [1] * (log_buffer.max_length - 1)])
|
||||
max_count = array_method([[-1] + [1] * (log_buffer.max_length - 1)])
|
||||
log_buffer = LogBuffer(max_array, max_count)
|
||||
log_buffer = HistoryBuffer(max_array, max_count)
|
||||
log_buffer.update(1)
|
||||
log_history, count_history = log_buffer.data
|
||||
assert log_history[0] == 1
|
||||
|
@ -69,7 +66,7 @@ class TestLoggerBuffer:
|
|||
def test_max_min(self, statistics_method, log_buffer_type):
|
||||
log_history = np.random.randint(1, 5, 20)
|
||||
count_history = np.ones(20)
|
||||
log_buffer = LogBuffer(log_history, count_history)
|
||||
log_buffer = HistoryBuffer(log_history, count_history)
|
||||
assert statistics_method(log_history[-10:]) == \
|
||||
getattr(log_buffer, log_buffer_type)(10)
|
||||
assert statistics_method(log_history) == \
|
||||
|
@ -78,7 +75,7 @@ class TestLoggerBuffer:
|
|||
def test_mean(self):
|
||||
log_history = np.random.randint(1, 5, 20)
|
||||
count_history = np.ones(20)
|
||||
log_buffer = LogBuffer(log_history, count_history)
|
||||
log_buffer = HistoryBuffer(log_history, count_history)
|
||||
assert np.sum(log_history[-10:]) / \
|
||||
np.sum(count_history[-10:]) == \
|
||||
log_buffer.mean(10)
|
||||
|
@ -89,17 +86,17 @@ class TestLoggerBuffer:
|
|||
def test_current(self):
|
||||
log_history = np.random.randint(1, 5, 20)
|
||||
count_history = np.ones(20)
|
||||
log_buffer = LogBuffer(log_history, count_history)
|
||||
log_buffer = HistoryBuffer(log_history, count_history)
|
||||
assert log_history[-1] == log_buffer.current()
|
||||
# test get empty array
|
||||
log_buffer = LogBuffer()
|
||||
log_buffer = HistoryBuffer()
|
||||
with pytest.raises(ValueError):
|
||||
log_buffer.current()
|
||||
|
||||
def test_statistics(self):
|
||||
log_history = np.array([1, 2, 3, 4, 5])
|
||||
count_history = np.array([1, 1, 1, 1, 1])
|
||||
log_buffer = LogBuffer(log_history, count_history)
|
||||
log_buffer = HistoryBuffer(log_history, count_history)
|
||||
assert log_buffer.statistics('mean') == 3
|
||||
assert log_buffer.statistics('min') == 1
|
||||
assert log_buffer.statistics('max') == 5
|
||||
|
@ -110,9 +107,9 @@ class TestLoggerBuffer:
|
|||
|
||||
def test_register_statistics(self):
|
||||
|
||||
@LogBuffer.register_statistics
|
||||
@HistoryBuffer.register_statistics
|
||||
def custom_statistics(self):
|
||||
return -1
|
||||
|
||||
log_buffer = LogBuffer()
|
||||
log_buffer = HistoryBuffer()
|
||||
assert log_buffer.statistics('custom_statistics') == -1
|
|
@ -11,14 +11,16 @@ from mmengine import MMLogger, print_log
|
|||
|
||||
|
||||
class TestLogger:
|
||||
regex_time = r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}'
|
||||
stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||
file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}'
|
||||
|
||||
@patch('torch.distributed.get_rank', lambda: 0)
|
||||
@patch('torch.distributed.is_initialized', lambda: True)
|
||||
@patch('torch.distributed.is_available', lambda: True)
|
||||
def test_init_rank0(self, tmp_path):
|
||||
logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO')
|
||||
assert logger.name == 'rank0.pkg1'
|
||||
assert logger.name == 'mmengine'
|
||||
assert logger.instance_name == 'rank0.pkg1'
|
||||
assert logger.instance_name == 'rank0.pkg1'
|
||||
# Logger get from `MMLogger.get_instance` does not inherit from
|
||||
# `logging.root`
|
||||
|
@ -38,6 +40,10 @@ class TestLogger:
|
|||
assert isinstance(logger.handlers[1], logging.FileHandler)
|
||||
logger_pkg3 = MMLogger.get_instance('rank0.pkg2')
|
||||
assert id(logger_pkg3) == id(logger)
|
||||
logger = MMLogger.get_instance(
|
||||
'rank0.pkg3', logger_name='logger_test', log_level='INFO')
|
||||
assert logger.name == 'logger_test'
|
||||
assert logger.instance_name == 'rank0.pkg3'
|
||||
logging.shutdown()
|
||||
|
||||
@patch('torch.distributed.get_rank', lambda: 1)
|
||||
|
@ -46,12 +52,18 @@ class TestLogger:
|
|||
def test_init_rank1(self, tmp_path):
|
||||
# If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`.
|
||||
tmp_file = tmp_path / 'tmp_file.log'
|
||||
log_path = tmp_path / 'rank1_tmp_file.log'
|
||||
log_path = tmp_path / 'tmp_file' / 'tmp_file_rank1.log'
|
||||
logger = MMLogger.get_instance(
|
||||
'rank1.pkg2', log_level='INFO', log_file=str(tmp_file))
|
||||
assert len(logger.handlers) == 2
|
||||
assert len(logger.handlers) == 1
|
||||
logger = MMLogger.get_instance(
|
||||
'rank1.pkg3',
|
||||
log_level='INFO',
|
||||
log_file=str(tmp_file),
|
||||
distributed=True)
|
||||
assert logger.handlers[0].level == logging.ERROR
|
||||
assert logger.handlers[1].level == logging.INFO
|
||||
assert len(logger.handlers) == 2
|
||||
assert os.path.exists(log_path)
|
||||
logging.shutdown()
|
||||
|
||||
|
@ -59,32 +71,33 @@ class TestLogger:
|
|||
[logging.WARNING, logging.INFO, logging.DEBUG])
|
||||
def test_handler(self, capsys, tmp_path, log_level):
|
||||
# test stream handler can output correct format logs
|
||||
logger_name = f'test_stream_{str(log_level)}'
|
||||
logger = MMLogger.get_instance(logger_name, log_level=log_level)
|
||||
instance_name = f'test_stream_{str(log_level)}'
|
||||
logger = MMLogger.get_instance(instance_name, log_level=log_level)
|
||||
logger.log(level=log_level, msg='welcome')
|
||||
out, _ = capsys.readouterr()
|
||||
# Skip match colored INFO
|
||||
loglevl_name = logging._levelToName[log_level]
|
||||
match = re.fullmatch(
|
||||
self.regex_time + f' - {logger_name} - '
|
||||
self.stream_handler_regex_time + f' - mmengine - '
|
||||
f'(.*){loglevl_name}(.*) - welcome\n', out)
|
||||
assert match is not None
|
||||
|
||||
# test file_handler output plain text without color.
|
||||
tmp_file = tmp_path / 'tmp_file.log'
|
||||
logger_name = f'test_file_{log_level}'
|
||||
instance_name = f'test_file_{log_level}'
|
||||
logger = MMLogger.get_instance(
|
||||
logger_name, log_level=log_level, log_file=tmp_file)
|
||||
instance_name, log_level=log_level, log_file=tmp_file)
|
||||
logger.log(level=log_level, msg='welcome')
|
||||
with open(tmp_file, 'r') as f:
|
||||
with open(tmp_path / 'tmp_file' / 'tmp_file.log', 'r') as f:
|
||||
log_text = f.read()
|
||||
match = re.fullmatch(
|
||||
self.regex_time + f' - {logger_name} - {loglevl_name} - '
|
||||
self.file_handler_regex_time +
|
||||
f' - mmengine - {loglevl_name} - '
|
||||
f'welcome\n', log_text)
|
||||
assert match is not None
|
||||
logging.shutdown()
|
||||
|
||||
def test_erro_format(self, capsys):
|
||||
def test_error_format(self, capsys):
|
||||
# test error level log can output file path, function name and
|
||||
# line number
|
||||
logger = MMLogger.get_instance('test_error', log_level='INFO')
|
||||
|
@ -92,9 +105,10 @@ class TestLogger:
|
|||
lineno = sys._getframe().f_lineno - 1
|
||||
file_path = __file__
|
||||
function_name = sys._getframe().f_code.co_name
|
||||
pattern = self.regex_time + r' - test_error - (.*)ERROR(.*) - '\
|
||||
f'{file_path} - {function_name} - ' \
|
||||
f'{lineno} - welcome\n'
|
||||
pattern = self.stream_handler_regex_time + \
|
||||
r' - mmengine - (.*)ERROR(.*) - ' \
|
||||
f'{file_path} - {function_name} - ' \
|
||||
f'{lineno} - welcome\n'
|
||||
out, _ = capsys.readouterr()
|
||||
match = re.fullmatch(pattern, out)
|
||||
assert match is not None
|
||||
|
@ -114,21 +128,21 @@ class TestLogger:
|
|||
print_log('welcome', logger=logger)
|
||||
out, _ = capsys.readouterr()
|
||||
match = re.fullmatch(
|
||||
self.regex_time + ' - test_print_log - (.*)INFO(.*) - '
|
||||
self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - '
|
||||
'welcome\n', out)
|
||||
assert match is not None
|
||||
# Test access logger by name.
|
||||
print_log('welcome', logger='test_print_log')
|
||||
out, _ = capsys.readouterr()
|
||||
match = re.fullmatch(
|
||||
self.regex_time + ' - test_print_log - (.*)INFO(.*) - '
|
||||
self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - '
|
||||
'welcome\n', out)
|
||||
assert match is not None
|
||||
# Test access the latest created logger.
|
||||
print_log('welcome', logger='current')
|
||||
out, _ = capsys.readouterr()
|
||||
match = re.fullmatch(
|
||||
self.regex_time + ' - test_print_log - (.*)INFO(.*) - '
|
||||
self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - '
|
||||
'welcome\n', out)
|
||||
assert match is not None
|
||||
# Test invalid logger type.
|
||||
|
|
|
@ -1,4 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pickle
|
||||
from collections import OrderedDict
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
@ -11,17 +14,26 @@ class TestMessageHub:
|
|||
def test_init(self):
|
||||
message_hub = MessageHub('name')
|
||||
assert message_hub.instance_name == 'name'
|
||||
assert len(message_hub.log_buffers) == 0
|
||||
assert len(message_hub.log_buffers) == 0
|
||||
assert len(message_hub.log_scalars) == 0
|
||||
assert len(message_hub.log_scalars) == 0
|
||||
# The type of log_scalars's value must be `HistoryBuffer`.
|
||||
with pytest.raises(AssertionError):
|
||||
MessageHub('hello', log_scalars=OrderedDict(a=1))
|
||||
# `Resumed_keys`
|
||||
with pytest.raises(AssertionError):
|
||||
MessageHub(
|
||||
'hello',
|
||||
runtime_info=OrderedDict(iter=1),
|
||||
resumed_keys=OrderedDict(iters=False))
|
||||
|
||||
def test_update_log(self):
|
||||
def test_update_scalar(self):
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
# test create target `LogBuffer` by name
|
||||
message_hub.update_log('name', 1)
|
||||
log_buffer = message_hub.log_buffers['name']
|
||||
# test create target `HistoryBuffer` by name
|
||||
message_hub.update_scalar('name', 1)
|
||||
log_buffer = message_hub.log_scalars['name']
|
||||
assert (log_buffer._log_history == np.array([1])).all()
|
||||
# test update target `LogBuffer` by name
|
||||
message_hub.update_log('name', 1)
|
||||
# test update target `HistoryBuffer` by name
|
||||
message_hub.update_scalar('name', 1)
|
||||
assert (log_buffer._log_history == np.array([1, 1])).all()
|
||||
# unmatched string will raise a key error
|
||||
|
||||
|
@ -33,19 +45,19 @@ class TestMessageHub:
|
|||
message_hub.update_info('key', 1)
|
||||
assert message_hub.runtime_info['key'] == 1
|
||||
|
||||
def test_get_log_buffers(self):
|
||||
def test_get_scalar(self):
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
# Get undefined key will raise error
|
||||
with pytest.raises(KeyError):
|
||||
message_hub.get_log('unknown')
|
||||
message_hub.get_scalar('unknown')
|
||||
# test get log_buffer as wished
|
||||
log_history = np.array([1, 2, 3, 4, 5])
|
||||
count = np.array([1, 1, 1, 1, 1])
|
||||
for i in range(len(log_history)):
|
||||
message_hub.update_log('test_value', float(log_history[i]),
|
||||
int(count[i]))
|
||||
message_hub.update_scalar('test_value', float(log_history[i]),
|
||||
int(count[i]))
|
||||
recorded_history, recorded_count = \
|
||||
message_hub.get_log('test_value').data
|
||||
message_hub.get_scalar('test_value').data
|
||||
assert (log_history == recorded_history).all()
|
||||
assert (recorded_count == count).all()
|
||||
|
||||
|
@ -57,18 +69,18 @@ class TestMessageHub:
|
|||
message_hub.update_info('test_value', recorded_dict)
|
||||
assert message_hub.get_info('test_value') == recorded_dict
|
||||
|
||||
def test_get_log_vars(self):
|
||||
def test_get_scalars(self):
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
log_dict = dict(
|
||||
loss=1,
|
||||
loss_cls=torch.tensor(2),
|
||||
loss_bbox=np.array(3),
|
||||
loss_iou=dict(value=1, count=2))
|
||||
message_hub.update_log_vars(log_dict)
|
||||
loss = message_hub.get_log('loss')
|
||||
loss_cls = message_hub.get_log('loss_cls')
|
||||
loss_bbox = message_hub.get_log('loss_bbox')
|
||||
loss_iou = message_hub.get_log('loss_iou')
|
||||
message_hub.update_scalars(log_dict)
|
||||
loss = message_hub.get_scalar('loss')
|
||||
loss_cls = message_hub.get_scalar('loss_cls')
|
||||
loss_bbox = message_hub.get_scalar('loss_bbox')
|
||||
loss_iou = message_hub.get_scalar('loss_iou')
|
||||
assert loss.current() == 1
|
||||
assert loss_cls.current() == 2
|
||||
assert loss_bbox.current() == 3
|
||||
|
@ -76,8 +88,27 @@ class TestMessageHub:
|
|||
|
||||
with pytest.raises(TypeError):
|
||||
loss_dict = dict(error_type=[])
|
||||
message_hub.update_log_vars(loss_dict)
|
||||
message_hub.update_scalars(loss_dict)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
loss_dict = dict(error_type=dict(count=1))
|
||||
message_hub.update_log_vars(loss_dict)
|
||||
message_hub.update_scalars(loss_dict)
|
||||
|
||||
def test_getstate(self):
|
||||
message_hub = MessageHub.get_instance('name')
|
||||
# update log_scalars.
|
||||
message_hub.update_scalar('loss', 0.1)
|
||||
message_hub.update_scalar('lr', 0.1, resumed=False)
|
||||
# update runtime information
|
||||
message_hub.update_info('iter', 1, resumed=True)
|
||||
message_hub.update_info('feat', [1, 2, 3], resumed=False)
|
||||
obj = pickle.dumps(message_hub)
|
||||
instance = pickle.loads(obj)
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
instance.get_info('feat')
|
||||
with pytest.raises(KeyError):
|
||||
instance.get_info('lr')
|
||||
|
||||
instance.get_info('iter')
|
||||
instance.get_scalar('loss')
|
||||
|
|
|
@ -70,3 +70,7 @@ class TestManagerMixin:
|
|||
# Non-string instance name will raise `AssertionError`.
|
||||
with pytest.raises(AssertionError):
|
||||
SubClassA.get_instance(name=1)
|
||||
# `get_instance` should not accept other arguments if corresponding
|
||||
# instance has been created.
|
||||
with pytest.raises(AssertionError):
|
||||
SubClassA.get_instance('name2', a=1, b=2)
|
||||
|
|
Loading…
Reference in New Issue