mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix the resuming error caused by HistoryBuffer (#1078)
This commit is contained in:
parent
5b9a1544b0
commit
17c5414d16
@ -207,3 +207,23 @@ class HistoryBuffer:
|
|||||||
raise ValueError('HistoryBuffer._log_history is an empty array! '
|
raise ValueError('HistoryBuffer._log_history is an empty array! '
|
||||||
'please call update first')
|
'please call update first')
|
||||||
return self._log_history[-1]
|
return self._log_history[-1]
|
||||||
|
|
||||||
|
def __getstate__(self) -> dict:
|
||||||
|
"""Make ``_statistics_methods`` can be resumed.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: State dict including statistics_methods.
|
||||||
|
"""
|
||||||
|
self.__dict__.update(statistics_methods=self._statistics_methods)
|
||||||
|
return self.__dict__
|
||||||
|
|
||||||
|
def __setstate__(self, state):
|
||||||
|
"""Try to load ``_statistics_methods`` from state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state (dict): State dict.
|
||||||
|
"""
|
||||||
|
statistics_methods = state.pop('statistics_methods', {})
|
||||||
|
self._set_default_statistics()
|
||||||
|
self._statistics_methods.update(statistics_methods)
|
||||||
|
self.__dict__.update(state)
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import copy
|
import copy
|
||||||
import datetime
|
import datetime
|
||||||
|
import re
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from typing import List, Optional, Tuple
|
from typing import List, Optional, Tuple
|
||||||
@ -59,6 +60,9 @@ class LogProcessor:
|
|||||||
``train/loss``, and accuracy will be saved as ``val/accuracy``.
|
``train/loss``, and accuracy will be saved as ``val/accuracy``.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
`New in version 0.7.0.`
|
`New in version 0.7.0.`
|
||||||
|
mean_pattern (str): This is a regular expression used to match the log
|
||||||
|
that need to be included in the smoothing statistics.
|
||||||
|
`New in version 0.7.3.`
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
>>> # `log_name` is defined, `loss_large_window` will be an additional
|
||||||
@ -106,12 +110,14 @@ class LogProcessor:
|
|||||||
by_epoch=True,
|
by_epoch=True,
|
||||||
custom_cfg: Optional[List[dict]] = None,
|
custom_cfg: Optional[List[dict]] = None,
|
||||||
num_digits: int = 4,
|
num_digits: int = 4,
|
||||||
log_with_hierarchy: bool = False):
|
log_with_hierarchy: bool = False,
|
||||||
|
mean_pattern=r'.*(loss|time|data_time|grad_norm).*'):
|
||||||
self.window_size = window_size
|
self.window_size = window_size
|
||||||
self.by_epoch = by_epoch
|
self.by_epoch = by_epoch
|
||||||
self.custom_cfg = custom_cfg if custom_cfg else []
|
self.custom_cfg = custom_cfg if custom_cfg else []
|
||||||
self.num_digits = num_digits
|
self.num_digits = num_digits
|
||||||
self.log_with_hierarchy = log_with_hierarchy
|
self.log_with_hierarchy = log_with_hierarchy
|
||||||
|
self.mean_pattern = re.compile(mean_pattern)
|
||||||
self._check_custom_cfg()
|
self._check_custom_cfg()
|
||||||
|
|
||||||
def get_log_after_iter(self, runner, batch_idx: int,
|
def get_log_after_iter(self, runner, batch_idx: int,
|
||||||
@ -280,14 +286,11 @@ class LogProcessor:
|
|||||||
# Count the averaged time and data_time by epoch
|
# Count the averaged time and data_time by epoch
|
||||||
if 'time' not in custom_keys:
|
if 'time' not in custom_keys:
|
||||||
custom_cfg_copy.append(
|
custom_cfg_copy.append(
|
||||||
dict(
|
dict(data_src='time', window_size='epoch', method_name='mean'))
|
||||||
data_src=f'{mode}/time',
|
|
||||||
window_size='epoch',
|
|
||||||
method_name='mean'))
|
|
||||||
if 'data_time' not in custom_keys:
|
if 'data_time' not in custom_keys:
|
||||||
custom_cfg_copy.append(
|
custom_cfg_copy.append(
|
||||||
dict(
|
dict(
|
||||||
data_src=f'{mode}/data_time',
|
data_src='data_time',
|
||||||
window_size='epoch',
|
window_size='epoch',
|
||||||
method_name='mean'))
|
method_name='mean'))
|
||||||
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
parsed_cfg = self._parse_windows_size(runner, batch_idx,
|
||||||
@ -358,18 +361,19 @@ class LogProcessor:
|
|||||||
mode_history_scalars[key] = log_buffer
|
mode_history_scalars[key] = log_buffer
|
||||||
for key in mode_history_scalars:
|
for key in mode_history_scalars:
|
||||||
# Update the latest learning rate and smoothed time logs.
|
# Update the latest learning rate and smoothed time logs.
|
||||||
if 'loss' in key or key in ('time', 'data_time', 'grad_norm'):
|
if re.search(self.mean_pattern, key) is not None:
|
||||||
tag[key] = mode_history_scalars[key].mean(self.window_size)
|
tag[key] = mode_history_scalars[key].mean(self.window_size)
|
||||||
else:
|
else:
|
||||||
# Default statistic method is current.
|
# Default statistic method is current.
|
||||||
tag[key] = mode_history_scalars[key].current()
|
tag[key] = mode_history_scalars[key].current()
|
||||||
# Update custom keys.
|
# Update custom keys.
|
||||||
for log_cfg in custom_cfg:
|
for log_cfg in custom_cfg:
|
||||||
data_src = log_cfg.pop('data_src')
|
if not reserve_prefix:
|
||||||
if 'log_name' in log_cfg:
|
data_src = log_cfg.pop('data_src')
|
||||||
log_name = log_cfg.pop('log_name')
|
log_name = f"{log_cfg.pop('log_name', data_src)}"
|
||||||
else:
|
else:
|
||||||
log_name = data_src
|
data_src = f"{mode}/{log_cfg.pop('data_src')}"
|
||||||
|
log_name = f"{mode}/{log_cfg.pop('log_name', data_src)}"
|
||||||
# log item in custom_cfg could only exist in train or val
|
# log item in custom_cfg could only exist in train or val
|
||||||
# mode.
|
# mode.
|
||||||
if data_src in mode_history_scalars:
|
if data_src in mode_history_scalars:
|
||||||
|
@ -13,6 +13,11 @@ else:
|
|||||||
array_method.append(torch.tensor)
|
array_method.append(torch.tensor)
|
||||||
|
|
||||||
|
|
||||||
|
@HistoryBuffer.register_statistics
|
||||||
|
def custom_statistics(self):
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
class TestLoggerBuffer:
|
class TestLoggerBuffer:
|
||||||
|
|
||||||
def test_init(self):
|
def test_init(self):
|
||||||
@ -112,10 +117,5 @@ class TestLoggerBuffer:
|
|||||||
log_buffer.statistics('unknown')
|
log_buffer.statistics('unknown')
|
||||||
|
|
||||||
def test_register_statistics(self):
|
def test_register_statistics(self):
|
||||||
|
|
||||||
@HistoryBuffer.register_statistics
|
|
||||||
def custom_statistics(self):
|
|
||||||
return -1
|
|
||||||
|
|
||||||
log_buffer = HistoryBuffer()
|
log_buffer = HistoryBuffer()
|
||||||
assert log_buffer.statistics('custom_statistics') == -1
|
assert log_buffer.statistics('custom_statistics') == -1
|
||||||
|
@ -201,18 +201,32 @@ class TestLogProcessor(RunnerTestCase):
|
|||||||
self.runner.message_hub._log_scalars = log_scalars
|
self.runner.message_hub._log_scalars = log_scalars
|
||||||
tag = log_processor._collect_scalars(
|
tag = log_processor._collect_scalars(
|
||||||
copy.deepcopy(custom_cfg), self.runner, mode='train')
|
copy.deepcopy(custom_cfg), self.runner, mode='train')
|
||||||
# Test training key in tag.
|
# Training key in tag.
|
||||||
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
|
assert list(tag.keys()) == ['time', 'loss_cls', 'time_max']
|
||||||
# Test statistics lr with `current`, loss and time with 'mean'
|
# Test statistics lr with `current`, loss and time with 'mean'
|
||||||
assert tag['time'] == time_scalars[-10:].mean()
|
assert tag['time'] == time_scalars[-10:].mean()
|
||||||
assert tag['time_max'] == time_scalars.max()
|
assert tag['time_max'] == time_scalars.max()
|
||||||
assert tag['loss_cls'] == loss_cls_scalars[-10:].mean()
|
assert tag['loss_cls'] == loss_cls_scalars[-10:].mean()
|
||||||
|
|
||||||
|
# Validation key in tag
|
||||||
tag = log_processor._collect_scalars(
|
tag = log_processor._collect_scalars(
|
||||||
copy.deepcopy(custom_cfg), self.runner, mode='val')
|
copy.deepcopy(custom_cfg), self.runner, mode='val')
|
||||||
assert list(tag.keys()) == ['metric']
|
assert list(tag.keys()) == ['metric']
|
||||||
assert tag['metric'] == metric_scalars[-1]
|
assert tag['metric'] == metric_scalars[-1]
|
||||||
|
|
||||||
|
# reserve_prefix=True
|
||||||
|
tag = log_processor._collect_scalars(
|
||||||
|
copy.deepcopy(custom_cfg),
|
||||||
|
self.runner,
|
||||||
|
mode='train',
|
||||||
|
reserve_prefix=True)
|
||||||
|
assert list(
|
||||||
|
tag.keys()) == ['train/time', 'train/loss_cls', 'train/time_max']
|
||||||
|
# Test statistics lr with `current`, loss and time with 'mean'
|
||||||
|
assert tag['train/time'] == time_scalars[-10:].mean()
|
||||||
|
assert tag['train/time_max'] == time_scalars.max()
|
||||||
|
assert tag['train/loss_cls'] == loss_cls_scalars[-10:].mean()
|
||||||
|
|
||||||
def test_collect_non_scalars(self):
|
def test_collect_non_scalars(self):
|
||||||
metric1 = np.random.rand(10)
|
metric1 = np.random.rand(10)
|
||||||
metric2 = torch.tensor(10)
|
metric2 = torch.tensor(10)
|
||||||
|
@ -21,7 +21,7 @@ from mmengine.evaluator import BaseMetric, Evaluator
|
|||||||
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
from mmengine.hooks import (CheckpointHook, DistSamplerSeedHook, Hook,
|
||||||
IterTimerHook, LoggerHook, ParamSchedulerHook,
|
IterTimerHook, LoggerHook, ParamSchedulerHook,
|
||||||
RuntimeInfoHook)
|
RuntimeInfoHook)
|
||||||
from mmengine.logging import MessageHub, MMLogger
|
from mmengine.logging import HistoryBuffer, MessageHub, MMLogger
|
||||||
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
from mmengine.model import BaseDataPreprocessor, BaseModel, ImgDataPreprocessor
|
||||||
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
from mmengine.optim import (DefaultOptimWrapperConstructor, MultiStepLR,
|
||||||
OptimWrapper, OptimWrapperDict, StepLR)
|
OptimWrapper, OptimWrapperDict, StepLR)
|
||||||
@ -2290,7 +2290,7 @@ class TestRunner(TestCase):
|
|||||||
runner = Runner.from_cfg(cfg)
|
runner = Runner.from_cfg(cfg)
|
||||||
runner.train()
|
runner.train()
|
||||||
|
|
||||||
# 2.1 test `save_checkpoint` which is called by `CheckpointHook`
|
# 2.1.1 test `save_checkpoint` which is called by `CheckpointHook`
|
||||||
path = osp.join(self.temp_dir, 'iter_12.pth')
|
path = osp.join(self.temp_dir, 'iter_12.pth')
|
||||||
self.assertTrue(osp.exists(path))
|
self.assertTrue(osp.exists(path))
|
||||||
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
|
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
|
||||||
@ -2304,6 +2304,10 @@ class TestRunner(TestCase):
|
|||||||
message_hub.load_state_dict(ckpt['message_hub'])
|
message_hub.load_state_dict(ckpt['message_hub'])
|
||||||
self.assertEqual(message_hub.get_info('epoch'), 0)
|
self.assertEqual(message_hub.get_info('epoch'), 0)
|
||||||
self.assertEqual(message_hub.get_info('iter'), 11)
|
self.assertEqual(message_hub.get_info('iter'), 11)
|
||||||
|
# 2.1.2 check class attribute _statistic_methods can be saved
|
||||||
|
HistoryBuffer._statistics_methods.clear()
|
||||||
|
ckpt = torch.load(path)
|
||||||
|
self.assertIn('min', HistoryBuffer._statistics_methods)
|
||||||
|
|
||||||
# 2.2 test `load_checkpoint`
|
# 2.2 test `load_checkpoint`
|
||||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user