[Fix] Fix the resuming error caused by HistoryBuffer (#1078)

This commit is contained in:
Mashiro 2023-04-21 17:23:38 +08:00 committed by GitHub
parent 5b9a1544b0
commit 17c5414d16
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 61 additions and 19 deletions

View File

@ -207,3 +207,23 @@ class HistoryBuffer:
raise ValueError('HistoryBuffer._log_history is an empty array! ' raise ValueError('HistoryBuffer._log_history is an empty array! '
'please call update first') 'please call update first')
return self._log_history[-1] return self._log_history[-1]
def __getstate__(self) -> dict:
"""Make ``_statistics_methods`` can be resumed.
Returns:
dict: State dict including statistics_methods.
"""
self.__dict__.update(statistics_methods=self._statistics_methods)
return self.__dict__
def __setstate__(self, state):
"""Try to load ``_statistics_methods`` from state.
Args:
state (dict): State dict.
"""
statistics_methods = state.pop('statistics_methods', {})
self._set_default_statistics()
self._statistics_methods.update(statistics_methods)
self.__dict__.update(state)

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
import copy import copy
import datetime import datetime
import re
from collections import OrderedDict from collections import OrderedDict
from itertools import chain from itertools import chain
from typing import List, Optional, Tuple from typing import List, Optional, Tuple
@ -59,6 +60,9 @@ class LogProcessor:
``train/loss``, and accuracy will be saved as ``val/accuracy``. ``train/loss``, and accuracy will be saved as ``val/accuracy``.
Defaults to False. Defaults to False.
`New in version 0.7.0.` `New in version 0.7.0.`
mean_pattern (str): This is a regular expression used to match the log
that need to be included in the smoothing statistics.
`New in version 0.7.3.`
Examples: Examples:
>>> # `log_name` is defined, `loss_large_window` will be an additional >>> # `log_name` is defined, `loss_large_window` will be an additional
@ -106,12 +110,14 @@ class LogProcessor:
by_epoch=True, by_epoch=True,
custom_cfg: Optional[List[dict]] = None, custom_cfg: Optional[List[dict]] = None,
num_digits: int = 4, num_digits: int = 4,
log_with_hierarchy: bool = False): log_with_hierarchy: bool = False,
mean_pattern=r'.*(loss|time|data_time|grad_norm).*'):
self.window_size = window_size self.window_size = window_size
self.by_epoch = by_epoch self.by_epoch = by_epoch
self.custom_cfg = custom_cfg if custom_cfg else [] self.custom_cfg = custom_cfg if custom_cfg else []
self.num_digits = num_digits self.num_digits = num_digits
self.log_with_hierarchy = log_with_hierarchy self.log_with_hierarchy = log_with_hierarchy
self.mean_pattern = re.compile(mean_pattern)
self._check_custom_cfg() self._check_custom_cfg()
def get_log_after_iter(self, runner, batch_idx: int, def get_log_after_iter(self, runner, batch_idx: int,
@ -280,14 +286,11 @@ class LogProcessor:
# Count the averaged time and data_time by epoch # Count the averaged time and data_time by epoch
if 'time' not in custom_keys: if 'time' not in custom_keys:
custom_cfg_copy.append( custom_cfg_copy.append(
dict( dict(data_src='time', window_size='epoch', method_name='mean'))
data_src=f'{mode}/time',
window_size='epoch',
method_name='mean'))
if 'data_time' not in custom_keys: if 'data_time' not in custom_keys:
custom_cfg_copy.append( custom_cfg_copy.append(
dict( dict(
data_src=f'{mode}/data_time', data_src='data_time',
window_size='epoch', window_size='epoch',
method_name='mean')) method_name='mean'))
parsed_cfg = self._parse_windows_size(runner, batch_idx, parsed_cfg = self._parse_windows_size(runner, batch_idx,
@ -358,18 +361,19 @@ class LogProcessor:
mode_history_scalars[key] = log_buffer mode_history_scalars[key] = log_buffer
for key in mode_history_scalars: for key in mode_history_scalars:
# Update the latest learning rate and smoothed time logs. # Update the latest learning rate and smoothed time logs.
if 'loss' in key or key in ('time', 'data_time', 'grad_norm'): if re.search(self.mean_pattern, key) is not None:
tag[key] = mode_history_scalars[key].mean(self.window_size) tag[key] = mode_history_scalars[key].mean(self.window_size)
else: else:
# Default statistic method is current. # Default statistic method is current.
tag[key] = mode_history_scalars[key].current() tag[key] = mode_history_scalars[key].current()
# Update custom keys. # Update custom keys.
for log_cfg in custom_cfg: for log_cfg in custom_cfg:
data_src = log_cfg.pop('data_src') if not reserve_prefix:
if 'log_name' in log_cfg: data_src = log_cfg.pop('data_src')
log_name = log_cfg.pop('log_name') log_name = f"{log_cfg.pop('log_name', data_src)}"
else: else:
log_name = data_src data_src = f"{mode}/{log_cfg.pop('data_src')}"
log_name = f"{mode}/{log_cfg.pop('log_name', data_src)}"
# log item in custom_cfg could only exist in train or val # log item in custom_cfg could only exist in train or val
# mode. # mode.
if data_src in mode_history_scalars: if data_src in mode_history_scalars:

View File

@ -13,6 +13,11 @@ else:
array_method.append(torch.tensor) array_method.append(torch.tensor)
@HistoryBuffer.register_statistics
def custom_statistics(self):
return -1
class TestLoggerBuffer: class TestLoggerBuffer:
def test_init(self): def test_init(self):
@ -112,10 +117,5 @@ class TestLoggerBuffer:
log_buffer.statistics('unknown') log_buffer.statistics('unknown')
def test_register_statistics(self): def test_register_statistics(self):
@HistoryBuffer.register_statistics
def custom_statistics(self):
return -1
log_buffer = HistoryBuffer() log_buffer = HistoryBuffer()
assert log_buffer.statistics('custom_statistics') == -1 assert log_buffer.statistics('custom_statistics') == -1

View File

@ -201,18 +201,32 @@ class TestLogProcessor(RunnerTestCase):
self.runner.message_hub._log_scalars = log_scalars self.runner.message_hub._log_scalars = log_scalars
tag = log_processor._collect_scalars( tag = log_processor._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='train') copy.deepcopy(custom_cfg), self.runner, mode='train')
# Test training key in tag. # Training key in tag.
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max'] assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
# Test statistics lr with `current`, loss and time with 'mean' # Test statistics lr with `current`, loss and time with 'mean'
assert tag['time'] == time_scalars[-10:].mean() assert tag['time'] == time_scalars[-10:].mean()
assert tag['time_max'] == time_scalars.max() assert tag['time_max'] == time_scalars.max()
assert tag['loss_cls'] == loss_cls_scalars[-10:].mean() assert tag['loss_cls'] == loss_cls_scalars[-10:].mean()
# Validation key in tag
tag = log_processor._collect_scalars( tag = log_processor._collect_scalars(
copy.deepcopy(custom_cfg), self.runner, mode='val') copy.deepcopy(custom_cfg), self.runner, mode='val')
assert list(tag.keys()) == ['metric'] assert list(tag.keys()) == ['metric']
assert tag['metric'] == metric_scalars[-1] assert tag['metric'] == metric_scalars[-1]
# reserve_prefix=True
tag = log_processor._collect_scalars(
copy.deepcopy(custom_cfg),
self.runner,
mode='train',
reserve_prefix=True)
assert list(
tag.keys()) == ['train/time', 'train/loss_cls', 'train/time_max']
# Test statistics lr with `current`, loss and time with 'mean'
assert tag['train/time'] == time_scalars[-10:].mean()
assert tag['train/time_max'] == time_scalars.max()
assert tag['train/loss_cls'] == loss_cls_scalars[-10:].mean()
def test_collect_non_scalars(self): def test_collect_non_scalars(self):
metric1 = np.random.rand(10) metric1 = np.random.rand(10)
metric2 = torch.tensor(10) metric2 = torch.tensor(10)

View File

@ -21,7 +21,7 @@ from mmengine.evaluator import BaseMetric, Evaluator
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook, from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
IterTimerHook, LoggerHook, ParamSchedulerHook, IterTimerHook, LoggerHook, ParamSchedulerHook,
RuntimeInfoHook) RuntimeInfoHook)
from mmengine.logging import MessageHub, MMLogger from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR, from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
OptimWrapper, OptimWrapperDict, StepLR) OptimWrapper, OptimWrapperDict, StepLR)
@ -2290,7 +2290,7 @@ class TestRunner(TestCase):
runner = Runner.from_cfg(cfg) runner = Runner.from_cfg(cfg)
runner.train() runner.train()
# 2.1 test `save_checkpoint` which is called by `CheckpointHook` # 2.1.1 test `save_checkpoint` which is called by `CheckpointHook`
path = osp.join(self.temp_dir, 'iter_12.pth') path = osp.join(self.temp_dir, 'iter_12.pth')
self.assertTrue(osp.exists(path)) self.assertTrue(osp.exists(path))
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth'))) self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
@ -2304,6 +2304,10 @@ class TestRunner(TestCase):
message_hub.load_state_dict(ckpt['message_hub']) message_hub.load_state_dict(ckpt['message_hub'])
self.assertEqual(message_hub.get_info('epoch'), 0) self.assertEqual(message_hub.get_info('epoch'), 0)
self.assertEqual(message_hub.get_info('iter'), 11) self.assertEqual(message_hub.get_info('iter'), 11)
# 2.1.2 check class attribute _statistic_methods can be saved
HistoryBuffer._statistics_methods.clear()
ckpt = torch.load(path)
self.assertIn('min', HistoryBuffer._statistics_methods)
# 2.2 test `load_checkpoint` # 2.2 test `load_checkpoint`
cfg = copy.deepcopy(self.iter_based_cfg) cfg = copy.deepcopy(self.iter_based_cfg)