From e2a2b0438edbdca44414f73db6fc9f0a3bcc6a3a Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Sun, 24 Apr 2022 19:23:28 +0800 Subject: [PATCH] [Refactor] Refine LoggerHook (#155) * rename global accessible and intergration get_sintance and create_instance * move ManagerMixin to utils * fix as docstring and seporate get_instance to get_instance and get_current_instance * fix lint * fix docstring, rename and move test_global_meta * rename LogBuffer to HistoryBuffer, rename MessageHub methods, MessageHub support resume * refine MMLogger timestamp, update unit test * MMLogger add logger_name arguments * Fix docstring * Add LogProcessor and some unit test * update unit test * complete LogProcessor unit test * refine LoggerHook * solve circle import * change default logger_name to mmengine * refactor eta * Fix docstring comment and unitt test * Fix with runner * fix docstring fix docstring * fix docstring * Add by_epoch attribute to LoggerHook and fix docstring * Please mypy and fix comment * remove \ in MMLogger * Fix lint * roll back pre-commit-hook * Fix hook unit test * Fix comments * remove \t in log and add docstring * Fix as comment * should not accept other arguments if corresponding instance has been created * fix logging ddp file saving * fix logging ddp file saving * move log processor to logging * move log processor to logging * remove current datalaoder * fix docstring * fix unit test * add learing rate in messagehub * Support output training/validation/testing message after iterations/epochs * fix docstring * Fix IterBasedRunner log string * Fix IterBasedRunner log string * Support parse validation loss in log processor --- mmengine/hooks/hook.py | 19 +- mmengine/hooks/iter_timer_hook.py | 64 +++- mmengine/hooks/logger_hook.py | 461 +++++------------------ mmengine/hooks/optimizer_hook.py | 3 + mmengine/logging/__init__.py | 5 +- mmengine/logging/log_processor.py | 409 ++++++++++++++++++++ mmengine/logging/logger.py | 18 +- mmengine/runner/runner.py | 11 +- tests/test_hook/test_hook.py | 9 +- tests/test_hook/test_iter_timer_hook.py | 69 +++- tests/test_hook/test_logger_hook.py | 347 ++++------------- tests/test_hook/test_optimizer_hook.py | 4 +- tests/test_logging/test_log_processor.py | 242 ++++++++++++ tests/test_runner/test_runner.py | 2 +- 14 files changed, 961 insertions(+), 702 deletions(-) create mode 100644 mmengine/logging/log_processor.py create mode 100644 tests/test_logging/test_log_processor.py diff --git a/mmengine/hooks/hook.py b/mmengine/hooks/hook.py index 84060334..1e6e9370 100644 --- a/mmengine/hooks/hook.py +++ b/mmengine/hooks/hook.py @@ -358,11 +358,11 @@ class Hook: """ return (runner.epoch + 1) % n == 0 if n > 0 else False - def every_n_inner_iters(self, inner_iter: int, n: int) -> bool: + def every_n_inner_iters(self, batch_idx: int, n: int) -> bool: """Test whether current inner iteration can be evenly divided by n. Args: - inner_iter (int): Current inner_iter of the training, validation + batch_idx (int): Current batch index of the training, validation or testing loop. n (int): Whether current inner iteration can be evenly divided by n. @@ -371,7 +371,7 @@ class Hook: bool: Whether current inner iteration can be evenly divided by n. """ - return (inner_iter + 1) % n == 0 if n > 0 else False + return (batch_idx + 1) % n == 0 if n > 0 else False def every_n_iters(self, runner, n: int) -> bool: """Test whether current iteration can be evenly divided by n. @@ -387,19 +387,18 @@ class Hook: """ return (runner.iter + 1) % n == 0 if n > 0 else False - def end_of_epoch(self, runner, batch_idx: int) -> bool: + def end_of_epoch(self, dataloader, batch_idx: int) -> bool: """Check whether the current iteration reaches the last iteration of - current dataloader. + the dataloader. Args: - runner (Runner): The runner of the training, validation or testing - process. + dataloader (Dataloader): The dataloader of the training, + validation or testing process. batch_idx (int): The index of the current batch in the loop. - Returns: bool: Whether reaches the end of current epoch or not. """ - return batch_idx + 1 == len(runner.cur_dataloader) + return batch_idx + 1 == len(dataloader) def is_last_train_epoch(self, runner) -> bool: """Test whether current epoch is the last train epoch. @@ -418,10 +417,10 @@ class Hook: Args: runner (Runner): The runner of the training, validation or testing process. + mode (str): Current mode of runner. Defaults to 'train'. Returns: bool: Whether current iteration is the last iteration. - mode (str): Current mode of runner. Defaults to 'train'. """ if mode == 'train': return runner.iter + 1 == runner.train_loop.max_iters diff --git a/mmengine/hooks/iter_timer_hook.py b/mmengine/hooks/iter_timer_hook.py index d281745d..ef7124d5 100644 --- a/mmengine/hooks/iter_timer_hook.py +++ b/mmengine/hooks/iter_timer_hook.py @@ -18,11 +18,25 @@ class IterTimerHook(Hook): priority = 'NORMAL' - def _before_epoch(self, runner, mode: str = 'train') -> None: - """Record time flag before start a epoch. + def __init__(self): + self.time_sec_tot = 0 + self.start_iter = 0 + + def before_run(self, runner) -> None: + """Synchronize the number of iterations with the runner. Args: - runner (Runner): The runner of the training process. + runner: The runner of the training, validation or testing + process. + """ + self.start_iter = runner.iter + + def _before_epoch(self, runner, mode: str = 'train') -> None: + """Record timestamp before start an epoch. + + Args: + runner (Runner): The runner of the training validation and + testing process. mode (str): Current mode of runner. Defaults to 'train'. """ self.t = time.time() @@ -32,16 +46,18 @@ class IterTimerHook(Hook): batch_idx: int, data_batch: DATA_BATCH = None, mode: str = 'train') -> None: - """Logging time for loading data and update the time flag. + """Calculating time for loading data and updating "data_time" + ``HistoryBuffer`` of ``runner.message_hub``. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training, validation and + testing process. batch_idx (int): The index of the current batch in the loop. data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. mode (str): Current mode of runner. Defaults to 'train'. """ - # TODO: update for new logging system + # Update data loading time in `runner.message_hub`. runner.message_hub.update_scalar(f'{mode}/data_time', time.time() - self.t) @@ -52,10 +68,12 @@ class IterTimerHook(Hook): outputs: Optional[Union[dict, Sequence[BaseDataElement]]] = None, mode: str = 'train') -> None: - """Logging time for a iteration and update the time flag. + """Calculating time for an iteration and updating "time" + ``HistoryBuffer`` of ``runner.message_hub``. Args: - runner (Runner): The runner of the training process. + runner (Runner): The runner of the training validation and + testing process. batch_idx (int): The index of the current batch in the loop. data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): Data from dataloader. Defaults to None. @@ -63,7 +81,31 @@ class IterTimerHook(Hook): to None. mode (str): Current mode of runner. Defaults to 'train'. """ - # TODO: update for new logging system - - runner.message_hub.update_scalar(f'{mode}/time', time.time() - self.t) + # Update iteration time in `runner.message_hub`. + message_hub = runner.message_hub + message_hub.update_scalar(f'{mode}/time', time.time() - self.t) self.t = time.time() + window_size = runner.log_processor.window_size + # Calculate eta every `window_size` iterations. Since test and val + # loop will not update runner.iter, use `every_n_innter_iters`to check + # the interval. + if self.every_n_inner_iters(batch_idx, window_size): + iter_time = message_hub.get_scalar(f'{mode}/time').mean( + window_size) + if mode == 'train': + self.time_sec_tot += iter_time * window_size + # Calculate average iterative time. + time_sec_avg = self.time_sec_tot / ( + runner.iter - self.start_iter + 1) + # Calculate eta. + eta_sec = time_sec_avg * ( + runner.train_loop.max_iters - runner.iter - 1) + runner.message_hub.update_info('eta', eta_sec) + else: + if mode == 'val': + cur_dataloader = runner.val_loop.dataloader + else: + cur_dataloader = runner.test_loop.dataloader + + eta_sec = iter_time * (len(cur_dataloader) - batch_idx - 1) + runner.message_hub.update_info('eta', eta_sec) diff --git a/mmengine/hooks/logger_hook.py b/mmengine/hooks/logger_hook.py index aed1d0e0..87b69114 100644 --- a/mmengine/hooks/logger_hook.py +++ b/mmengine/hooks/logger_hook.py @@ -1,16 +1,10 @@ # Copyright (c) OpenMMLab. All rights reserved. -import copy -import datetime import os import os.path as osp -from collections import OrderedDict from pathlib import Path from typing import Any, Optional, Sequence, Tuple, Union -import torch - from mmengine.data import BaseDataElement -from mmengine.dist import master_only from mmengine.fileio import FileClient from mmengine.hooks import Hook from mmengine.registry import HOOKS @@ -21,33 +15,20 @@ DATA_BATCH = Optional[Sequence[Tuple[Any, BaseDataElement]]] @HOOKS.register_module() class LoggerHook(Hook): - """In this logger hook, the information will be printed on the terminal and - saved in JSON file, tensorboard, wandb .etc. + """Collect logs from different components of ``Runner`` and write them to + terminal, JSON file, tensorboard and wandb .etc. + + ``LoggerHook`` is used to record logs formatted by ``LogProcessor`` during + training/validation/testing phase. It is used to control following + behaviers: + + - The frequency of logs update in terminal, local, tensorboad wandb.etc. + - The frequency of show experiment information in terminal. + - The work directory to save logs. Args: - by_epoch (bool): Whether ``EpochBasedLoop`` is used. - Defaults to True. interval (int): Logging interval (every k iterations). Defaults to 10. - custom_keys (dict, optional): Defines the keys in the log and which - kinds of statistic methods should be used to log them. - - - ``custom_keys`` contains multiple string-dict pairs. In each - string-dict pair, the string defines a key name in the log and the - dict is a config defines the statistic methods and corresponding - arguments used to log the value. For example, - ``dict(loss=dict(method_name='mean', log_name='global_loss', - window_size='global'))`` which means the log key ``loss`` will be - counted as global mean and additionally logged as ``global_loss``. - If ``log_name`` is not defined in config dict, the original logged - key will be overwritten. - - The key in ``LoggerHook.fixed_smooth_keys`` cannot be overwritten - because ``time`` and ``iter_time`` will be used to calculate - estimated time of arrival. If you want to recount the time, you - should set ``log_name`` in corresponding values. - - For those statistic methods with the ``window_size`` argument, - if ``by_epoch`` is set to False, ``windows_size`` should not be - `epoch` to statistics log value by epoch. ignore_last (bool): Ignore the log of last iterations in each epoch if the number of remaining iterations is less than :attr:`interval`. Defaults to True. @@ -72,64 +53,24 @@ class LoggerHook(Hook): Defaults to None. Examples: - >>> # `log_name` is defined, `loss_mean_window` will be an additional - >>> # record. - >>> logger_hook_cfg = dict(by_epoch=True, - >>> custom_keys=dict( - >>> loss=dict( - >>> log_name='loss_mean_window', - >>> method_name='mean', - >>> window_size=10))) - >>> # `log_name` is not defined. `loss` will be overwritten by - >>> # `global_mean` statistics. - >>> logger_hook_cfg = dict(by_epoch=True, - >>> custom_keys=dict( - >>> loss=dict( - >>> method_name='mean', - >>> window_size='global'))) - >>> # `time` cannot be overwritten, `global_time` will be an additional - >>> # record. - >>> logger_hook_cfg = dict(by_epoch=True, - >>> custom_keys=dict( - >>> time=dict( - >>> log_name='global_time', - >>> method='mean', - >>> window_size='global'))) - >>> # Record loss with different statistics methods. - >>> logger_hook_cfg = dict(by_epoch=True, - >>> custom_keys=dict(loss=[ - >>> dict(log_name='loss_mean_window', - >>> method_name='mean', - >>> window_size=10), - >>> dict(method_name='mean', - >>> window_size='global')])) + >>> # A simplest LoggerHook config. + >>> logger_hook_cfg = dict(interval=20) """ - # eta will be calculated by time. `time` and `data_time` should not be - # overwritten. - fixed_smooth_keys = ('time', 'data_time') priority = 'BELOW_NORMAL' def __init__( self, - by_epoch: bool = True, interval: int = 10, - custom_keys: Optional[dict] = None, ignore_last: bool = True, interval_exp_name: int = 1000, out_dir: Optional[Union[str, Path]] = None, out_suffix: Union[Sequence[str], str] = ('.log.json', '.log', '.py'), - keep_local=True, - file_client_args=None, + keep_local: bool = True, + file_client_args: Optional[dict] = None, ): - self._inner_iter = 0 - self.by_epoch = by_epoch self.interval = interval - self.custom_keys = custom_keys if custom_keys is not None else dict() self.ignore_last = ignore_last - - self.time_sec_tot = 0 self.interval_exp_name = interval_exp_name - self._check_custom_keys() if out_dir is None and file_client_args is not None: raise ValueError( @@ -169,7 +110,7 @@ class LoggerHook(Hook): f'{runner.timestamp}.log.json') self.yaml_log_path = osp.join(runner.work_dir, f'{runner.timestamp}.log.json') - self.start_iter = runner.iter + # TODO Compatible with Visualizer. if runner.meta is not None: runner.writer.add_params(runner.meta, file_path=self.yaml_log_path) @@ -178,41 +119,100 @@ class LoggerHook(Hook): batch_idx: int, data_batch: DATA_BATCH = None, outputs: Optional[dict] = None) -> None: - """Record training logs. + """Record training logs after training iteration. Args: runner (Runner): The runner of the training process. batch_idx (int): The index of the current batch in the train loop. - data_batch (Sequence[BaseDataElement], optional): Data from - dataloader. Defaults to None. + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): + Data from dataloader. Defaults to None. outputs (dict, optional): Outputs from model. Defaults to None. """ - self._inner_iter = batch_idx - if runner.meta is not None and 'exp_name' in runner.meta: - if (self.every_n_iters(runner, self.interval_exp_name)) or ( - self.by_epoch and self.end_of_epoch(runner, batch_idx)): - exp_info = f'Exp name: {runner.meta["exp_name"]}' - runner.logger.info(exp_info) - if self.by_epoch and self.every_n_inner_iters(batch_idx, - self.interval): - self._log_train(runner) - elif not self.by_epoch and self.every_n_iters(runner, self.interval): - self._log_train(runner) - elif self.end_of_epoch(runner, batch_idx) and not self.ignore_last: + # Print experiment name every n iterations. + if self.every_n_iters(runner, + self.interval_exp_name) or (self.end_of_epoch( + runner.train_dataloader, batch_idx)): + exp_info = f'Exp name: {runner.experiment_name}' + runner.logger.info(exp_info) + if self.every_n_inner_iters(batch_idx, self.interval): + tag, log_str = runner.log_processor.get_log_after_iter( + runner, batch_idx, 'train') + elif (self.end_of_epoch(runner.train_dataloader, batch_idx) + and not self.ignore_last): # `runner.max_iters` may not be divisible by `self.interval`. if # `self.ignore_last==True`, the log of remaining iterations will # be recorded (Epoch [4][1000/1007], the logs of 998-1007 # iterations will be recorded). - self._log_train(runner) + tag, log_str = runner.log_processor.get_log_after_iter( + runner, batch_idx, 'train') + else: + return + runner.logger.info(log_str) + # TODO compatible with visualizer. + runner.writer.add_scalars(tag, step=runner.iter + 1) + + def after_val_iter( + self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence[BaseDataElement]] = None) -> None: + """Record validation logs after validation iteration. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): + Data from dataloader. Defaults to None. + outputs (sequence, optional): Outputs from model. Defaults to None. + """ + if self.every_n_inner_iters(batch_idx, self.interval): + tag, log_str = runner.log_processor.get_log_after_iter( + runner, batch_idx, 'val') + runner.logger.info(log_str) + + def after_test_iter( + self, + runner, + batch_idx: int, + data_batch: DATA_BATCH = None, + outputs: Optional[Sequence[BaseDataElement]] = None) -> None: + """Record testing logs after iteration. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The index of the current batch in the train loop. + data_batch (Sequence[Tuple[Any, BaseDataElement]], optional): + Data from dataloader. Defaults to None. + outputs (sequence, optional): Outputs from model. Defaults to None. + """ + if self.every_n_inner_iters(batch_idx, self.interval): + tag, log_str = runner.log_processor.get_log_after_iter( + runner, batch_idx, 'test') + runner.logger.info(log_str) def after_val_epoch(self, runner) -> None: - """Record validation logs. + """Record validation logs after validation epoch. Args: runner (Runner): The runner of the training process. """ - self._log_val(runner) + tag, log_str = runner.log_processor.get_log_after_epoch( + runner, len(runner.val_dataloader), 'val') + runner.logger.info(log_str) + # TODO compatible with visualizer. + runner.writer.add_scalars(tag, step=runner.iter + 1) + + def after_test_epoch(self, runner) -> None: + """Record testing logs after test epoch. + + Args: + runner (Runner): The runner of the training process. + """ + tag, log_str = runner.log_processor.get_log_after_epoch( + runner, len(runner.val_dataloader), 'test') + runner.logger.info(log_str) def after_run(self, runner) -> None: """Copy logs to ``self.out_dir`` if ``self.out_dir is not None`` @@ -237,280 +237,3 @@ class LoggerHook(Hook): os.remove(local_filepath) runner.logger.info((f'{local_filepath} was removed due to the ' '`self.keep_local=False`')) - - @master_only - def _log_train(self, runner) -> None: - """Collect and record training logs which start named with "train/*". - - Args: - runner (Runner): The runner of the training process. - """ - tag = self._collect_info(runner, 'train') - # The training log default defines `lr`, `momentum`, `time` and - # `data_time`. `log_tag` will pop these keys and loop other keys to - # `log_str`. - log_tag = copy.deepcopy(tag) - cur_iter = self._get_iter(runner, inner_iter=True) - cur_epoch = self._get_epoch(runner, 'train') - - # Record learning rate and momentum. - lr_str_list = [] - momentum_str_list = [] - for key, value in tag.items(): - if key.startswith('lr'): - log_tag.pop(key) - lr_str_list.append(f'{key}: {value:.3e}') - lr_str = ' '.join(lr_str_list) - for key, value in tag.items(): - if key.startswith('momentum'): - log_tag.pop(key) - momentum_str_list.append(f'{key}: {value:.3e}') - momentum_str = ' '.join(momentum_str_list) - lr_momentum_str = f'{lr_str} {momentum_str}' - # by epoch: Epoch [4][100/1000] - # by iter: Iter [100/100000] - if self.by_epoch: - log_str = f'Epoch [{cur_epoch}]' \ - f'[{cur_iter}/{len(runner.cur_dataloader)}]\t' - else: - log_str = f'Iter [{cur_iter}/{runner.train_loop.max_iters}]\t' - log_str += f'{lr_momentum_str}, ' - # Calculate eta time. - self.time_sec_tot += (tag['time'] * self.interval) - time_sec_avg = self.time_sec_tot / (runner.iter - self.start_iter + 1) - eta_sec = time_sec_avg * ( - runner.train_loop.max_iters - runner.iter - 1) - eta_str = str(datetime.timedelta(seconds=int(eta_sec))) - log_str += f'eta: {eta_str}, ' - log_str += f'time: {tag["time"]:.3f}, ' \ - f'data_time: {tag["data_time"]:.3f}, ' - # Pop recorded keys - log_tag.pop('time') - log_tag.pop('data_time') - # statistic memory - if torch.cuda.is_available(): - log_str += f'memory: {self._get_max_memory(runner)}, ' - # Loop left keys to fill `log_str`. - log_items = [] - for name, val in log_tag.items(): - if isinstance(val, float): - val = f'{val:.4f}' - log_items.append(f'{name}: {val}') - log_str += ', '.join(log_items) - runner.logger.info(log_str) - # Write logs to local, tensorboad, and wandb. - runner.writer.add_scalars( - tag, step=runner.iter + 1, file_path=self.json_log_path) - - @master_only - def _log_val(self, runner) -> None: - """Collect and record training logs which start named with "val/*". - - Args: - runner (Runner): The runner of the training process. - """ - tag = self._collect_info(runner, 'val') - # Compatible with function `log` https://github.com/open-mmlab/mmcv/blob/master/mmcv/runner/hooks/logger/text.py # noqa E501 - eval_iter = len(runner.cur_dataloader) - cur_iter = self._get_iter(runner) - cur_epoch = self._get_epoch(runner, 'val') - # val/test time - # here 1000 is the length of the val dataloader - # by epoch: Epoch[val] [4][1000] - # by iter: Iter[val] [1000] - if self.by_epoch: - # runner.epoch += 1 has been done before val workflow - log_str = f'Epoch(val) [{cur_epoch}][{eval_iter}]\t' - else: - log_str = f'Iter(val) [{eval_iter}]\t' - - log_items = [] - for name, val in tag.items(): - if isinstance(val, float): - val = f'{val:.4f}' - log_items.append(f'{name}: {val}') - log_str += ', '.join(log_items) - runner.logger.info(log_str) - # Write tag. - runner.writer.add_scalars( - tag, step=cur_iter, file_path=self.json_log_path) - - def _get_window_size(self, runner, window_size: Union[int, str]) \ - -> int: - """Parse window_size specified in ``self.custom_keys`` to int value. - - Args: - runner (Runner): The runner of the training process. - window_size (int or str): Smoothing scale of logs. - - Returns: - int: Smoothing window for statistical methods. - """ - if isinstance(window_size, int): - assert window_size == self.interval, \ - 'The value of windows size must equal to LoggerHook.interval' - return window_size - elif window_size == 'epoch': - return self._inner_iter + 1 - elif window_size == 'global': - return runner.iter + 1 - else: - raise ValueError('window_size should be int, epoch or global, but ' - f'got invalid {window_size}') - - def _collect_info(self, runner, mode: str) -> dict: - """Collect log information to a dict according to mode. - - Args: - runner (Runner): The runner of the training process. - mode (str): 'train' or 'val', which means the prefix attached by - runner. - - Returns: - dict: Statistical values of logs. - """ - tag = OrderedDict() - log_buffers = runner.message_hub.log_scalars - mode_log_buffers = OrderedDict() - # Filter log_buffers which starts with `mode`. - for prefix_key, log_buffer in log_buffers.items(): - if prefix_key.startswith(mode): - key = prefix_key.split('/')[-1] - mode_log_buffers[key] = log_buffer - # Ensure all metric and lr values are latest. - for key in mode_log_buffers: - # Update the latest learning rate and smoothed time logs. - if key in self.fixed_smooth_keys or key.startswith('loss'): - tag[key] = mode_log_buffers[key].mean(self.interval) - else: - tag[key] = mode_log_buffers[key].current() - # Update custom keys. - if mode == 'train': - for log_key, log_cfg in self.custom_keys.items(): - self._parse_custom_keys(runner, log_key, - copy.deepcopy(log_cfg), - mode_log_buffers, tag) - return tag - - def _parse_custom_keys(self, runner, log_key: str, log_cfg: dict, - log_buffers: OrderedDict, tag: OrderedDict) -> None: - """Statistics logs in log_buffers according to custom_keys. - - Args: - runner (Runner): The runner of the training process. - log_key (str): log key specified in ``self.custom_keys`` - log_cfg (dict): A config dict for describing the logging - statistics method. - log_buffers (OrderedDict): All logs for the corresponding phase. - tag (OrderedDict): A dict which defines all statistic values of - logs. - """ - if isinstance(log_cfg, list): - log_names = set() - for cfg in log_cfg: - log_name = cfg.get('log_name', None) - if log_name in log_names: - raise KeyError(f'{cfg["log_name"]} cannot be redefined in ' - 'log_key') - if log_name is not None: - log_names.add(log_name) - self._parse_custom_keys(runner, log_key, cfg, log_buffers, tag) - assert len(log_names) == len(log_cfg) - 1, \ - f'{log_key} cannot be overwritten multiple times, please ' \ - f'check only one key does not contain `log_name` in {log_cfg}.' - elif isinstance(log_cfg, dict): - if 'window_size' in log_cfg: - log_cfg['window_size'] = \ - self._get_window_size(runner, log_cfg['window_size']) - if 'log_name' in log_cfg: - name = log_cfg.pop('log_name') - else: - name = log_key - tag[name] = log_buffers[log_key].statistics(**log_cfg).item() - else: - raise ValueError('The structure of `LoggerHook.custom key` is ' - 'wrong, please make sure the type of each key is ' - 'dict or list.') - - def _get_max_memory(self, runner) -> int: - """Returns the maximum GPU memory occupied by tensors in megabytes (MB) - for a given device. - - Args: - runner (Runner): The runner of the training process. - - Returns: - The maximum GPU memory occupied by tensors in megabytes for a given - device. - """ - device = getattr(runner.model, 'output_device', None) - mem = torch.cuda.max_memory_allocated(device=device) - mem_mb = torch.tensor([int(mem) // (1024 * 1024)], - dtype=torch.int, - device=device) - torch.cuda.reset_peak_memory_stats() - return int(mem_mb.item()) - - def _check_custom_keys(self) -> None: - """Check the legality of ``self.custom_keys``. - - If ``self.by_epoch==False``, ``window_size`` should not be "epoch". The - key of ``self.fixed_smooth_keys`` cannot be overwritten. - """ - - def _check_window_size(item): - if not self.by_epoch: - assert item['window_size'] != 'epoch', \ - 'window_size cannot be epoch if LoggerHook.by_epoch is ' \ - 'False.' - - def _check_fixed_keys(key, item): - if key in self.fixed_smooth_keys: - assert 'log_name' in item, f'{key} cannot be overwritten by ' \ - 'custom keys!' - - for key, value in self.custom_keys.items(): - if isinstance(value, Sequence): - [(_check_window_size(item), _check_fixed_keys(key, item)) - for item in value] - - else: - _check_window_size(value) - _check_fixed_keys(key, value) - - def _get_epoch(self, runner, mode: str) -> int: - """Get epoch according to mode. - - Args: - runner (Runner): The runner of the training process. - mode (str): Train or val. - - Returns: - int: The current epoch. - """ - if mode == 'train': - epoch = runner.epoch + 1 - elif mode == 'val': - # normal val mode - # runner.epoch += 1 has been done before val workflow - epoch = runner.epoch - else: - raise ValueError(f"runner mode should be 'train' or 'val', " - f'but got {runner.mode}') - return epoch - - def _get_iter(self, runner, inner_iter=False) -> int: - """Get the current training iteration step. - Args: - runner (Runner): The runner of the training process. - inner_iter (bool): Whether to return the inner iter of an epoch. - Defaults to False. - - Returns: - int: The current global iter or inner iter. - """ - if self.by_epoch and inner_iter: - current_iter = self._inner_iter + 1 - else: - current_iter = runner.iter + 1 - return current_iter diff --git a/mmengine/hooks/optimizer_hook.py b/mmengine/hooks/optimizer_hook.py index ff33b54a..61870de8 100644 --- a/mmengine/hooks/optimizer_hook.py +++ b/mmengine/hooks/optimizer_hook.py @@ -86,6 +86,9 @@ class OptimizerHook(Hook): we keep ``outputs`` here. Defaults to None. """ runner.optimizer.zero_grad() + runner.message_hub.update_scalar( + 'train/lr', runner.optimizer.param_groups[0]['lr']) + if self.detect_anomalous_params: self.detect_anomalous_parameters(runner.outputs['loss'], runner) runner.outputs['loss'].backward() diff --git a/mmengine/logging/__init__.py b/mmengine/logging/__init__.py index ba5533c2..eeac7ff1 100644 --- a/mmengine/logging/__init__.py +++ b/mmengine/logging/__init__.py @@ -1,6 +1,9 @@ # Copyright (c) OpenMMLab. All rights reserved. from .history_buffer import HistoryBuffer +from .log_processor import LogProcessor from .logger import MMLogger, print_log from .message_hub import MessageHub -__all__ = ['HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log'] +__all__ = [ + 'HistoryBuffer', 'MessageHub', 'MMLogger', 'print_log', 'LogProcessor' +] diff --git a/mmengine/logging/log_processor.py b/mmengine/logging/log_processor.py new file mode 100644 index 00000000..cb97286c --- /dev/null +++ b/mmengine/logging/log_processor.py @@ -0,0 +1,409 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +import datetime +from collections import OrderedDict +from typing import List, Optional, Tuple + +import torch + + +class LogProcessor: + """A log processor used to format log information collected from + ``runner.message_hub.log_scalars``. + + ``LogProcessor`` instance is built by runner and will format + ``runner.message_hub.log_scalars`` to ``tag`` and ``log_str``, which can + directly used by ``LoggerHook`` and ``MMLogger``. Besides, the argument + ``custom_cfg`` of constructor can control the statistics method of logs. + + Args: + window_size (int): default smooth interval Defaults to 10. + by_epoch (bool): Whether to format logs with epoch stype. Defaults to + True. + custom_cfg (list[dict], optional): Contains multiple log config dict, + in which key means the data source name of log and value means the + statistic method and corresponding arguments used to count the + data source. Defaults to None + - If custom_cfg is None, all logs will be formatted via default + methods, such as smoothing loss by default window_size. If + custom_cfg is defined as a list of config dict, for example: + [dict(data_src=loss, method='mean', log_name='global_loss', + window_size='global')]. It means the log item ``loss`` will be + counted as global mean and additionally logged as ``global_loss`` + (defined by ``log_name``). If ``log_name`` is not defined in + config dict, the original logged key will be overwritten. + + - The original log item cannot be overwritten twice. Here is + an error example: + [dict(data_src=loss, method='mean', window_size='global'), + dict(data_src=loss, method='mean', window_size='epoch')]. + Both log config dict in custom_cfg do not have ``log_name`` key, + which means the loss item will be overwritten twice. + + - For those statistic methods with the ``window_size`` argument, + if ``by_epoch`` is set to False, ``windows_size`` should not be + `epoch` to statistics log value by epoch. + + Examples: + >>> # `log_name` is defined, `loss_large_window` will be an additional + >>> # record. + >>> log_processor = dict( + >>> window_size=10, + >>> by_epoch=True, + >>> custom_cfg=[dict(data_src='loss', + >>> log_name='loss_large_window', + >>> method_name='mean', + >>> window_size=100)]) + >>> # `log_name` is not defined. `loss` will be overwritten. + >>> log_processor = dict( + >>> window_size=10, + >>> by_epoch=True, + >>> custom_cfg=[dict(data_src='loss', + >>> method_name='mean', + >>> window_size=100)]) + >>> # Record loss with different statistics methods. + >>> log_processor = dict( + >>> window_size=10, + >>> by_epoch=True, + >>> custom_cfg=[dict(data_src='loss', + >>> log_name='loss_large_window', + >>> method_name='mean', + >>> window_size=100), + >>> dict(data_src='loss', + >>> method_name='mean', + >>> window_size=100)]) + >>> # Overwrite loss item twice will raise an error. + >>> log_processor = dict( + >>> window_size=10, + >>> by_epoch=True, + >>> custom_cfg=[dict(data_src='loss', + >>> method_name='mean', + >>> window_size=100), + >>> dict(data_src='loss', + >>> method_name='max', + >>> window_size=100)]) + AssertionError + """ + + def __init__(self, + window_size=10, + by_epoch=True, + custom_cfg: Optional[List[dict]] = None): + self.window_size = window_size + self.by_epoch = by_epoch + self.custom_cfg = custom_cfg if custom_cfg else [] + self._check_custom_cfg() + + def get_log_after_iter(self, runner, batch_idx: int, + mode: str) -> Tuple[dict, str]: + """Format log string after training, validation or testing epoch. + + Args: + runner (Runner): The runner of training phase. + batch_idx (int): The index of the current batch in the current + loop. + mode (str): Current mode of runner, train, test or val. + + Return: + Tuple(dict, str): Formatted log dict/string which will be + recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. + """ + assert mode in ['train', 'test', 'val'] + current_loop = self._get_cur_loop(runner, mode) + cur_iter = self._get_iter(runner, batch_idx=batch_idx) + # Overwrite ``window_size`` defined in ``custom_cfg`` to int value. + custom_cfg_copy = self._parse_windows_size(runner, batch_idx) + # tag is used to write log information to different backends. + tag = self._collect_scalars(custom_cfg_copy, runner, mode) + # `log_tag` will pop 'lr' and loop other keys to `log_str`. + log_tag = copy.deepcopy(tag) + # Record learning rate. + lr_str_list = [] + for key, value in tag.items(): + if key.startswith('lr'): + log_tag.pop(key) + lr_str_list.append(f'{key}: {value:.3e}') + lr_str = ' '.join(lr_str_list) + # Format log header. + # by_epoch == True + # train/val: Epoch [5][5/10] ... + # test: Epoch [5/10] + # by_epoch == False + # train: Epoch [5/10000] ... (divided by `max_iter`) + # val/test: Epoch [5/2000] ... (divided by length of dataloader) + if self.by_epoch: + if mode in ['train', 'val']: + cur_epoch = self._get_epoch(runner, mode) + log_str = (f'Epoch({mode}) [{cur_epoch}]' + f'[{cur_iter}/{len(current_loop.dataloader)}] ') + else: + log_str = (f'Epoch({mode}) ' + f'[{cur_iter}/{len(current_loop.dataloader)}] ') + else: + if mode == 'train': + log_str = (f'Iter({mode}) ' + f'[{cur_iter}/{runner.train_loop.max_iters}] ') + else: + log_str = (f'Iter({mode}) [{batch_idx+1}' + f'/{len(current_loop.dataloader)}] ') + # Concatenate lr, momentum string with log header. + log_str += f'{lr_str} ' + # If IterTimerHook used in runner, eta, time, and data_time should be + # recorded. + if (all(item in tag for item in ['time', 'data_time']) + and 'eta' in runner.message_hub.runtime_info): + eta = runner.message_hub.get_info('eta') + eta_str = str(datetime.timedelta(seconds=int(eta))) + log_str += f'eta: {eta_str} ' + log_str += (f'time: {tag["time"]:.3f} ' + f'data_time: {tag["data_time"]:.3f} ') + # Pop recorded keys + log_tag.pop('time') + log_tag.pop('data_time') + + # If cuda is available, the max memory occupied should be calculated. + if torch.cuda.is_available(): + log_str += f'memory: {self._get_max_memory(runner)} ' + # Loop left keys to fill `log_str`. + if mode in ('train', 'val'): + log_items = [] + for name, val in log_tag.items(): + if mode == 'val' and not name.startswith('loss'): + continue + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ' '.join(log_items) + return tag, log_str + + def get_log_after_epoch(self, runner, batch_idx: int, + mode: str) -> Tuple[dict, str]: + """Format log string after validation or testing epoch. + + Args: + runner (Runner): The runner of training phase. + batch_idx (int): The index of the current batch in the current + loop. + mode (str): Current mode of runner. + + Return: + Tuple(dict, str): Formatted log dict/string which will be + recorded by :obj:`runner.message_hub` and :obj:`runner.visualizer`. + """ + assert mode in [ + 'test', 'val' + ], ('`_get_metric_log_str` only accept val or test mode, but got ' + f'{mode}') + cur_loop = self._get_cur_loop(runner, mode) + dataloader_len = len(cur_loop.dataloader) + + custom_cfg_copy = self._parse_windows_size(runner, batch_idx) + # tag is used to write log information to different backends. + tag = self._collect_scalars(custom_cfg_copy, runner, mode) + # validation log string needs cur epoch/iteration and max + # epochs/iterations. test log string only needs length of test + # dataloader. + cur_iter = self._get_iter(runner, batch_idx) + if self.by_epoch: + if mode == 'val': + cur_epoch = self._get_epoch(runner, mode) + log_str = (f'Epoch({mode}) [{cur_epoch}][{dataloader_len}/' + f'{dataloader_len}] ') + else: + log_str = ( + f'Epoch({mode}) [{dataloader_len}/{dataloader_len}] ') + + else: + if mode == 'train': + log_str = (f'Iter({mode}) [{cur_iter}/' + f'{runner.train_loop.max_iters}] ') + else: + log_str = ( + f'Iter({mode}) [{dataloader_len}/{dataloader_len}] ') + log_items = [] + for name, val in tag.items(): + if name in ('time', 'data_time'): + continue + if isinstance(val, float): + val = f'{val:.4f}' + log_items.append(f'{name}: {val}') + log_str += ' '.join(log_items) + return tag, log_str + + def _collect_scalars(self, custom_cfg: List[dict], runner, + mode: str) -> dict: + """Collect log information to compose a dict according to mode. + + Args: + custom_cfg (List[dict]): A copy of ``self.custom_cfg`` with int + ``window_size``. + runner (Runner): The runner of the training process. + mode (str): 'train' or 'val', which means the prefix attached by + runner. + + Returns: + dict: Statistical values of logs. + """ + tag = OrderedDict() + # history_scalars of train/val/test phase. + history_scalars = runner.message_hub.log_scalars + # corresponding mode history_scalars + mode_history_scalars = OrderedDict() + # extract log scalars and remove prefix to `mode_history_scalars` + # according to mode. + for prefix_key, log_buffer in history_scalars.items(): + if prefix_key.startswith(mode): + key = prefix_key.split('/')[-1] + mode_history_scalars[key] = log_buffer + for key in mode_history_scalars: + # Update the latest learning rate and smoothed time logs. + if key.startswith('loss'): + tag[key] = mode_history_scalars[key].mean(self.window_size) + else: + # Default statistic method is current. + tag[key] = mode_history_scalars[key].current() + # Update custom keys. + for log_cfg in custom_cfg: + data_src = log_cfg.pop('data_src') + if 'log_name' in log_cfg: + log_name = log_cfg.pop('log_name') + else: + log_name = data_src + # log item in custom_cfg could only exist in train or val + # mode. + if data_src in mode_history_scalars: + tag[log_name] = mode_history_scalars[data_src].statistics( + **log_cfg) + return tag + + def _check_custom_cfg(self) -> None: + """Check the legality of ``self.custom_cfg``.""" + + def _check_window_size(): + for log_cfg in self.custom_cfg: + if not self.by_epoch: + assert log_cfg['window_size'] != 'epoch', \ + 'window_size cannot be epoch if LoggerHook.by_epoch' \ + ' is False.' + + def _check_repeated_log_name(): + check_dict = dict() + # The `log_name` of the same data_src should not be repeated. + # If `log_name` is not specified, `data_src` will be overwritten. + # But only allowed to be overwritten once. + for log_cfg in self.custom_cfg: + assert 'data_src' in log_cfg + data_src = log_cfg['data_src'] + log_name = log_cfg.get('log_name', data_src) + check_dict.setdefault(data_src, + dict(log_names=set(), log_counts=0)) + check_dict[data_src]['log_names'].add(log_name) + check_dict[data_src]['log_counts'] += 1 + assert (len( + check_dict[data_src] + ['log_names']) == check_dict[data_src]['log_counts']), ( + f'If you want to statistic {data_src} with multiple ' + 'statistics method, please check `log_name` is unique' + f'and {data_src} will not be overwritten twice. See ' + f'more information in the docstring of `LogProcessor`') + + _check_repeated_log_name() + _check_window_size() + + def _parse_windows_size(self, runner, batch_idx: int) -> list: + """Parse window_size defined in custom_cfg to int value. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int): The iteration index of current dataloader. + """ + custom_cfg_copy = copy.deepcopy(self.custom_cfg) + for log_cfg in custom_cfg_copy: + window_size = log_cfg.get('window_size', None) + if window_size is None or isinstance(window_size, int): + continue + elif window_size == 'epoch': + log_cfg['window_size'] = batch_idx + 1 + elif window_size == 'global': + log_cfg['window_size'] = runner.iter + 1 + else: + raise TypeError( + 'window_size should be int, epoch or global, but got ' + f'invalid {window_size}') + return custom_cfg_copy + + def _get_max_memory(self, runner) -> int: + """Returns the maximum GPU memory occupied by tensors in megabytes (MB) + for a given device. + + Args: + runner (Runner): The runner of the training process. + + Returns: + The maximum GPU memory occupied by tensors in megabytes for a given + device. + """ + device = getattr(runner.model, 'output_device', None) + mem = torch.cuda.max_memory_allocated(device=device) + mem_mb = torch.tensor([int(mem) // (1024 * 1024)], + dtype=torch.int, + device=device) + torch.cuda.reset_peak_memory_stats() + return int(mem_mb.item()) + + def _get_iter(self, runner, batch_idx: int = None) -> int: + """Get current training iteration step. + + Args: + runner (Runner): The runner of the training process. + batch_idx (int, optional): The interaction index of current + dataloader. Defaults to None. + + Returns: + int: The current global iter or inner iter. + """ + if self.by_epoch and batch_idx: + current_iter = batch_idx + 1 + else: + current_iter = runner.iter + 1 + return current_iter + + def _get_epoch(self, runner, mode: str) -> int: + """Get current epoch according to mode. + + Args: + runner (Runner): The runner of the training/validation process. + mode (str): Current mode of runner, "train" or "val". + + Returns: + int: The current epoch. + """ + if mode == 'train': + epoch = runner.epoch + 1 + elif mode == 'val': + # normal val mode + # runner.epoch += 1 has been done before validation + epoch = runner.epoch + else: + raise ValueError( + f"runner mode should be 'train' or 'val', but got {mode}") + return epoch + + def _get_cur_loop(self, runner, mode: str): + """Get current loop according to mode. + + Args: + runner (Runner): The runner of the training/validation/testing + process. + mode (str): Current mode of runner, "train", "val" or test. + + Returns: + BaseLoop: Current loop of runner. + """ + # returns type hint will occur circular import + if mode == 'train': + return runner.train_loop + elif mode == 'val': + return runner.val_loop + else: + return runner.test_loop diff --git a/mmengine/logging/logger.py b/mmengine/logging/logger.py index 3ae26524..6066449f 100644 --- a/mmengine/logging/logger.py +++ b/mmengine/logging/logger.py @@ -32,15 +32,15 @@ class MMFormatter(logging.Formatter): info_prefix = self._get_prefix('INFO', color) debug_prefix = self._get_prefix('DEBUG', color) # Config output format. - self.err_format = f'%(asctime)s - %(name)s - {error_prefix} - ' \ - f'%(pathname)s - %(funcName)s - %(lineno)d - ' \ - '%(message)s' - self.warn_format = f'%(asctime)s - %(name)s - {warn_prefix} - %(' \ - 'message)s' - self.info_format = f'%(asctime)s - %(name)s - {info_prefix} - %(' \ - 'message)s' - self.debug_format = f'%(asctime)s - %(name)s - {debug_prefix} - %(' \ - 'message)s' + self.err_format = (f'%(asctime)s - %(name)s - {error_prefix} - ' + '%(pathname)s - %(funcName)s - %(lineno)d - ' + '%(message)s') + self.warn_format = (f'%(asctime)s - %(name)s - {warn_prefix} - %(' + 'message)s') + self.info_format = (f'%(asctime)s - %(name)s - {info_prefix} - %(' + 'message)s') + self.debug_format = (f'%(asctime)s - %(name)s - {debug_prefix} - %(' + 'message)s') def _get_prefix(self, level: str, color: bool) -> str: """Get the prefix of the target log level. diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index a1b11511..e43c0b08 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -25,7 +25,7 @@ from mmengine.dist import (broadcast, get_dist_info, init_dist, master_only, sync_random_seed) from mmengine.evaluator import Evaluator from mmengine.hooks import Hook -from mmengine.logging import MessageHub, MMLogger +from mmengine.logging import LogProcessor, MessageHub, MMLogger from mmengine.model import is_model_wrapper from mmengine.optim import _ParamScheduler, build_optimizer from mmengine.registry import (DATA_SAMPLERS, DATASETS, HOOKS, LOOPS, @@ -127,6 +127,8 @@ class Runner: non-distributed environment will be launched. env_cfg (dict): A dict used for setting environment. Defaults to dict(dist_cfg=dict(backend='nccl')). + log_processor (dict, optional): A processor to format logs. Defaults to + None. log_level (int or str): The log level of MMLogger handlers. Defaults to 'INFO'. writer (ComposedWriter or dict, optional): A ComposedWriter object or a @@ -184,6 +186,7 @@ class Runner: param_scheduler=dict(type='ParamSchedulerHook')), launcher='none', env_cfg=dict(dist_cfg=dict(backend='nccl')), + log_processor=dict(window_size=20), writer=dict( name='composed_writer', writers=[dict(type='LocalWriter', save_dir='temp_dir')]) @@ -218,6 +221,7 @@ class Runner: launcher: str = 'none', env_cfg: Dict = dict(dist_cfg=dict(backend='nccl')), log_level: str = 'INFO', + log_processor: Optional[Dict] = None, writer: Optional[Union[ComposedWriter, Dict]] = None, default_scope: Optional[str] = None, randomness: Dict = dict(seed=None), @@ -310,6 +314,10 @@ class Runner: else: self._experiment_name = self.timestamp + log_processor = dict() if log_processor is None else log_processor + self.log_processor = LogProcessor(**log_processor) + # Since `get_instance` could return any subclass of ManagerMixin. The + # corresponding attribute needs a type hint. self.logger = self.build_logger(log_level=log_level) # Build `message_hub` for communication among components. # `message_hub` can store log scalars (loss, learning rate) and @@ -385,6 +393,7 @@ class Runner: resume=cfg.get('resume', False), launcher=cfg.get('launcher', 'none'), env_cfg=cfg.get('env_cfg'), # type: ignore + log_processor=cfg.get('log_processor'), log_level=cfg.get('log_level', 'INFO'), writer=cfg.get('writer'), default_scope=cfg.get('default_scope'), diff --git a/tests/test_hook/test_hook.py b/tests/test_hook/test_hook.py index db80ed4a..771c54f6 100644 --- a/tests/test_hook/test_hook.py +++ b/tests/test_hook/test_hook.py @@ -157,18 +157,17 @@ class TestHook: def test_end_of_epoch(self): hook = Hook() - runner = Mock() # last inner iter batch_idx = 1 - runner.cur_dataloader.__len__ = Mock(return_value=2) - runner.cur_dataloader.__len__ = Mock(return_value=2) - return_val = hook.end_of_epoch(runner, batch_idx) + dataloader = Mock() + dataloader.__len__ = Mock(return_value=2) + return_val = hook.end_of_epoch(dataloader, batch_idx) assert return_val # not the last inner iter batch_idx = 0 - return_val = hook.end_of_epoch(runner, batch_idx) + return_val = hook.end_of_epoch(dataloader, batch_idx) assert not return_val def test_is_last_train_epoch(self): diff --git a/tests/test_hook/test_iter_timer_hook.py b/tests/test_hook/test_iter_timer_hook.py index af149f2f..8d3dfb9d 100644 --- a/tests/test_hook/test_iter_timer_hook.py +++ b/tests/test_hook/test_iter_timer_hook.py @@ -1,29 +1,70 @@ # Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import Mock +from unittest import TestCase +from unittest.mock import MagicMock, Mock, patch from mmengine.hooks import IterTimerHook +from mmengine.logging import MessageHub -class TestIterTimerHook: +def time_patch(): + if not hasattr(time_patch, 'time'): + time_patch.time = 0 + else: + time_patch.time += 1 + return time_patch.time + + +class TestIterTimerHook(TestCase): + + def setUp(self) -> None: + self.hook = IterTimerHook() + + def test_init(self): + assert self.hook.time_sec_tot == 0 + assert self.hook.start_iter == 0 + + def test_before_run(self): + runner = MagicMock() + runner.iter = 1 + self.hook.before_run(runner) + assert self.hook.start_iter == 1 def test_before_epoch(self): - hook = IterTimerHook() runner = Mock() - hook._before_epoch(runner) - assert isinstance(hook.t, float) + self.hook._before_epoch(runner) + assert isinstance(self.hook.t, float) + @patch('time.time', MagicMock(return_value=1)) def test_before_iter(self): - hook = IterTimerHook() - runner = Mock() + runner = MagicMock() runner.log_buffer = dict() - hook._before_epoch(runner) - hook._before_iter(runner, 0) - runner.message_hub.update_scalar.assert_called() + self.hook._before_epoch(runner) + for mode in ('train', 'val', 'test'): + self.hook._before_iter(runner, batch_idx=1, mode=mode) + runner.message_hub.update_scalar.assert_called_with( + f'{mode}/data_time', 0) + @patch('time.time', time_patch) def test_after_iter(self): - hook = IterTimerHook() - runner = Mock() + runner = MagicMock() runner.log_buffer = dict() - hook._before_epoch(runner) - hook._after_iter(runner, 0) + runner.log_processor.window_size = 10 + runner.train_loop.max_iters = 100 + runner.iter = 0 + runner.test_loop.dataloader = [0] * 20 + runner.val_loop.dataloader = [0] * 20 + self.hook._before_epoch(runner) + self.hook.before_run(runner) + self.hook._after_iter(runner, batch_idx=1) runner.message_hub.update_scalar.assert_called() + runner.message_hub.get_log.assert_not_called() + runner.message_hub.update_info.assert_not_called() + runner.message_hub = MessageHub.get_instance('test_iter_timer_hook') + runner.iter = 9 + # eta = (100 - 10) / 1 + self.hook._after_iter(runner, batch_idx=89) + assert runner.message_hub.get_info('eta') == 90 + self.hook._after_iter(runner, batch_idx=9, mode='val') + assert runner.message_hub.get_info('eta') == 10 + self.hook._after_iter(runner, batch_idx=19, mode='test') + assert runner.message_hub.get_info('eta') == 0 diff --git a/tests/test_hook/test_logger_hook.py b/tests/test_hook/test_logger_hook.py index cac2e45b..3caed5dd 100644 --- a/tests/test_hook/test_logger_hook.py +++ b/tests/test_hook/test_logger_hook.py @@ -1,13 +1,8 @@ # Copyright (c) OpenMMLab. All rights reserved. -import datetime -import logging import os.path as osp -import sys -from collections import OrderedDict -from unittest.mock import MagicMock, patch +from unittest.mock import MagicMock import pytest -import torch from mmengine.fileio.file_client import HardDiskBackend from mmengine.hooks import LoggerHook @@ -17,11 +12,8 @@ class TestLoggerHook: def test_init(self): logger_hook = LoggerHook(out_dir='tmp.txt') - assert logger_hook.by_epoch assert logger_hook.interval == 10 - assert not logger_hook.custom_keys assert logger_hook.ignore_last - assert logger_hook.time_sec_tot == 0 assert logger_hook.interval_exp_name == 1000 assert logger_hook.out_suffix == ('.log.json', '.log', '.py') assert logger_hook.keep_local @@ -30,22 +22,7 @@ class TestLoggerHook: # out_dir should be None or string or tuple of string. with pytest.raises(TypeError): LoggerHook(out_dir=1) - # time cannot be overwritten. - with pytest.raises(AssertionError): - LoggerHook(custom_keys=dict(time=dict(method='max'))) - LoggerHook( - custom_keys=dict(time=[ - dict(method='max', log_name='time_max'), - dict(method='min', log_name='time_min') - ])) - # Epoch window_size cannot be used when `LoggerHook.by_epoch=False` - with pytest.raises(AssertionError): - LoggerHook( - by_epoch=False, - custom_keys=dict( - time=dict( - method='max', log_name='time_max', - window_size='epoch'))) + with pytest.raises(ValueError): LoggerHook(file_client_args=dict(enable_mc=True)) @@ -60,20 +37,23 @@ class TestLoggerHook: assert logger_hook.out_dir == osp.join('out_dir', 'work_dir') assert logger_hook.json_log_path == osp.join('work_dir', 'timestamp.log.json') - assert logger_hook.start_iter == runner.iter runner.writer.add_params.assert_called() def test_after_run(self, tmp_path): + # Test out_dir = tmp_path / 'out_dir' out_dir.mkdir() work_dir = tmp_path / 'work_dir' work_dir.mkdir() work_dir_json = work_dir / 'tmp.log.json' - json_f = open(work_dir_json, 'w') - json_f.close() runner = MagicMock() runner.work_dir = work_dir - + # Test without out_dir. + logger_hook = LoggerHook() + logger_hook.after_run(runner) + # Test with out_dir and make sure json file has been moved to out_dir. + json_f = open(work_dir_json, 'w') + json_f.close() logger_hook = LoggerHook(out_dir=str(tmp_path), keep_local=False) logger_hook.out_dir = str(out_dir) logger_hook.after_run(runner) @@ -84,274 +64,83 @@ class TestLoggerHook: def test_after_train_iter(self): # Test LoggerHook by iter. runner = MagicMock() - runner.iter = 10 - batch_idx = 5 - logger_hook = LoggerHook(by_epoch=False) - logger_hook._log_train = MagicMock() - logger_hook.after_train_iter(runner, batch_idx=batch_idx) + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook = LoggerHook() + logger_hook.after_train_iter(runner, batch_idx=5) # `cur_iter=10+1`, which cannot be exact division by # `logger_hook.interval` - logger_hook._log_train.assert_not_called() - runner.iter = 9 - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - logger_hook._log_train.assert_called() + runner.log_processor.get_log_after_iter.assert_not_called() + logger_hook.after_train_iter(runner, batch_idx=9) + runner.log_processor.get_log_after_iter.assert_called() # Test LoggerHook by epoch. - logger_hook = LoggerHook(by_epoch=True) - logger_hook._log_train = MagicMock() - # Only `runner.inner_iter` will work. - runner.iter = 9 - batch_idx = 10 - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - logger_hook._log_train.assert_not_called() - batch_idx = 9 - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - logger_hook._log_train.assert_called() + logger_hook = LoggerHook() + runner = MagicMock() + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + # Only `batch_idx` will work. + logger_hook.after_train_iter(runner, batch_idx=10) + runner.log_processor.get_log_after_iter.assert_not_called() + logger_hook.after_train_iter(runner, batch_idx=9) + runner.log_processor.get_log_after_iter.assert_called() # Test end of the epoch. - logger_hook = LoggerHook(by_epoch=True, ignore_last=False) - logger_hook._log_train = MagicMock() - runner.cur_dataloader = [0] * 5 - batch_idx = 4 - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - logger_hook._log_train.assert_called() + runner = MagicMock() + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook = LoggerHook(ignore_last=False) + runner.train_dataloader = [0] * 5 + logger_hook.after_train_iter(runner, batch_idx=4) + runner.log_processor.get_log_after_iter.assert_called() # Test print exp_name + runner = MagicMock() + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) runner.meta = dict(exp_name='retinanet') - logger_hook = LoggerHook() runner.logger = MagicMock() - logger_hook._log_train = MagicMock() - logger_hook.after_train_iter(runner, batch_idx=batch_idx) - runner.logger.info.assert_called_with( - f'Exp name: {runner.meta["exp_name"]}') + logger_hook = LoggerHook() + logger_hook.after_train_iter(runner, batch_idx=999) + runner.logger.info.assert_called() def test_after_val_epoch(self): logger_hook = LoggerHook() runner = MagicMock() - logger_hook._log_val = MagicMock() + runner.log_processor.get_log_after_epoch = MagicMock( + return_value=(dict(), 'string')) logger_hook.after_val_epoch(runner) - logger_hook._log_val.assert_called() + runner.log_processor.get_log_after_epoch.assert_called() + runner.logger.info.assert_called() + runner.writer.add_scalars.assert_called() - @pytest.mark.parametrize('by_epoch', [True, False]) - def test_log_train(self, by_epoch, capsys): - runner = self._setup_runner() - runner.meta = dict(exp_name='retinanet') - # Prepare LoggerHook - logger_hook = LoggerHook(by_epoch=by_epoch) - logger_hook._inner_iter = 1 - logger_hook.writer = MagicMock() - logger_hook.time_sec_tot = 1000 - logger_hook.start_iter = 0 - logger_hook._get_max_memory = MagicMock(return_value='100') - logger_hook.json_log_path = 'tmp.json' - - # Prepare training information. - train_infos = dict( - lr=0.1, momentum=0.9, time=1.0, data_time=1.0, loss_cls=1.0) - logger_hook._collect_info = MagicMock(return_value=train_infos) - logger_hook._log_train(runner) - # Verify that the correct variables have been written. - runner.writer.add_scalars.assert_called_with( - train_infos, step=11, file_path='tmp.json') - # Verify that the correct context have been logged. - out, _ = capsys.readouterr() - time_avg = logger_hook.time_sec_tot / ( - runner.iter + 1 - logger_hook.start_iter) - eta_second = time_avg * (runner.train_loop.max_iters - runner.iter - 1) - eta_str = str(datetime.timedelta(seconds=int(eta_second))) - if by_epoch: - if torch.cuda.is_available(): - log_str = 'Epoch [2][2/5]\t' \ - f"lr: {train_infos['lr']:.3e} " \ - f"momentum: {train_infos['momentum']:.3e}, " \ - f'eta: {eta_str}, ' \ - f"time: {train_infos['time']:.3f}, " \ - f"data_time: {train_infos['data_time']:.3f}, " \ - f'memory: 100, ' \ - f"loss_cls: {train_infos['loss_cls']:.4f}\n" - else: - log_str = 'Epoch [2][2/5]\t' \ - f"lr: {train_infos['lr']:.3e} " \ - f"momentum: {train_infos['momentum']:.3e}, " \ - f'eta: {eta_str}, ' \ - f"time: {train_infos['time']:.3f}, " \ - f"data_time: {train_infos['data_time']:.3f}, " \ - f"loss_cls: {train_infos['loss_cls']:.4f}\n" - assert out == log_str - else: - if torch.cuda.is_available(): - log_str = 'Iter [11/50]\t' \ - f"lr: {train_infos['lr']:.3e} " \ - f"momentum: {train_infos['momentum']:.3e}, " \ - f'eta: {eta_str}, ' \ - f"time: {train_infos['time']:.3f}, " \ - f"data_time: {train_infos['data_time']:.3f}, " \ - f'memory: 100, ' \ - f"loss_cls: {train_infos['loss_cls']:.4f}\n" - else: - log_str = 'Iter [11/50]\t' \ - f"lr: {train_infos['lr']:.3e} " \ - f"momentum: {train_infos['momentum']:.3e}, " \ - f'eta: {eta_str}, ' \ - f"time: {train_infos['time']:.3f}, " \ - f"data_time: {train_infos['data_time']:.3f}, " \ - f"loss_cls: {train_infos['loss_cls']:.4f}\n" - assert out == log_str - - @pytest.mark.parametrize('by_epoch', [True, False]) - def test_log_val(self, by_epoch, capsys): - runner = self._setup_runner() - # Prepare LoggerHook. - logger_hook = LoggerHook(by_epoch=by_epoch) - logger_hook.json_log_path = 'tmp.json' - metric = dict(accuracy=0.9, data_time=1.0) - logger_hook._collect_info = MagicMock(return_value=metric) - logger_hook._log_val(runner) - # Verify that the correct context have been logged. - out, _ = capsys.readouterr() - runner.writer.add_scalars.assert_called_with( - metric, step=11, file_path='tmp.json') - if by_epoch: - assert out == 'Epoch(val) [1][5]\taccuracy: 0.9000, ' \ - 'data_time: 1.0000\n' - - else: - assert out == 'Iter(val) [5]\taccuracy: 0.9000, ' \ - 'data_time: 1.0000\n' - - def test_get_window_size(self): - runner = self._setup_runner() - logger_hook = LoggerHook() - logger_hook._inner_iter = 1 - # Test get window size by name. - assert logger_hook._get_window_size(runner, 'epoch') == 2 - assert logger_hook._get_window_size(runner, 'global') == 11 - assert logger_hook._get_window_size(runner, 10) == 10 - # Window size must equal to `logger_hook.interval`. - with pytest.raises(AssertionError): - logger_hook._get_window_size(runner, 20) - - with pytest.raises(ValueError): - logger_hook._get_window_size(runner, 'unknwon') - - def test_parse_custom_keys(self): - tag = OrderedDict() - runner = self._setup_runner() - log_buffers = OrderedDict(lr=MagicMock(), loss=MagicMock()) - cfg_dict = dict( - lr=dict(method='min'), - loss=[ - dict(method='min', window_size='global'), - dict(method='max', log_name='loss_max') - ]) - logger_hook = LoggerHook() - for log_key, log_cfg in cfg_dict.items(): - logger_hook._parse_custom_keys(runner, log_key, log_cfg, - log_buffers, tag) - assert list(tag) == ['lr', 'loss', 'loss_max'] - assert log_buffers['lr'].min.assert_called - assert log_buffers['loss'].min.assert_called - assert log_buffers['loss'].max.assert_called - assert log_buffers['loss'].mean.assert_called - # `log_name` Cannot be repeated. - with pytest.raises(KeyError): - cfg_dict = dict(loss=[ - dict(method='min', window_size='global'), - dict(method='max', log_name='loss_max'), - dict(method='mean', log_name='loss_max') - ]) - logger_hook.custom_keys = cfg_dict - for log_key, log_cfg in cfg_dict.items(): - logger_hook._parse_custom_keys(runner, log_key, log_cfg, - log_buffers, tag) - # `log_key` cannot be overwritten multiple times. - with pytest.raises(AssertionError): - cfg_dict = dict(loss=[ - dict(method='min', window_size='global'), - dict(method='max'), - ]) - logger_hook.custom_keys = cfg_dict - for log_key, log_cfg in cfg_dict.items(): - logger_hook._parse_custom_keys(runner, log_key, log_cfg, - log_buffers, tag) - - def test_collect_info(self): - runner = self._setup_runner() - logger_hook = LoggerHook( - custom_keys=dict(time=dict(method='max', log_name='time_max'))) - logger_hook._parse_custom_keys = MagicMock() - # Collect with prefix. - log_buffers = { - 'train/time': MagicMock(), - 'lr': MagicMock(), - 'train/loss_cls': MagicMock(), - 'val/metric': MagicMock() - } - runner.message_hub.log_scalars = log_buffers - tag = logger_hook._collect_info(runner, mode='train') - # Test parse custom_keys - logger_hook._parse_custom_keys.assert_called() - # Test training key in tag. - assert list(tag.keys()) == ['time', 'loss_cls'] - # Test statistics lr with `current`, loss and time with 'mean' - log_buffers['train/time'].mean.assert_called() - log_buffers['train/loss_cls'].mean.assert_called() - log_buffers['train/loss_cls'].current.assert_not_called() - - tag = logger_hook._collect_info(runner, mode='val') - assert list(tag.keys()) == ['metric'] - log_buffers['val/metric'].current.assert_called() - - @patch('torch.cuda.max_memory_allocated', MagicMock()) - @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) - def test_get_max_memory(self): + def test_after_test_epoch(self): logger_hook = LoggerHook() runner = MagicMock() - runner.world_size = 1 - runner.model = torch.nn.Linear(1, 1) - logger_hook._get_max_memory(runner) - torch.cuda.max_memory_allocated.assert_called() - torch.cuda.reset_peak_memory_stats.assert_called() + runner.log_processor.get_log_after_epoch = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook.after_test_epoch(runner) + runner.log_processor.get_log_after_epoch.assert_called() + runner.logger.info.assert_called() - def test_get_iter(self): - runner = self._setup_runner() + def test_after_val_iter(self): logger_hook = LoggerHook() - logger_hook._inner_iter = 1 - # Get global iter when `inner_iter=False` - iter = logger_hook._get_iter(runner) - assert iter == 11 - # Get inner iter - iter = logger_hook._get_iter(runner, inner_iter=True) - assert iter == 2 - # Still get global iter when `logger_hook.by_epoch==False` - logger_hook.by_epoch = False - iter = logger_hook._get_iter(runner, inner_iter=True) - assert iter == 11 - - def test_get_epoch(self): - runner = self._setup_runner() - logger_hook = LoggerHook() - epoch = logger_hook._get_epoch(runner, 'train') - assert epoch == 2 - epoch = logger_hook._get_epoch(runner, 'val') - assert epoch == 1 - with pytest.raises(ValueError): - logger_hook._get_epoch(runner, 'test') - - def _setup_runner(self): runner = MagicMock() - runner.epoch = 1 - runner.cur_dataloader = [0] * 5 - runner.iter = 10 - runner.train_loop.max_iters = 50 - logger = logging.getLogger() - logger.setLevel(logging.INFO) - for handler in logger.handlers: - if not isinstance(handler, logging.StreamHandler): - continue - else: - logger.addHandler(logging.StreamHandler(stream=sys.stdout)) - runner.logger = logger - runner.message_hub = MagicMock() - runner.composed_wirter = MagicMock() - return runner + runner.iter = 0 + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook.after_val_iter(runner, 1) + runner.log_processor.get_log_after_iter.assert_not_called() + logger_hook.after_val_iter(runner, 9) + runner.log_processor.get_log_after_iter.assert_called() + + def test_after_test_iter(self): + logger_hook = LoggerHook() + runner = MagicMock() + runner.iter = 0 + runner.log_processor.get_log_after_iter = MagicMock( + return_value=(dict(), 'log_str')) + logger_hook.after_test_iter(runner, 1) + runner.log_processor.get_log_after_iter.assert_not_called() + logger_hook.after_test_iter(runner, 9) + runner.log_processor.get_log_after_iter.assert_called() diff --git a/tests/test_hook/test_optimizer_hook.py b/tests/test_hook/test_optimizer_hook.py index 5d04ca3f..dc11ee0f 100644 --- a/tests/test_hook/test_optimizer_hook.py +++ b/tests/test_hook/test_optimizer_hook.py @@ -1,5 +1,5 @@ # Copyright (c) OpenMMLab. All rights reserved. -from unittest.mock import Mock +from unittest.mock import MagicMock, Mock import torch from torch import nn @@ -45,7 +45,7 @@ class TestOptimizerHook: model = Model() x = torch.rand(1, 1, 3, 3) - dummy_runner = Mock() + dummy_runner = MagicMock() dummy_runner.optimizer.zero_grad = Mock(return_value=None) dummy_runner.optimizer.step = Mock(return_value=None) dummy_runner.model = model diff --git a/tests/test_logging/test_log_processor.py b/tests/test_logging/test_log_processor.py new file mode 100644 index 00000000..b10cac48 --- /dev/null +++ b/tests/test_logging/test_log_processor.py @@ -0,0 +1,242 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import copy +from unittest.mock import MagicMock, patch + +import pytest +import torch + +from mmengine.logging import LogProcessor, MessageHub, MMLogger + + +class TestLogProcessor: + + def test_init(self): + log_processor = LogProcessor( + window_size=10, by_epoch=True, custom_cfg=None) + assert log_processor.by_epoch + assert log_processor.window_size == 10 + assert log_processor.custom_cfg == [] + + def test_check_custom_cfg(self): + # ``by_epoch==False`` and `window_size='epoch'` in log config will + # raise AssertionError. + custom_cfg = [dict(data_src='loss', window_size='epoch')] + with pytest.raises(AssertionError): + LogProcessor(by_epoch=False, custom_cfg=custom_cfg) + # Duplicate log_name will raise AssertionError. + custom_cfg = [ + dict(data_src='loss', log_name='loss_1'), + dict(data_src='loss', log_name='loss_1') + ] + with pytest.raises(AssertionError): + LogProcessor(custom_cfg=custom_cfg) + # Overwrite loss item twice will raise AssertionError. + custom_cfg = [dict(data_src='loss'), dict(data_src='loss')] + with pytest.raises(AssertionError): + LogProcessor(custom_cfg=custom_cfg) + + custom_cfg = [ + dict(data_src='loss_cls', window_size=100, method_name='min'), + dict(data_src='loss', log_name='loss_min', method_name='max'), + dict(data_src='loss', log_name='loss_max', method_name='max') + ] + LogProcessor(custom_cfg=custom_cfg) + + def test_parse_windows_size(self): + log_processor = LogProcessor() + # Test parse 'epoch' window_size. + log_processor.custom_cfg = [ + dict(data_src='loss_cls', window_size='epoch') + ] + custom_cfg = log_processor._parse_windows_size(self.runner, 1) + assert custom_cfg[0]['window_size'] == 2 + + # Test parse 'global' window_size. + log_processor.custom_cfg = [ + dict(data_src='loss_cls', window_size='global') + ] + custom_cfg = log_processor._parse_windows_size(self.runner, 1) + assert custom_cfg[0]['window_size'] == 11 + + # Test parse int window_size + log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=100)] + custom_cfg = log_processor._parse_windows_size(self.runner, 1) + assert custom_cfg[0]['window_size'] == 100 + + # Invalid type window_size will raise TypeError. + log_processor.custom_cfg = [dict(data_src='loss_cls', window_size=[])] + with pytest.raises(TypeError): + log_processor._parse_windows_size(custom_cfg, self.runner) + + @pytest.mark.parametrize('by_epoch,mode', + ([True, 'train'], [False, 'train'], [True, 'val'], + [False, 'val'], [True, 'test'], [False, 'test'])) + def test_get_log_after_iter(self, by_epoch, mode): + # Prepare LoggerHook + log_processor = LogProcessor(by_epoch=by_epoch) + log_processor._get_max_memory = MagicMock(return_value='100') + eta = 40 + self.runner.message_hub.update_info('eta', eta) + # Prepare training information. + if mode == 'train': + train_logs = dict(lr=0.1, time=1.0, data_time=1.0, loss_cls=1.0) + else: + train_logs = dict(time=1.0, data_time=1.0, loss_cls=1.0) + log_processor._collect_scalars = MagicMock(return_value=train_logs) + tag, out = log_processor.get_log_after_iter(self.runner, 1, mode) + # Verify that the correct context have been logged. + cur_loop = log_processor._get_cur_loop(self.runner, mode) + if by_epoch: + if mode in ['train', 'val']: + cur_epoch = log_processor._get_epoch(self.runner, mode) + log_str = (f'Epoch({mode}) [{cur_epoch}][2/' + f'{len(cur_loop.dataloader)}] ') + else: + log_str = (f'Epoch({mode}) [2/{len(cur_loop.dataloader)}] ') + + if mode == 'train': + log_str += f"lr: {train_logs['lr']:.3e} " + else: + log_str += ' ' + + log_str += (f'eta: 0:00:40 ' + f"time: {train_logs['time']:.3f} " + f"data_time: {train_logs['data_time']:.3f} ") + + if torch.cuda.is_available(): + log_str += 'memory: 100 ' + if mode == 'train': + log_str += f"loss_cls: {train_logs['loss_cls']:.4f}" + assert out == log_str + else: + if mode == 'train': + max_iters = self.runner.train_loop.max_iters + log_str = f'Iter({mode}) [11/{max_iters}] ' + else: + max_iters = len(cur_loop.dataloader) + log_str = f'Iter({mode}) [2/{max_iters}] ' + + if mode == 'train': + log_str += f"lr: {train_logs['lr']:.3e} " + else: + log_str += ' ' + + log_str += (f'eta: 0:00:40 ' + f"time: {train_logs['time']:.3f} " + f"data_time: {train_logs['data_time']:.3f} ") + + if torch.cuda.is_available(): + log_str += 'memory: 100 ' + + if mode == 'train': + log_str += f"loss_cls: {train_logs['loss_cls']:.4f}" + assert out == log_str + + @pytest.mark.parametrize( + 'by_epoch,mode', + ([True, 'val'], [False, 'val'], [True, 'test'], [False, 'test'])) + def test_log_val(self, by_epoch, mode): + # Prepare LoggerHook + log_processor = LogProcessor(by_epoch=by_epoch) + # Prepare validation information. + val_logs = dict(accuracy=0.9, data_time=1.0) + log_processor._collect_scalars = MagicMock(return_value=val_logs) + _, out = log_processor.get_log_after_epoch(self.runner, 2, mode) + if by_epoch: + if mode == 'test': + assert out == 'Epoch(test) [5/5] accuracy: 0.9000' + else: + assert out == 'Epoch(val) [1][10/10] accuracy: 0.9000' + else: + if mode == 'test': + assert out == 'Iter(test) [5/5] accuracy: 0.9000' + else: + assert out == 'Iter(val) [10/10] accuracy: 0.9000' + + def test_collect_scalars(self): + custom_cfg = [ + dict(data_src='time', method_name='mean', window_size=100), + dict(data_src='time', method_name='max', log_name='time_max') + ] + logger_hook = LogProcessor(custom_cfg=custom_cfg) + # Collect with prefix. + log_scalars = { + 'train/time': MagicMock(), + 'lr': MagicMock(), + 'train/loss_cls': MagicMock(), + 'val/metric': MagicMock() + } + self.runner.message_hub._log_scalars = log_scalars + tag = logger_hook._collect_scalars( + copy.deepcopy(custom_cfg), self.runner, mode='train') + # Test training key in tag. + assert list(tag.keys()) == ['time', 'loss_cls', 'time_max'] + # Test statistics lr with `current`, loss and time with 'mean' + log_scalars['train/time'].statistics.assert_called_with( + method_name='max') + log_scalars['train/loss_cls'].mean.assert_called() + + tag = logger_hook._collect_scalars( + copy.deepcopy(custom_cfg), self.runner, mode='val') + assert list(tag.keys()) == ['metric'] + log_scalars['val/metric'].current.assert_called() + + @patch('torch.cuda.max_memory_allocated', MagicMock()) + @patch('torch.cuda.reset_peak_memory_stats', MagicMock()) + def test_get_max_memory(self): + logger_hook = LogProcessor() + runner = MagicMock() + runner.world_size = 1 + runner.model = torch.nn.Linear(1, 1) + logger_hook._get_max_memory(runner) + torch.cuda.max_memory_allocated.assert_called() + torch.cuda.reset_peak_memory_stats.assert_called() + + def test_get_iter(self): + log_processor = LogProcessor() + # Get global iter when `inner_iter=False` + iter = log_processor._get_iter(self.runner) + assert iter == 11 + # Get inner iter + iter = log_processor._get_iter(self.runner, 1) + assert iter == 2 + # Still get global iter when `logger_hook.by_epoch==False` + log_processor.by_epoch = False + iter = log_processor._get_iter(self.runner, 1) + assert iter == 11 + + def test_get_epoch(self): + log_processor = LogProcessor() + epoch = log_processor._get_epoch(self.runner, 'train') + assert epoch == 2 + epoch = log_processor._get_epoch(self.runner, 'val') + assert epoch == 1 + with pytest.raises(ValueError): + log_processor._get_epoch(self.runner, 'test') + + def test_get_cur_loop(self): + log_processor = LogProcessor() + loop = log_processor._get_cur_loop(self.runner, 'train') + assert len(loop.dataloader) == 20 + loop = log_processor._get_cur_loop(self.runner, 'val') + assert len(loop.dataloader) == 10 + loop = log_processor._get_cur_loop(self.runner, 'test') + assert len(loop.dataloader) == 5 + + def setup(self): + runner = MagicMock() + runner.epoch = 1 + runner.iter = 10 + runner.train_loop.max_iters = 50 + runner.train_loop.dataloader = [0] * 20 + runner.val_loop.dataloader = [0] * 10 + runner.test_loop.dataloader = [0] * 5 + logger = MMLogger.get_instance('log_processor_test') + runner.logger = logger + message_hub = MessageHub.get_instance('log_processor_test') + for i in range(10): + message_hub.update_scalar('train/loss', 10 - i) + for i in range(10): + message_hub.update_scalar('val/acc', i * 0.1) + runner.message_hub = message_hub + self.runner = runner diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 7b497c8e..a2576dec 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -221,7 +221,7 @@ class TestRunner(TestCase): self.iter_based_cfg.default_hooks = dict( timer=dict(type='IterTimerHook'), checkpoint=dict(type='CheckpointHook', interval=1, by_epoch=False), - logger=dict(type='LoggerHook', by_epoch=False), + logger=dict(type='LoggerHook'), optimizer=dict(type='OptimizerHook', grad_clip=None), param_scheduler=dict(type='ParamSchedulerHook'))