mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix the resuming error caused by HistoryBuffer (#1078)
This commit is contained in:
parent
5b9a1544b0
commit
17c5414d16
@ -207,3 +207,23 @@ class HistoryBuffer:
|
||||
raise ValueError('HistoryBuffer._log_history is an empty array! '
|
||||
'please call update first')
|
||||
return self._log_history[-1]
|
||||
|
||||
def __getstate__(self) -> dict:
|
||||
"""Make ``_statistics_methods`` can be resumed.
|
||||
|
||||
Returns:
|
||||
dict: State dict including statistics_methods.
|
||||
"""
|
||||
self.__dict__.update(statistics_methods=self._statistics_methods)
|
||||
return self.__dict__
|
||||
|
||||
def __setstate__(self, state):
|
||||
"""Try to load ``_statistics_methods`` from state.
|
||||
|
||||
Args:
|
||||
state (dict): State dict.
|
||||
"""
|
||||
statistics_methods = state.pop('statistics_methods', {})
|
||||
self._set_default_statistics()
|
||||
self._statistics_methods.update(statistics_methods)
|
||||
self.__dict__.update(state)
|
||||
|
@ -1,6 +1,7 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import datetime
|
||||
import re
|
||||
from collections import OrderedDict
|
||||
from itertools import chain
|
||||
from typing import List, Optional, Tuple
|
||||
@ -59,6 +60,9 @@ class LogProcessor:
|
||||
``train/loss``, and accuracy will be saved as ``val/accuracy``.
|
||||
Defaults to False.
|
||||
`New in version 0.7.0.`
|
||||
mean_pattern (str): This is a regular expression used to match the log
|
||||
that need to be included in the smoothing statistics.
|
||||
`New in version 0.7.3.`
|
||||
|
||||
Examples:
|
||||
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
||||
@ -106,12 +110,14 @@ class LogProcessor:
|
||||
by_epoch=True,
|
||||
custom_cfg: Optional[List[dict]] = None,
|
||||
num_digits: int = 4,
|
||||
log_with_hierarchy: bool = False):
|
||||
log_with_hierarchy: bool = False,
|
||||
mean_pattern=r'.*(loss|time|data_time|grad_norm).*'):
|
||||
self.window_size = window_size
|
||||
self.by_epoch = by_epoch
|
||||
self.custom_cfg = custom_cfg if custom_cfg else []
|
||||
self.num_digits = num_digits
|
||||
self.log_with_hierarchy = log_with_hierarchy
|
||||
self.mean_pattern = re.compile(mean_pattern)
|
||||
self._check_custom_cfg()
|
||||
|
||||
def get_log_after_iter(self, runner, batch_idx: int,
|
||||
@ -280,14 +286,11 @@ class LogProcessor:
|
||||
# Count the averaged time and data_time by epoch
|
||||
if 'time' not in custom_keys:
|
||||
custom_cfg_copy.append(
|
||||
dict(
|
||||
data_src=f'{mode}/time',
|
||||
window_size='epoch',
|
||||
method_name='mean'))
|
||||
dict(data_src='time', window_size='epoch', method_name='mean'))
|
||||
if 'data_time' not in custom_keys:
|
||||
custom_cfg_copy.append(
|
||||
dict(
|
||||
data_src=f'{mode}/data_time',
|
||||
data_src='data_time',
|
||||
window_size='epoch',
|
||||
method_name='mean'))
|
||||
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
||||
@ -358,18 +361,19 @@ class LogProcessor:
|
||||
mode_history_scalars[key] = log_buffer
|
||||
for key in mode_history_scalars:
|
||||
# Update the latest learning rate and smoothed time logs.
|
||||
if 'loss' in key or key in ('time', 'data_time', 'grad_norm'):
|
||||
if re.search(self.mean_pattern, key) is not None:
|
||||
tag[key] = mode_history_scalars[key].mean(self.window_size)
|
||||
else:
|
||||
# Default statistic method is current.
|
||||
tag[key] = mode_history_scalars[key].current()
|
||||
# Update custom keys.
|
||||
for log_cfg in custom_cfg:
|
||||
data_src = log_cfg.pop('data_src')
|
||||
if 'log_name' in log_cfg:
|
||||
log_name = log_cfg.pop('log_name')
|
||||
if not reserve_prefix:
|
||||
data_src = log_cfg.pop('data_src')
|
||||
log_name = f"{log_cfg.pop('log_name', data_src)}"
|
||||
else:
|
||||
log_name = data_src
|
||||
data_src = f"{mode}/{log_cfg.pop('data_src')}"
|
||||
log_name = f"{mode}/{log_cfg.pop('log_name', data_src)}"
|
||||
# log item in custom_cfg could only exist in train or val
|
||||
# mode.
|
||||
if data_src in mode_history_scalars:
|
||||
|
@ -13,6 +13,11 @@ else:
|
||||
array_method.append(torch.tensor)
|
||||
|
||||
|
||||
@HistoryBuffer.register_statistics
|
||||
def custom_statistics(self):
|
||||
return -1
|
||||
|
||||
|
||||
class TestLoggerBuffer:
|
||||
|
||||
def test_init(self):
|
||||
@ -112,10 +117,5 @@ class TestLoggerBuffer:
|
||||
log_buffer.statistics('unknown')
|
||||
|
||||
def test_register_statistics(self):
|
||||
|
||||
@HistoryBuffer.register_statistics
|
||||
def custom_statistics(self):
|
||||
return -1
|
||||
|
||||
log_buffer = HistoryBuffer()
|
||||
assert log_buffer.statistics('custom_statistics') == -1
|
||||
|
@ -201,18 +201,32 @@ class TestLogProcessor(RunnerTestCase):
|
||||
self.runner.message_hub._log_scalars = log_scalars
|
||||
tag = log_processor._collect_scalars(
|
||||
copy.deepcopy(custom_cfg), self.runner, mode='train')
|
||||
# Test training key in tag.
|
||||
# Training key in tag.
|
||||
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
|
||||
# Test statistics lr with `current`, loss and time with 'mean'
|
||||
assert tag['time'] == time_scalars[-10:].mean()
|
||||
assert tag['time_max'] == time_scalars.max()
|
||||
assert tag['loss_cls'] == loss_cls_scalars[-10:].mean()
|
||||
|
||||
# Validation key in tag
|
||||
tag = log_processor._collect_scalars(
|
||||
copy.deepcopy(custom_cfg), self.runner, mode='val')
|
||||
assert list(tag.keys()) == ['metric']
|
||||
assert tag['metric'] == metric_scalars[-1]
|
||||
|
||||
# reserve_prefix=True
|
||||
tag = log_processor._collect_scalars(
|
||||
copy.deepcopy(custom_cfg),
|
||||
self.runner,
|
||||
mode='train',
|
||||
reserve_prefix=True)
|
||||
assert list(
|
||||
tag.keys()) == ['train/time', 'train/loss_cls', 'train/time_max']
|
||||
# Test statistics lr with `current`, loss and time with 'mean'
|
||||
assert tag['train/time'] == time_scalars[-10:].mean()
|
||||
assert tag['train/time_max'] == time_scalars.max()
|
||||
assert tag['train/loss_cls'] == loss_cls_scalars[-10:].mean()
|
||||
|
||||
def test_collect_non_scalars(self):
|
||||
metric1 = np.random.rand(10)
|
||||
metric2 = torch.tensor(10)
|
||||
|
@ -21,7 +21,7 @@ from mmengine.evaluator import BaseMetric, Evaluator
|
||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
||||
IterTimerHook, LoggerHook, ParamSchedulerHook,
|
||||
RuntimeInfoHook)
|
||||
from mmengine.logging import MessageHub, MMLogger
|
||||
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
|
||||
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
||||
OptimWrapper, OptimWrapperDict, StepLR)
|
||||
@ -2290,7 +2290,7 @@ class TestRunner(TestCase):
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.train()
|
||||
|
||||
# 2.1 test `save_checkpoint` which is called by `CheckpointHook`
|
||||
# 2.1.1 test `save_checkpoint` which is called by `CheckpointHook`
|
||||
path = osp.join(self.temp_dir, 'iter_12.pth')
|
||||
self.assertTrue(osp.exists(path))
|
||||
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
|
||||
@ -2304,6 +2304,10 @@ class TestRunner(TestCase):
|
||||
message_hub.load_state_dict(ckpt['message_hub'])
|
||||
self.assertEqual(message_hub.get_info('epoch'), 0)
|
||||
self.assertEqual(message_hub.get_info('iter'), 11)
|
||||
# 2.1.2 check class attribute _statistic_methods can be saved
|
||||
HistoryBuffer._statistics_methods.clear()
|
||||
ckpt = torch.load(path)
|
||||
self.assertIn('min', HistoryBuffer._statistics_methods)
|
||||
|
||||
# 2.2 test `load_checkpoint`
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
|
Loading…
x
Reference in New Issue
Block a user