diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index bf123cae..d281745d 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -42,8 +42,8 @@ class IterTimerHook(Hook): mode (str): Current mode of runner. Defaults to 'train'. """ # TODO: update for new logging system - runner.message_hub.update_log(f'{mode}/data_time', - time.time() - self.t) + runner.message_hub.update_scalar(f'{mode}/data_time', + time.time() - self.t) def _after_iter(self, runner, @@ -65,5 +65,5 @@ class IterTimerHook(Hook): """ # TODO: update for new logging system - runner.message_hub.update_log(f'{mode}/time', time.time() - self.t) + runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t) self.t = time.time() diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index 786fd311..aed1d0e0 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -370,7 +370,7 @@ class LoggerHook(Hook): dict: Statistical values of logs. """ tag = OrderedDict() - log_buffers = runner.message_hub.log_buffers + log_buffers = runner.message_hub.log_scalars mode_log_buffers = OrderedDict() # Filter log_buffers which starts with `mode`. for prefix_key, log_buffer in log_buffers.items(): diff --git a/mmengine/logging/__init__.py b/mmengine/logging/__init__.py index 13945401..ba5533c2 100644 --- a/mmengine/logging/__init__.py +++ b/mmengine/logging/__init__.py @@ -1,6 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. -from .log_buffer import LogBuffer +from .history_buffer import HistoryBuffer from .logger import MMLogger, print_log from .message_hub import MessageHub -__all__ = ['LogBuffer', 'MessageHub', 'MMLogger', 'print_log'] +__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log'] diff --git a/mmengine/logging/log_buffer.py b/mmengine/logging/history_buffer.py similarity index 70% rename from mmengine/logging/log_buffer.py rename to mmengine/logging/history_buffer.py index c5523648..34a78ac6 100644 --- a/mmengine/logging/log_buffer.py +++ b/mmengine/logging/history_buffer.py @@ -1,16 +1,28 @@ # Copyright (c) OpenMMLab. All rights reserved. import warnings -from functools import partial from typing import Any, Callable, Optional, Sequence, Tuple, Union import numpy as np -class BaseLogBuffer: +class HistoryBuffer: """Unified storage format for different log types. - Record the history of log for further statistics. The subclass inherited - from ``BaseLogBuffer`` will implement the specific statistical methods. + ``HistoryBuffer`` records the history of log for further statistics. + + Examples: + >>> history_buffer = HistoryBuffer() + >>> # Update history_buffer. + >>> history_buffer.update(1) + >>> history_buffer.update(2) + >>> history_buffer.min() # minimum of (1, 2) + 1 + >>> history_buffer.max() # maximum of (1, 2) + 2 + >>> history_buffer.mean() # mean of (1, 2) + 1.5 + >>> history_buffer.statistics('mean') # access method by string. + 1.5 Args: log_history (Sequence): History logs. Defaults to []. @@ -25,6 +37,7 @@ class BaseLogBuffer: max_length: int = 1000000): self.max_length = max_length + self._set_default_statistics() assert len(log_history) == len(count_history), \ 'The lengths of log_history and count_histroy should be equal' if len(log_history) > max_length: @@ -37,10 +50,18 @@ class BaseLogBuffer: self._log_history = np.array(log_history) self._count_history = np.array(count_history) + def _set_default_statistics(self) -> None: + """Register default statistic methods: min, max, current and mean.""" + self._statistics_methods.setdefault('min', HistoryBuffer.min) + self._statistics_methods.setdefault('max', HistoryBuffer.max) + self._statistics_methods.setdefault('current', HistoryBuffer.current) + self._statistics_methods.setdefault('mean', HistoryBuffer.mean) + def update(self, log_val: Union[int, float], count: int = 1) -> None: - """update the log history. If the length of the buffer exceeds - ``self._max_length``, the oldest element will be removed from the - buffer. + """update the log history. + + If the length of the buffer exceeds ``self._max_length``, the oldest + element will be removed from the buffer. Args: log_val (int or float): The value of log. @@ -72,9 +93,23 @@ class BaseLogBuffer: def register_statistics(cls, method: Callable) -> Callable: """Register custom statistics method to ``_statistics_methods``. + The registered method can be called by ``history_buffer.statistics`` + with corresponding method name and arguments. + + Examples: + >>> @HistoryBuffer.register_statistics + >>> def weighted_mean(self, window_size, weight): + >>> assert len(weight) == window_size + >>> return (self._log_history[-window_size:] * + >>> np.array(weight)).sum() / \ + >>> self._count_history[-window_size:] + + >>> log_buffer = HistoryBuffer([1, 2], [1, 1]) + >>> log_buffer.statistics('weighted_mean', 2, [2, 1]) + 2 + Args: method (Callable): Custom statistics method. - Returns: Callable: Original custom statistics method. """ @@ -95,26 +130,20 @@ class BaseLogBuffer: """ if method_name not in self._statistics_methods: raise KeyError(f'{method_name} has not been registered in ' - 'BaseLogBuffer._statistics_methods') + 'HistoryBuffer._statistics_methods') method = self._statistics_methods[method_name] # Provide self arguments for registered functions. - method = partial(method, self) - return method(*arg, **kwargs) + return method(self, *arg, **kwargs) - -class LogBuffer(BaseLogBuffer): - """``LogBuffer`` inherits from ``BaseLogBuffer`` and provides some basic - statistics methods, such as ``min``, ``max``, ``current`` and ``mean``.""" - - @BaseLogBuffer.register_statistics def mean(self, window_size: Optional[int] = None) -> np.ndarray: """Return the mean of the latest ``window_size`` values in log - histories. If ``window_size is None``, return the global mean of - history logs. + histories. + + If ``window_size is None`` or ``window_size > len(self._log_history)``, + return the global mean value of history logs. Args: window_size (int, optional): Size of statistics window. - Returns: np.ndarray: Mean value within the window. """ @@ -128,15 +157,15 @@ class LogBuffer(BaseLogBuffer): counts_sum = self._count_history[-window_size:].sum() return logs_sum / counts_sum - @BaseLogBuffer.register_statistics def max(self, window_size: Optional[int] = None) -> np.ndarray: """Return the maximum value of the latest ``window_size`` values in log - histories. If ``window_size is None``, return the global maximum value - of history logs. + histories. + + If ``window_size is None`` or ``window_size > len(self._log_history)``, + return the global maximum value of history logs. Args: window_size (int, optional): Size of statistics window. - Returns: np.ndarray: The maximum value within the window. """ @@ -148,15 +177,15 @@ class LogBuffer(BaseLogBuffer): window_size = len(self._log_history) return self._log_history[-window_size:].max() - @BaseLogBuffer.register_statistics def min(self, window_size: Optional[int] = None) -> np.ndarray: """Return the minimum value of the latest ``window_size`` values in log - histories. If ``window_size is None``, return the global minimum value - of history logs. + histories. + + If ``window_size is None`` or ``window_size > len(self._log_history)``, + return the global minimum value of history logs. Args: window_size (int, optional): Size of statistics window. - Returns: np.ndarray: The minimum value within the window. """ @@ -168,7 +197,6 @@ class LogBuffer(BaseLogBuffer): window_size = len(self._log_history) return self._log_history[-window_size:].min() - @BaseLogBuffer.register_statistics def current(self) -> np.ndarray: """Return the recently updated values in log histories. @@ -176,6 +204,6 @@ class LogBuffer(BaseLogBuffer): np.ndarray: Recently updated values in log histories. """ if len(self._log_history) == 0: - raise ValueError('LogBuffer._log_history is an empty array! ' + raise ValueError('HistoryBuffer._log_history is an empty array! ' 'please call update first') return self._log_history[-1] diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index 25a14520..3ae26524 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import logging import os +import os.path as osp import sys from logging import Logger, LogRecord from typing import Optional, Union @@ -8,7 +9,7 @@ from typing import Optional, Union import torch.distributed as dist from termcolor import colored -from mmengine.utils import ManagerMixin +from mmengine.utils import ManagerMixin, mkdir_or_exist class MMFormatter(logging.Formatter): @@ -85,54 +86,134 @@ class MMFormatter(logging.Formatter): class MMLogger(Logger, ManagerMixin): - """The Logger manager which can create formatted logger and get specified - logger globally. MMLogger is created and accessed in the same way as - ManagerMixin. + """Formatted logger used to record messages. + + ``MMLogger`` can create formatted logger to log message with different + log levels and get instance in the same way as ``ManagerMixin``. + ``MMLogger`` has the following features: + + - Distributed log storage, ``MMLogger`` can choose whether to save log of + different ranks according to `log_file`. + - Message with different log levels will have different colors and format + when displayed on terminal. + + Note: + - The `name` of logger and the ``instance_name`` of ``MMLogger`` could + be different. We can only get ``MMLogger`` instance by + ``MMLogger.get_instance`` but not ``logging.getLogger``. This feature + ensures ``MMLogger`` will not be incluenced by third-party logging + config. + - Different from ``logging.Logger``, ``MMLogger`` will not log warrning + or error message without ``Handler``. + - If `log_file=/path/to/tmp.log`, all logs will be saved to + `/path/to/tmp/tmp.log` + + Examples: + >>> logger = MMLogger.get_instance(name='MMLogger', + >>> logger_name='Logger') + >>> # Although logger has name attribute just like `logging.Logger` + >>> # We cannot get logger instance by `logging.getLogger`. + >>> assert logger.name == 'Logger' + >>> assert logger.instance_name = 'MMLogger' + >>> assert id(logger) != id(logging.getLogger('Logger')) + >>> # Get logger that do not store logs. + >>> logger1 = MMLogger.get_instance('logger1') + >>> # Get logger only save rank0 logs. + >>> logger2 = MMLogger.get_instance('logger2', log_file='out.log') + >>> # Get logger only save multiple ranks logs. + >>> logger3 = MMLogger.get_instance('logger3', log_file='out.log', + >>> distributed=True) Args: - name (str): Logger name. Defaults to ''. + name (str): Global instance name. + logger_name (str): ``name`` attribute of ``Logging.Logger`` instance. + If `logger_name` is not defined, defaults to 'mmengine'. log_file (str, optional): The log filename. If specified, a ``FileHandler`` will be added to the logger. Defaults to None. - log_level: The log level of the handler. Defaults to 'NOTSET'. - file_mode (str): The file mode used in opening log file. - Defaults to 'w'. + log_level (str): The log level of the handler and logger. Defaults to + "NOTSET". + file_mode (str): The file mode used to open log file. Defaults to 'w'. + distributed (bool): Whether to save distributed logs, Defaults to + false. """ def __init__(self, - name: str = '', + name: str, + logger_name='mmengine', log_file: Optional[str] = None, - log_level: str = 'NOTSET', - file_mode: str = 'w'): - Logger.__init__(self, name) + log_level: str = 'INFO', + file_mode: str = 'w', + distributed=False): + Logger.__init__(self, logger_name) ManagerMixin.__init__(self, name) # Get rank in DDP mode. if dist.is_available() and dist.is_initialized(): rank = dist.get_rank() else: rank = 0 - # Config stream_handler. If `rank != 0`. stream_handler can only # export ERROR logs. stream_handler = logging.StreamHandler(stream=sys.stdout) - stream_handler.setFormatter(MMFormatter(color=True)) + # `StreamHandler` record month, day, hour, minute, and second + # timestamp. + stream_handler.setFormatter( + MMFormatter(color=True, datefmt='%m/%d %H:%M:%S')) + # Only rank0 `StreamHandler` will log messages below error level. stream_handler.setLevel(log_level) if rank == 0 else \ stream_handler.setLevel(logging.ERROR) self.handlers.append(stream_handler) if log_file is not None: + # If `log_file=/path/to/tmp.log`, all logs will be saved to + # `/path/to/tmp/tmp.log` + log_dir = osp.dirname(log_file) + filename = osp.basename(log_file) + filename_list = filename.split('.') + sub_file_name = '.'.join(filename_list[:-1]) + log_dir = osp.join(log_dir, sub_file_name) + mkdir_or_exist(log_dir) + log_file = osp.join(log_dir, filename) if rank != 0: - # rename `log_file` with rank prefix. + # rename `log_file` with rank suffix. path_split = log_file.split(os.sep) - path_split[-1] = f'rank{rank}_{path_split[-1]}' + if '.' in path_split[-1]: + filename_list = path_split[-1].split('.') + filename_list[-2] = f'{filename_list[-2]}_rank{rank}' + path_split[-1] = '.'.join(filename_list) + else: + path_split[-1] = f'{path_split[-1]}_rank{rank}' log_file = os.sep.join(path_split) - # Here, the default behaviour of the official logger is 'a'. Thus, - # we provide an interface to change the file mode to the default - # behaviour. `FileHandler` is not supported to have colors, - # otherwise it will appear garbled. - file_handler = logging.FileHandler(log_file, file_mode) - file_handler.setFormatter(MMFormatter(color=False)) - file_handler.setLevel(log_level) - self.handlers.append(file_handler) + # Save multi-ranks logs if distributed is True. The logs of rank0 + # will always be saved. + if rank == 0 or distributed: + # Here, the default behaviour of the official logger is 'a'. + # Thus, we provide an interface to change the file mode to + # the default behaviour. `FileHandler` is not supported to + # have colors, otherwise it will appear garbled. + file_handler = logging.FileHandler(log_file, file_mode) + # `StreamHandler` record year, month, day hour, minute, + # and second timestamp. file_handler will only record logs + # without color to avoid garbled code saved in files. + file_handler.setFormatter( + MMFormatter(color=False, datefmt='%Y/%m/%d %H:%M:%S')) + file_handler.setLevel(log_level) + self.handlers.append(file_handler) + + def callHandlers(self, record: LogRecord) -> None: + """Pass a record to all relevant handlers. + + Override ``callHandlers`` method in ``logging.Logger`` to avoid + multiple warning messages in DDP mode. Loop through all handlers of + the logger instance and its parents in the logger hierarchy. If no + handler was found, the record will not be output. + + Args: + record (LogRecord): A ``LogRecord`` instance contains logged + message. + """ + for handler in self.handlers: + if record.levelno >= handler.level: + handler.handle(record) def print_log(msg, diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 75a2a4bc..f6f393a0 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -1,14 +1,14 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy from collections import OrderedDict -from typing import Any, Union +from typing import Any, Optional, Union import numpy as np import torch from mmengine.utils import ManagerMixin from mmengine.visualization.utils import check_type -from .log_buffer import LogBuffer +from .history_buffer import HistoryBuffer class MessageHub(ManagerMixin): @@ -17,84 +17,210 @@ class MessageHub(ManagerMixin): ``MessageHub`` will record log information and runtime information. The log information refers to the learning rate, loss, etc. of the model - when training a model, which will be stored as ``LogBuffer``. The runtime - information refers to the iter times, meta information of runner etc., - which will be overwritten by next update. + during training phase, which will be stored as ``HistoryBuffer``. The + runtime information refers to the iter times, meta information of + runner etc., which will be overwritten by next update. Args: - name (str): Name of message hub, for global access. Defaults to ''. + name (str): Name of message hub used to get corresponding instance + globally. + log_scalars (OrderedDict, optional): Each key-value pair in the + dictionary is the name of the log information such as "loss", "lr", + "metric" and their corresponding values. The type of value must be + HistoryBuffer. Defaults to None. + runtime_info (OrderedDict, optional): Each key-value pair in the + dictionary is the name of the runtime information and their + corresponding values. Defaults to None. + resumed_keys (OrderedDict, optional): Each key-value pair in the + dictionary decides whether the key in :attr:`_log_scalars` and + :attr:`_runtime_info` will be serialized. + + Note: + Key in :attr:`_resumed_keys` belongs to :attr:`_log_scalars` or + :attr:`_runtime_info`. The corresponding value cannot be set + repeatedly. + + Examples: + >>> # create empty `MessageHub`. + >>> message_hub1 = MessageHub() + >>> log_scalars = OrderedDict(loss=HistoryBuffer()) + >>> runtime_info = OrderedDict(task='task') + >>> resumed_keys = dict(loss=True) + >>> # create `MessageHub` from data. + >>> message_hub2 = MessageHub( + >>> name='name', + >>> log_scalars=log_scalars, + >>> runtime_info=runtime_info, + >>> resumed_keys=resumed_keys) """ - def __init__(self, name: str = ''): - self._log_buffers: OrderedDict = OrderedDict() - self._runtime_info: OrderedDict = OrderedDict() + def __init__(self, + name: str, + log_scalars: Optional[OrderedDict] = None, + runtime_info: Optional[OrderedDict] = None, + resumed_keys: Optional[OrderedDict] = None): super().__init__(name) + self._log_scalars = log_scalars if log_scalars is not None else \ + OrderedDict() + self._runtime_info = runtime_info if runtime_info is not None else \ + OrderedDict() + self._resumed_keys = resumed_keys if resumed_keys is not None else \ + OrderedDict() - def update_log(self, key: str, value: Union[int, float], count: int = 1) \ - -> None: - """Update log buffer. + assert isinstance(self._log_scalars, OrderedDict) + assert isinstance(self._runtime_info, OrderedDict) + assert isinstance(self._resumed_keys, OrderedDict) + + for value in self._log_scalars.values(): + assert isinstance(value, HistoryBuffer), \ + ("The type of log_scalars'value must be HistoryBuffer, but " + f'got {type(value)}') + + for key in self._resumed_keys.keys(): + assert key in self._log_scalars or key in self._runtime_info, \ + ('Key in `resumed_keys` must contained in `log_scalars` or ' + f'`runtime_info`, but got {key}') + + def update_scalar(self, + key: str, + value: Union[int, float, np.ndarray, torch.Tensor], + count: int = 1, + resumed: bool = True) -> None: + """Update :attr:_log_scalars. + + Update ``HistoryBuffer`` in :attr:`_log_scalars`. If corresponding key + ``HistoryBuffer`` has been created, ``value`` and ``count`` is the + argument of ``HistoryBuffer.update``, Otherwise, ``update_scalar`` + will create an ``HistoryBuffer`` with value and count via the + constructor of ``HistoryBuffer``. + + Examples: + >>> message_hub = MessageHub + >>> # create loss `HistoryBuffer` with value=1, count=1 + >>> message_hub.update_scalar('loss', 1) + >>> # update loss `HistoryBuffer` with value + >>> message_hub.update_scalar('loss', 3) + >>> message_hub.update_scalar('loss', 3, resumed=False) + AssertionError: loss used to be true, but got false now. resumed + keys cannot be modified repeatedly' + + Note: + resumed cannot be set repeatedly for the same key. Args: - key (str): Key of ``LogBuffer``. - value (int or float): Value of log. - count (int): Accumulation times of log, defaults to 1. `count` - will be used in smooth statistics. + key (str): Key of ``HistoryBuffer``. + value (torch.Tensor or np.ndarray or int or float): Value of log. + count (torch.Tensor or np.ndarray or int or float): Accumulation + times of log, defaults to 1. `count` will be used in smooth + statistics. + resumed (str): Whether the corresponding ``HistoryBuffer`` + could be resumed. Defaults to True. """ - if key in self._log_buffers: - self._log_buffers[key].update(value, count) + self._set_resumed_keys(key, resumed) + checked_value = self._get_valid_value(key, value) + assert isinstance(count, int), ( + f'The type of count must be int. but got {type(count): {count}}') + if key in self._log_scalars: + self._log_scalars[key].update(checked_value, count) else: - self._log_buffers[key] = LogBuffer([value], [count]) + self._log_scalars[key] = HistoryBuffer([checked_value], [count]) - def update_log_vars(self, log_dict: dict) -> None: - """Update :attr:`_log_buffers` with a dict. + def update_scalars(self, log_dict: dict, resumed: bool = True) -> None: + """Update :attr:`_log_scalars` with a dict. + + ``update_scalars`` iterates through each pair of log_dict key-value, + and calls ``update_scalar``. If type of value is dict, the value should + be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in + ``log_dict`` has the same resume option. Args: - log_dict (str): Used for batch updating :attr:`_log_buffers`. + log_dict (str): Used for batch updating :attr:`_log_scalars`. + resumed (bool): Whether all ``HistoryBuffer`` referred in + log_dict should be resumed. Defaults to True. Examples: >>> message_hub = MessageHub.get_instance('mmengine') >>> log_dict = dict(a=1, b=2, c=3) - >>> message_hub.update_log_vars(log_dict) + >>> message_hub.update_scalars(log_dict) >>> # The default count of `a`, `b` and `c` is 1. >>> log_dict = dict(a=1, b=2, c=dict(value=1, count=2)) - >>> message_hub.update_log_vars(log_dict) + >>> message_hub.update_scalars(log_dict) >>> # The count of `c` is 2. """ assert isinstance(log_dict, dict), ('`log_dict` must be a dict!, ' f'but got {type(log_dict)}') for log_name, log_val in log_dict.items(): + self._set_resumed_keys(log_name, resumed) if isinstance(log_val, dict): assert 'value' in log_val, \ f'value must be defined in {log_val}' - count = log_val.get('count', 1) - value = self._get_valid_value(log_name, log_val['value']) + count = self._get_valid_value(log_name, + log_val.get('count', 1)) + checked_value = self._get_valid_value(log_name, + log_val['value']) else: - value = self._get_valid_value(log_name, log_val) count = 1 - self.update_log(log_name, value, count) + checked_value = self._get_valid_value(log_name, log_val) + assert isinstance(count, + int), ('The type of count must be int. but got ' + f'{type(count): {count}}') + self.update_scalar(log_name, checked_value, count) - def update_info(self, key: str, value: Any) -> None: + def update_info(self, key: str, value: Any, resumed: bool = True) -> None: """Update runtime information. + The key corresponding runtime information will be overwritten each + time calling ``update_info``. + + Note: + resumed cannot be set repeatedly for the same key. + + Examples: + >>> message_hub = MessageHub() + >>> message_hub.update_info('iter', 100) + Args: key (str): Key of runtime information. value (Any): Value of runtime information. + resumed (bool): Whether the corresponding ``HistoryBuffer`` + could be resumed. """ + self._set_resumed_keys(key, resumed) + self._resumed_keys[key] = resumed self._runtime_info[key] = value + def _set_resumed_keys(self, key: str, resumed: bool) -> None: + """Set corresponding resumed keys. + + This method is called by ``update_scalar``, ``update_scalars`` and + ``update_info`` to set the corresponding key is true or false in + :attr:`_resumed_keys`. + + Args: + key (str): Key of :attr:`_log_scalrs` or :attr:`_runtime_info`. + resumed (bool): Whether the corresponding ``HistoryBuffer`` + could be resumed. + """ + if key not in self._resumed_keys: + self._resumed_keys[key] = resumed + else: + assert self._resumed_keys[key] == resumed, \ + f'{key} used to be {self._resumed_keys[key]}, but got ' \ + '{resumed} now. resumed keys cannot be modified repeatedly' + @property - def log_buffers(self) -> OrderedDict: - """Get all ``LogBuffer`` instances. + def log_scalars(self) -> OrderedDict: + """Get all ``HistoryBuffer`` instances. Note: - Considering the large memory footprint of ``log_buffers`` in the - post-training, ``MessageHub.log_buffers`` will not return the - result of ``copy.deepcopy``. + Considering the large memory footprint of history buffers in the + post-training, :meth:`get_scalar` will return a reference of + history buffer rather than a copy. Returns: - OrderedDict: All ``LogBuffer`` instances. + OrderedDict: All ``HistoryBuffer`` instances. """ - return self._log_buffers + return self._log_scalars @property def runtime_info(self) -> OrderedDict: @@ -105,24 +231,25 @@ class MessageHub(ManagerMixin): """ return copy.deepcopy(self._runtime_info) - def get_log(self, key: str) -> LogBuffer: - """Get ``LogBuffer`` instance by key. + def get_scalar(self, key: str) -> HistoryBuffer: + """Get ``HistoryBuffer`` instance by key. Note: - Considering the large memory footprint of ``log_buffers`` in the - post-training, ``MessageHub.get_log`` will not return the - result of ``copy.deepcopy``. + Considering the large memory footprint of history buffers in the + post-training, :meth:`get_scalar` will not return a reference of + history buffer rather than a copy. Args: - key (str): Key of ``LogBuffer``. + key (str): Key of ``HistoryBuffer``. Returns: - LogBuffer: Corresponding ``LogBuffer`` instance if the key exists. + HistoryBuffer: Corresponding ``HistoryBuffer`` instance if the + key exists. """ - if key not in self.log_buffers: + if key not in self.log_scalars: raise KeyError(f'{key} is not found in Messagehub.log_buffers: ' f'instance name is: {MessageHub.instance_name}') - return self._log_buffers[key] + return self.log_scalars[key] def get_info(self, key: str) -> Any: """Get runtime information by key. @@ -139,7 +266,7 @@ class MessageHub(ManagerMixin): return copy.deepcopy(self._runtime_info[key]) def _get_valid_value(self, key: str, - value: Union[torch.Tensor, np.ndarray, int, float])\ + value: Union[torch.Tensor, np.ndarray, int, float]) \ -> Union[int, float]: """Convert value to python built-in type. @@ -158,4 +285,22 @@ class MessageHub(ManagerMixin): value = value.item() else: check_type(key, value, (int, float)) - return value + return value # type: ignore + + def __getstate__(self): + for key in list(self._log_scalars.keys()): + assert key in self._resumed_keys, ( + f'Cannot found {key} in {self}._resumed_keys, ' + 'please make sure you do not change the _resumed_keys ' + 'outside the class') + if not self._resumed_keys[key]: + self._log_scalars.pop(key) + + for key in list(self._runtime_info.keys()): + assert key in self._resumed_keys, ( + f'Cannot found {key} in {self}._resumed_keys, ' + 'please make sure you do not change the _resumed_keys ' + 'outside the class') + if not self._resumed_keys[key]: + self._runtime_info.pop(key) + return self.__dict__ diff --git a/mmengine/runner/loops.py b/mmengine/runner/loops.py index 62d99317..4de30628 100644 --- a/mmengine/runner/loops.py +++ b/mmengine/runner/loops.py @@ -77,7 +77,7 @@ class EpochBasedTrainLoop(BaseLoop): # TODO, should move to LoggerHook for key, value in self.runner.outputs['log_vars'].items(): - self.runner.message_hub.update_log(f'train/{key}', value) + self.runner.message_hub.update_scalar(f'train/{key}', value) self.runner.call_hook( 'after_train_iter', @@ -147,7 +147,7 @@ class IterBasedTrainLoop(BaseLoop): # TODO for key, value in self.runner.outputs['log_vars'].items(): - self.runner.message_hub.update_log(f'train/{key}', value) + self.runner.message_hub.update_scalar(f'train/{key}', value) self.runner.call_hook( 'after_train_iter', @@ -195,7 +195,7 @@ class ValLoop(BaseLoop): # compute metrics metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) for key, value in metrics.items(): - self.runner.message_hub.update_log(f'val/{key}', value) + self.runner.message_hub.update_scalar(f'val/{key}', value) self.runner.call_hook('after_val_epoch') self.runner.call_hook('after_val') @@ -252,7 +252,7 @@ class TestLoop(BaseLoop): # compute metrics metrics = self.evaluator.evaluate(len(self.dataloader.dataset)) for key, value in metrics.items(): - self.runner.message_hub.update_log(f'test/{key}', value) + self.runner.message_hub.update_scalar(f'test/{key}', value) self.runner.call_hook('after_test_epoch') self.runner.call_hook('after_test') diff --git a/mmengine/utils/manager.py b/mmengine/utils/manager.py index 923ba0f0..b00bf981 100644 --- a/mmengine/utils/manager.py +++ b/mmengine/utils/manager.py @@ -71,6 +71,8 @@ class ManagerMixin(metaclass=ManagerMeta): """ def __init__(self, name: str = '', **kwargs): + assert isinstance(name, str) and name, \ + 'name argument must be an non-empty string.' self._instance_name = name @classmethod @@ -105,6 +107,10 @@ class ManagerMixin(metaclass=ManagerMeta): if name not in instance_dict: instance = cls(name=name, **kwargs) instance_dict[name] = instance + else: + assert not kwargs, ( + f'{cls} instance named of {name} has been created, the method ' + '`get_instance` should not access any other arguments') # Get latest instantiated instance or root instance. _release_lock() return instance_dict[name] @@ -117,12 +123,12 @@ class ManagerMixin(metaclass=ManagerMeta): ``get_instance(xxx)`` at least once. Examples - >>> instance = GlobalAccessible.get_current_instance(current=True) + >>> instance = GlobalAccessible.get_current_instance() AssertionError: At least one of name and current needs to be set >>> instance = GlobalAccessible.get_instance('name1') >>> instance.instance_name name1 - >>> instance = GlobalAccessible.get_current_instance(current=True) + >>> instance = GlobalAccessible.get_current_instance() >>> instance.instance_name name1 @@ -132,9 +138,8 @@ class ManagerMixin(metaclass=ManagerMeta): _accquire_lock() if not cls._instance_dict: raise RuntimeError( - f'Before calling {cls.__name__}.get_instance(' - 'current=True), ' - 'you should call get_instance(name=xxx) at least once.') + f'Before calling {cls.__name__}.get_current_instance(), you ' + 'should call get_instance(name=xxx) at least once.') name = next(iter(reversed(cls._instance_dict))) _release_lock() return cls._instance_dict[name] diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py index 44e09bc3..af149f2f 100644 --- a/tests/test_hook/test_iter_timer_hook.py +++ b/tests/test_hook/test_iter_timer_hook.py @@ -18,7 +18,7 @@ class TestIterTimerHook: runner.log_buffer = dict() hook._before_epoch(runner) hook._before_iter(runner, 0) - runner.message_hub.update_log.assert_called() + runner.message_hub.update_scalar.assert_called() def test_after_iter(self): hook = IterTimerHook() @@ -26,4 +26,4 @@ class TestIterTimerHook: runner.log_buffer = dict() hook._before_epoch(runner) hook._after_iter(runner, 0) - runner.message_hub.update_log.assert_called() + runner.message_hub.update_scalar.assert_called() diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index 6fac1a93..cac2e45b 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -287,7 +287,7 @@ class TestLoggerHook: 'train/loss_cls': MagicMock(), 'val/metric': MagicMock() } - runner.message_hub.log_buffers = log_buffers + runner.message_hub.log_scalars = log_buffers tag = logger_hook._collect_info(runner, mode='train') # Test parse custom_keys logger_hook._parse_custom_keys.assert_called() diff --git a/tests/test_logging/test_loggger_buffer.py b/tests/test_logging/test_history_buffer.py similarity index 84% rename from tests/test_logging/test_loggger_buffer.py rename to tests/test_logging/test_history_buffer.py index 5ea5ff82..5cdda029 100644 --- a/tests/test_logging/test_loggger_buffer.py +++ b/tests/test_logging/test_history_buffer.py @@ -3,15 +3,13 @@ import numpy as np import pytest import torch -from mmengine import LogBuffer +from mmengine import HistoryBuffer class TestLoggerBuffer: def test_init(self): - # `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to - # test `update` method - log_buffer = LogBuffer() + log_buffer = HistoryBuffer() assert log_buffer.max_length == 1000000 log_history, counts = log_buffer.data assert len(log_history) == 0 @@ -19,7 +17,7 @@ class TestLoggerBuffer: # test the length of array exceed `max_length` logs = np.random.randint(1, 10, log_buffer.max_length + 1) counts = np.random.randint(1, 10, log_buffer.max_length + 1) - log_buffer = LogBuffer(logs, counts) + log_buffer = HistoryBuffer(logs, counts) log_history, count_history = log_buffer.data assert len(log_history) == log_buffer.max_length @@ -30,14 +28,13 @@ class TestLoggerBuffer: # The different lengths of `log_history` and `count_history` will # raise error with pytest.raises(AssertionError): - LogBuffer([1, 2], [1]) + HistoryBuffer([1, 2], [1]) @pytest.mark.parametrize('array_method', [torch.tensor, np.array, lambda x: x]) def test_update(self, array_method): - # `BaseLogBuffer` is an abstract class, using `CurrentLogBuffer` to # test `update` method - log_buffer = LogBuffer() + log_buffer = HistoryBuffer() log_history = array_method([1, 2, 3, 4, 5]) count_history = array_method([5, 5, 5, 5, 5]) for i in range(len(log_history)): @@ -52,7 +49,7 @@ class TestLoggerBuffer: # test the length of `array` exceed `max_length` max_array = array_method([[-1] + [1] * (log_buffer.max_length - 1)]) max_count = array_method([[-1] + [1] * (log_buffer.max_length - 1)]) - log_buffer = LogBuffer(max_array, max_count) + log_buffer = HistoryBuffer(max_array, max_count) log_buffer.update(1) log_history, count_history = log_buffer.data assert log_history[0] == 1 @@ -69,7 +66,7 @@ class TestLoggerBuffer: def test_max_min(self, statistics_method, log_buffer_type): log_history = np.random.randint(1, 5, 20) count_history = np.ones(20) - log_buffer = LogBuffer(log_history, count_history) + log_buffer = HistoryBuffer(log_history, count_history) assert statistics_method(log_history[-10:]) == \ getattr(log_buffer, log_buffer_type)(10) assert statistics_method(log_history) == \ @@ -78,7 +75,7 @@ class TestLoggerBuffer: def test_mean(self): log_history = np.random.randint(1, 5, 20) count_history = np.ones(20) - log_buffer = LogBuffer(log_history, count_history) + log_buffer = HistoryBuffer(log_history, count_history) assert np.sum(log_history[-10:]) / \ np.sum(count_history[-10:]) == \ log_buffer.mean(10) @@ -89,17 +86,17 @@ class TestLoggerBuffer: def test_current(self): log_history = np.random.randint(1, 5, 20) count_history = np.ones(20) - log_buffer = LogBuffer(log_history, count_history) + log_buffer = HistoryBuffer(log_history, count_history) assert log_history[-1] == log_buffer.current() # test get empty array - log_buffer = LogBuffer() + log_buffer = HistoryBuffer() with pytest.raises(ValueError): log_buffer.current() def test_statistics(self): log_history = np.array([1, 2, 3, 4, 5]) count_history = np.array([1, 1, 1, 1, 1]) - log_buffer = LogBuffer(log_history, count_history) + log_buffer = HistoryBuffer(log_history, count_history) assert log_buffer.statistics('mean') == 3 assert log_buffer.statistics('min') == 1 assert log_buffer.statistics('max') == 5 @@ -110,9 +107,9 @@ class TestLoggerBuffer: def test_register_statistics(self): - @LogBuffer.register_statistics + @HistoryBuffer.register_statistics def custom_statistics(self): return -1 - log_buffer = LogBuffer() + log_buffer = HistoryBuffer() assert log_buffer.statistics('custom_statistics') == -1 diff --git a/tests/test_logging/test_logger.py b/tests/test_logging/test_logger.py index 8c41fe24..0483cd97 100644 --- a/tests/test_logging/test_logger.py +++ b/tests/test_logging/test_logger.py @@ -11,14 +11,16 @@ from mmengine import MMLogger, print_log class TestLogger: - regex_time = r'\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2},\d{3}' + stream_handler_regex_time = r'\d{2}/\d{2} \d{2}:\d{2}:\d{2}' + file_handler_regex_time = r'\d{4}/\d{2}/\d{2} \d{2}:\d{2}:\d{2}' @patch('torch.distributed.get_rank', lambda: 0) @patch('torch.distributed.is_initialized', lambda: True) @patch('torch.distributed.is_available', lambda: True) def test_init_rank0(self, tmp_path): logger = MMLogger.get_instance('rank0.pkg1', log_level='INFO') - assert logger.name == 'rank0.pkg1' + assert logger.name == 'mmengine' + assert logger.instance_name == 'rank0.pkg1' assert logger.instance_name == 'rank0.pkg1' # Logger get from `MMLogger.get_instance` does not inherit from # `logging.root` @@ -38,6 +40,10 @@ class TestLogger: assert isinstance(logger.handlers[1], logging.FileHandler) logger_pkg3 = MMLogger.get_instance('rank0.pkg2') assert id(logger_pkg3) == id(logger) + logger = MMLogger.get_instance( + 'rank0.pkg3', logger_name='logger_test', log_level='INFO') + assert logger.name == 'logger_test' + assert logger.instance_name == 'rank0.pkg3' logging.shutdown() @patch('torch.distributed.get_rank', lambda: 1) @@ -46,12 +52,18 @@ class TestLogger: def test_init_rank1(self, tmp_path): # If `rank!=1`, the `loglevel` of file_handler is `logging.ERROR`. tmp_file = tmp_path / 'tmp_file.log' - log_path = tmp_path / 'rank1_tmp_file.log' + log_path = tmp_path / 'tmp_file' / 'tmp_file_rank1.log' logger = MMLogger.get_instance( 'rank1.pkg2', log_level='INFO', log_file=str(tmp_file)) - assert len(logger.handlers) == 2 + assert len(logger.handlers) == 1 + logger = MMLogger.get_instance( + 'rank1.pkg3', + log_level='INFO', + log_file=str(tmp_file), + distributed=True) assert logger.handlers[0].level == logging.ERROR assert logger.handlers[1].level == logging.INFO + assert len(logger.handlers) == 2 assert os.path.exists(log_path) logging.shutdown() @@ -59,32 +71,33 @@ class TestLogger: [logging.WARNING, logging.INFO, logging.DEBUG]) def test_handler(self, capsys, tmp_path, log_level): # test stream handler can output correct format logs - logger_name = f'test_stream_{str(log_level)}' - logger = MMLogger.get_instance(logger_name, log_level=log_level) + instance_name = f'test_stream_{str(log_level)}' + logger = MMLogger.get_instance(instance_name, log_level=log_level) logger.log(level=log_level, msg='welcome') out, _ = capsys.readouterr() # Skip match colored INFO loglevl_name = logging._levelToName[log_level] match = re.fullmatch( - self.regex_time + f' - {logger_name} - ' + self.stream_handler_regex_time + f' - mmengine - ' f'(.*){loglevl_name}(.*) - welcome\n', out) assert match is not None # test file_handler output plain text without color. tmp_file = tmp_path / 'tmp_file.log' - logger_name = f'test_file_{log_level}' + instance_name = f'test_file_{log_level}' logger = MMLogger.get_instance( - logger_name, log_level=log_level, log_file=tmp_file) + instance_name, log_level=log_level, log_file=tmp_file) logger.log(level=log_level, msg='welcome') - with open(tmp_file, 'r') as f: + with open(tmp_path / 'tmp_file' / 'tmp_file.log', 'r') as f: log_text = f.read() match = re.fullmatch( - self.regex_time + f' - {logger_name} - {loglevl_name} - ' + self.file_handler_regex_time + + f' - mmengine - {loglevl_name} - ' f'welcome\n', log_text) assert match is not None logging.shutdown() - def test_erro_format(self, capsys): + def test_error_format(self, capsys): # test error level log can output file path, function name and # line number logger = MMLogger.get_instance('test_error', log_level='INFO') @@ -92,9 +105,10 @@ class TestLogger: lineno = sys._getframe().f_lineno - 1 file_path = __file__ function_name = sys._getframe().f_code.co_name - pattern = self.regex_time + r' - test_error - (.*)ERROR(.*) - '\ - f'{file_path} - {function_name} - ' \ - f'{lineno} - welcome\n' + pattern = self.stream_handler_regex_time + \ + r' - mmengine - (.*)ERROR(.*) - ' \ + f'{file_path} - {function_name} - ' \ + f'{lineno} - welcome\n' out, _ = capsys.readouterr() match = re.fullmatch(pattern, out) assert match is not None @@ -114,21 +128,21 @@ class TestLogger: print_log('welcome', logger=logger) out, _ = capsys.readouterr() match = re.fullmatch( - self.regex_time + ' - test_print_log - (.*)INFO(.*) - ' + self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - ' 'welcome\n', out) assert match is not None # Test access logger by name. print_log('welcome', logger='test_print_log') out, _ = capsys.readouterr() match = re.fullmatch( - self.regex_time + ' - test_print_log - (.*)INFO(.*) - ' + self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - ' 'welcome\n', out) assert match is not None # Test access the latest created logger. print_log('welcome', logger='current') out, _ = capsys.readouterr() match = re.fullmatch( - self.regex_time + ' - test_print_log - (.*)INFO(.*) - ' + self.stream_handler_regex_time + ' - mmengine - (.*)INFO(.*) - ' 'welcome\n', out) assert match is not None # Test invalid logger type. diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index 4623c3b5..80db11c2 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -1,4 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. +import pickle +from collections import OrderedDict + import numpy as np import pytest import torch @@ -11,17 +14,26 @@ class TestMessageHub: def test_init(self): message_hub = MessageHub('name') assert message_hub.instance_name == 'name' - assert len(message_hub.log_buffers) == 0 - assert len(message_hub.log_buffers) == 0 + assert len(message_hub.log_scalars) == 0 + assert len(message_hub.log_scalars) == 0 + # The type of log_scalars's value must be `HistoryBuffer`. + with pytest.raises(AssertionError): + MessageHub('hello', log_scalars=OrderedDict(a=1)) + # `Resumed_keys` + with pytest.raises(AssertionError): + MessageHub( + 'hello', + runtime_info=OrderedDict(iter=1), + resumed_keys=OrderedDict(iters=False)) - def test_update_log(self): + def test_update_scalar(self): message_hub = MessageHub.get_instance('mmengine') - # test create target `LogBuffer` by name - message_hub.update_log('name', 1) - log_buffer = message_hub.log_buffers['name'] + # test create target `HistoryBuffer` by name + message_hub.update_scalar('name', 1) + log_buffer = message_hub.log_scalars['name'] assert (log_buffer._log_history == np.array([1])).all() - # test update target `LogBuffer` by name - message_hub.update_log('name', 1) + # test update target `HistoryBuffer` by name + message_hub.update_scalar('name', 1) assert (log_buffer._log_history == np.array([1, 1])).all() # unmatched string will raise a key error @@ -33,19 +45,19 @@ class TestMessageHub: message_hub.update_info('key', 1) assert message_hub.runtime_info['key'] == 1 - def test_get_log_buffers(self): + def test_get_scalar(self): message_hub = MessageHub.get_instance('mmengine') # Get undefined key will raise error with pytest.raises(KeyError): - message_hub.get_log('unknown') + message_hub.get_scalar('unknown') # test get log_buffer as wished log_history = np.array([1, 2, 3, 4, 5]) count = np.array([1, 1, 1, 1, 1]) for i in range(len(log_history)): - message_hub.update_log('test_value', float(log_history[i]), - int(count[i])) + message_hub.update_scalar('test_value', float(log_history[i]), + int(count[i])) recorded_history, recorded_count = \ - message_hub.get_log('test_value').data + message_hub.get_scalar('test_value').data assert (log_history == recorded_history).all() assert (recorded_count == count).all() @@ -57,18 +69,18 @@ class TestMessageHub: message_hub.update_info('test_value', recorded_dict) assert message_hub.get_info('test_value') == recorded_dict - def test_get_log_vars(self): + def test_get_scalars(self): message_hub = MessageHub.get_instance('mmengine') log_dict = dict( loss=1, loss_cls=torch.tensor(2), loss_bbox=np.array(3), loss_iou=dict(value=1, count=2)) - message_hub.update_log_vars(log_dict) - loss = message_hub.get_log('loss') - loss_cls = message_hub.get_log('loss_cls') - loss_bbox = message_hub.get_log('loss_bbox') - loss_iou = message_hub.get_log('loss_iou') + message_hub.update_scalars(log_dict) + loss = message_hub.get_scalar('loss') + loss_cls = message_hub.get_scalar('loss_cls') + loss_bbox = message_hub.get_scalar('loss_bbox') + loss_iou = message_hub.get_scalar('loss_iou') assert loss.current() == 1 assert loss_cls.current() == 2 assert loss_bbox.current() == 3 @@ -76,8 +88,27 @@ class TestMessageHub: with pytest.raises(TypeError): loss_dict = dict(error_type=[]) - message_hub.update_log_vars(loss_dict) + message_hub.update_scalars(loss_dict) with pytest.raises(AssertionError): loss_dict = dict(error_type=dict(count=1)) - message_hub.update_log_vars(loss_dict) + message_hub.update_scalars(loss_dict) + + def test_getstate(self): + message_hub = MessageHub.get_instance('name') + # update log_scalars. + message_hub.update_scalar('loss', 0.1) + message_hub.update_scalar('lr', 0.1, resumed=False) + # update runtime information + message_hub.update_info('iter', 1, resumed=True) + message_hub.update_info('feat', [1, 2, 3], resumed=False) + obj = pickle.dumps(message_hub) + instance = pickle.loads(obj) + + with pytest.raises(KeyError): + instance.get_info('feat') + with pytest.raises(KeyError): + instance.get_info('lr') + + instance.get_info('iter') + instance.get_scalar('loss') diff --git a/tests/test_utils/test_manager.py b/tests/test_utils/test_manager.py index f51b397f..be9348e2 100644 --- a/tests/test_utils/test_manager.py +++ b/tests/test_utils/test_manager.py @@ -70,3 +70,7 @@ class TestManagerMixin: # Non-string instance name will raise `AssertionError`. with pytest.raises(AssertionError): SubClassA.get_instance(name=1) + # `get_instance` should not accept other arguments if corresponding + # instance has been created. + with pytest.raises(AssertionError): + SubClassA.get_instance('name2', a=1, b=2)