[Fix] Fix the resuming error caused by HistoryBuffer (#1078)

This commit is contained in:
Mashiro 2023-04-21 17:23:38 +08:00 committed by GitHub
parent 5b9a1544b0
commit 17c5414d16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 61 additions and 19 deletions

View File

@ -207,3 +207,23 @@ class HistoryBuffer:
raise ValueError('HistoryBuffer._log_history is an empty array! '
'please call update first')
return self._log_history[-1]
def __getstate__(self) -> dict:
"""Make ``_statistics_methods`` can be resumed.
Returns:
dict: State dict including statistics_methods.
"""
self.__dict__.update(statistics_methods=self._statistics_methods)
return self.__dict__
def __setstate__(self, state):
"""Try to load ``_statistics_methods`` from state.
Args:
state (dict): State dict.
"""
statistics_methods = state.pop('statistics_methods', {})
self._set_default_statistics()
self._statistics_methods.update(statistics_methods)
self.__dict__.update(state)

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import datetime
import re
from collections import OrderedDict
from itertools import chain
from typing import List, Optional, Tuple
@ -59,6 +60,9 @@ class LogProcessor:
``train/loss``, and accuracy will be saved as ``val/accuracy``.
Defaults to False.
`New in version 0.7.0.`
mean_pattern (str): This is a regular expression used to match the log
that need to be included in the smoothing statistics.
`New in version 0.7.3.`
Examples:
>>> # `log_name` is defined, `loss_large_window` will be an additional
@ -106,12 +110,14 @@ class LogProcessor:
by_epoch=True,
custom_cfg: Optional[List[dict]] = None,
num_digits: int = 4,
log_with_hierarchy: bool = False):
log_with_hierarchy: bool = False,
mean_pattern=r'.*(loss|time|data_time|grad_norm).*'):
self.window_size = window_size
self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else []
self.num_digits = num_digits
self.log_with_hierarchy = log_with_hierarchy
self.mean_pattern = re.compile(mean_pattern)
self._check_custom_cfg()
def get_log_after_iter(self, runner, batch_idx: int,
@ -280,14 +286,11 @@ class LogProcessor:
# Count the averaged time and data_time by epoch
if 'time' not in custom_keys:
custom_cfg_copy.append(
dict(
data_src=f'{mode}/time',
window_size='epoch',
method_name='mean'))
dict(data_src='time', window_size='epoch', method_name='mean'))
if 'data_time' not in custom_keys:
custom_cfg_copy.append(
dict(
data_src=f'{mode}/data_time',
data_src='data_time',
window_size='epoch',
method_name='mean'))
parsed_cfg = self._parse_windows_size(runner, batch_idx,
@ -358,18 +361,19 @@ class LogProcessor:
mode_history_scalars[key] = log_buffer
for key in mode_history_scalars:
# Update the latest learning rate and smoothed time logs.
if 'loss' in key or key in ('time', 'data_time', 'grad_norm'):
if re.search(self.mean_pattern, key) is not None:
tag[key] = mode_history_scalars[key].mean(self.window_size)
else:
# Default statistic method is current.
tag[key] = mode_history_scalars[key].current()
# Update custom keys.
for log_cfg in custom_cfg:
data_src = log_cfg.pop('data_src')
if 'log_name' in log_cfg:
log_name = log_cfg.pop('log_name')
if not reserve_prefix:
data_src = log_cfg.pop('data_src')
log_name = f"{log_cfg.pop('log_name', data_src)}"
else:
log_name = data_src
data_src = f"{mode}/{log_cfg.pop('data_src')}"
log_name = f"{mode}/{log_cfg.pop('log_name', data_src)}"
# log item in custom_cfg could only exist in train or val
# mode.
if data_src in mode_history_scalars:

View File

@ -13,6 +13,11 @@ else:
array_method.append(torch.tensor)
@HistoryBuffer.register_statistics
def custom_statistics(self):
return -1
class TestLoggerBuffer:
def test_init(self):
@ -112,10 +117,5 @@ class TestLoggerBuffer:
log_buffer.statistics('unknown')
def test_register_statistics(self):
@HistoryBuffer.register_statistics
def custom_statistics(self):
return -1
log_buffer = HistoryBuffer()
assert log_buffer.statistics('custom_statistics') == -1

View File

@ -201,18 +201,32 @@ class TestLogProcessor(RunnerTestCase):
self.runner.message_hub._log_scalars = log_scalars
tag = log_processor._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='train')
# Test training key in tag.
# Training key in tag.
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
# Test statistics lr with `current`, loss and time with 'mean'
assert tag['time'] == time_scalars[-10:].mean()
assert tag['time_max'] == time_scalars.max()
assert tag['loss_cls'] == loss_cls_scalars[-10:].mean()
# Validation key in tag
tag = log_processor._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='val')
assert list(tag.keys()) == ['metric']
assert tag['metric'] == metric_scalars[-1]
# reserve_prefix=True
tag = log_processor._collect_scalars(
copy.deepcopy(custom_cfg),
self.runner,
mode='train',
reserve_prefix=True)
assert list(
tag.keys()) == ['train/time', 'train/loss_cls', 'train/time_max']
# Test statistics lr with `current`, loss and time with 'mean'
assert tag['train/time'] == time_scalars[-10:].mean()
assert tag['train/time_max'] == time_scalars.max()
assert tag['train/loss_cls'] == loss_cls_scalars[-10:].mean()
def test_collect_non_scalars(self):
metric1 = np.random.rand(10)
metric2 = torch.tensor(10)

View File

@ -21,7 +21,7 @@ from mmengine.evaluator import BaseMetric, Evaluator
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, ParamSchedulerHook,
RuntimeInfoHook)
from mmengine.logging import MessageHub, MMLogger
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
OptimWrapper, OptimWrapperDict, StepLR)
@ -2290,7 +2290,7 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg)
runner.train()
# 2.1 test `save_checkpoint` which is called by `CheckpointHook`
# 2.1.1 test `save_checkpoint` which is called by `CheckpointHook`
path = osp.join(self.temp_dir, 'iter_12.pth')
self.assertTrue(osp.exists(path))
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
@ -2304,6 +2304,10 @@ class TestRunner(TestCase):
message_hub.load_state_dict(ckpt['message_hub'])
self.assertEqual(message_hub.get_info('epoch'), 0)
self.assertEqual(message_hub.get_info('iter'), 11)
# 2.1.2 check class attribute _statistic_methods can be saved
HistoryBuffer._statistics_methods.clear()
ckpt = torch.load(path)
self.assertIn('min', HistoryBuffer._statistics_methods)
# 2.2 test `load_checkpoint`
cfg = copy.deepcopy(self.iter_based_cfg)