diff --git a/mmengine/logging/history_buffer.py b/mmengine/logging/history_buffer.py index 34a78ac6..58effa81 100644 --- a/mmengine/logging/history_buffer.py +++ b/mmengine/logging/history_buffer.py @@ -207,3 +207,23 @@ class HistoryBuffer: raise ValueError('HistoryBuffer._log_history is an empty array! ' 'please call update first') return self._log_history[-1] + + def __getstate__(self) -> dict: + """Make ``_statistics_methods`` can be resumed. + + Returns: + dict: State dict including statistics_methods. + """ + self.__dict__.update(statistics_methods=self._statistics_methods) + return self.__dict__ + + def __setstate__(self, state): + """Try to load ``_statistics_methods`` from state. + + Args: + state (dict): State dict. + """ + statistics_methods = state.pop('statistics_methods', {}) + self._set_default_statistics() + self._statistics_methods.update(statistics_methods) + self.__dict__.update(state) diff --git a/mmengine/runner/log_processor.py b/mmengine/runner/log_processor.py index 75047b37..f915a222 100644 --- a/mmengine/runner/log_processor.py +++ b/mmengine/runner/log_processor.py @@ -1,6 +1,7 @@ # Copyright (c) OpenMMLab. All rights reserved. import copy import datetime +import re from collections import OrderedDict from itertools import chain from typing import List, Optional, Tuple @@ -59,6 +60,9 @@ class LogProcessor: ``train/loss``, and accuracy will be saved as ``val/accuracy``. Defaults to False. `New in version 0.7.0.` + mean_pattern (str): This is a regular expression used to match the log + that need to be included in the smoothing statistics. + `New in version 0.7.3.` Examples: >>> # `log_name` is defined, `loss_large_window` will be an additional @@ -106,12 +110,14 @@ class LogProcessor: by_epoch=True, custom_cfg: Optional[List[dict]] = None, num_digits: int = 4, - log_with_hierarchy: bool = False): + log_with_hierarchy: bool = False, + mean_pattern=r'.*(loss|time|data_time|grad_norm).*'): self.window_size = window_size self.by_epoch = by_epoch self.custom_cfg = custom_cfg if custom_cfg else [] self.num_digits = num_digits self.log_with_hierarchy = log_with_hierarchy + self.mean_pattern = re.compile(mean_pattern) self._check_custom_cfg() def get_log_after_iter(self, runner, batch_idx: int, @@ -280,14 +286,11 @@ class LogProcessor: # Count the averaged time and data_time by epoch if 'time' not in custom_keys: custom_cfg_copy.append( - dict( - data_src=f'{mode}/time', - window_size='epoch', - method_name='mean')) + dict(data_src='time', window_size='epoch', method_name='mean')) if 'data_time' not in custom_keys: custom_cfg_copy.append( dict( - data_src=f'{mode}/data_time', + data_src='data_time', window_size='epoch', method_name='mean')) parsed_cfg = self._parse_windows_size(runner, batch_idx, @@ -358,18 +361,19 @@ class LogProcessor: mode_history_scalars[key] = log_buffer for key in mode_history_scalars: # Update the latest learning rate and smoothed time logs. - if 'loss' in key or key in ('time', 'data_time', 'grad_norm'): + if re.search(self.mean_pattern, key) is not None: tag[key] = mode_history_scalars[key].mean(self.window_size) else: # Default statistic method is current. tag[key] = mode_history_scalars[key].current() # Update custom keys. for log_cfg in custom_cfg: - data_src = log_cfg.pop('data_src') - if 'log_name' in log_cfg: - log_name = log_cfg.pop('log_name') + if not reserve_prefix: + data_src = log_cfg.pop('data_src') + log_name = f"{log_cfg.pop('log_name', data_src)}" else: - log_name = data_src + data_src = f"{mode}/{log_cfg.pop('data_src')}" + log_name = f"{mode}/{log_cfg.pop('log_name', data_src)}" # log item in custom_cfg could only exist in train or val # mode. if data_src in mode_history_scalars: diff --git a/tests/test_logging/test_history_buffer.py b/tests/test_logging/test_history_buffer.py index b8aaca4a..99c03165 100644 --- a/tests/test_logging/test_history_buffer.py +++ b/tests/test_logging/test_history_buffer.py @@ -13,6 +13,11 @@ else: array_method.append(torch.tensor) +@HistoryBuffer.register_statistics +def custom_statistics(self): + return -1 + + class TestLoggerBuffer: def test_init(self): @@ -112,10 +117,5 @@ class TestLoggerBuffer: log_buffer.statistics('unknown') def test_register_statistics(self): - - @HistoryBuffer.register_statistics - def custom_statistics(self): - return -1 - log_buffer = HistoryBuffer() assert log_buffer.statistics('custom_statistics') == -1 diff --git a/tests/test_runner/test_log_processor.py b/tests/test_runner/test_log_processor.py index 8a51c5c1..b0e1382b 100644 --- a/tests/test_runner/test_log_processor.py +++ b/tests/test_runner/test_log_processor.py @@ -201,18 +201,32 @@ class TestLogProcessor(RunnerTestCase): self.runner.message_hub._log_scalars = log_scalars tag = log_processor._collect_scalars( copy.deepcopy(custom_cfg), self.runner, mode='train') - # Test training key in tag. + # Training key in tag. assert list(tag.keys()) == ['time', 'loss_cls', 'time_max'] # Test statistics lr with `current`, loss and time with 'mean' assert tag['time'] == time_scalars[-10:].mean() assert tag['time_max'] == time_scalars.max() assert tag['loss_cls'] == loss_cls_scalars[-10:].mean() + # Validation key in tag tag = log_processor._collect_scalars( copy.deepcopy(custom_cfg), self.runner, mode='val') assert list(tag.keys()) == ['metric'] assert tag['metric'] == metric_scalars[-1] + # reserve_prefix=True + tag = log_processor._collect_scalars( + copy.deepcopy(custom_cfg), + self.runner, + mode='train', + reserve_prefix=True) + assert list( + tag.keys()) == ['train/time', 'train/loss_cls', 'train/time_max'] + # Test statistics lr with `current`, loss and time with 'mean' + assert tag['train/time'] == time_scalars[-10:].mean() + assert tag['train/time_max'] == time_scalars.max() + assert tag['train/loss_cls'] == loss_cls_scalars[-10:].mean() + def test_collect_non_scalars(self): metric1 = np.random.rand(10) metric2 = torch.tensor(10) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index fc00efab..611be6e1 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -21,7 +21,7 @@ from mmengine.evaluator import BaseMetric, Evaluator from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, IterTimerHook, LoggerHook, ParamSchedulerHook, RuntimeInfoHook) -from mmengine.logging import MessageHub, MMLogger +from mmengine.logging import HistoryBuffer, MessageHub, MMLogger from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, OptimWrapper, OptimWrapperDict, StepLR) @@ -2290,7 +2290,7 @@ class TestRunner(TestCase): runner = Runner.from_cfg(cfg) runner.train() - # 2.1 test `save_checkpoint` which is called by `CheckpointHook` + # 2.1.1 test `save_checkpoint` which is called by `CheckpointHook` path = osp.join(self.temp_dir, 'iter_12.pth') self.assertTrue(osp.exists(path)) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) @@ -2304,6 +2304,10 @@ class TestRunner(TestCase): message_hub.load_state_dict(ckpt['message_hub']) self.assertEqual(message_hub.get_info('epoch'), 0) self.assertEqual(message_hub.get_info('iter'), 11) + # 2.1.2 check class attribute _statistic_methods can be saved + HistoryBuffer._statistics_methods.clear() + ckpt = torch.load(path) + self.assertIn('min', HistoryBuffer._statistics_methods) # 2.2 test `load_checkpoint` cfg = copy.deepcopy(self.iter_based_cfg)