mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] fix resume message_hub (#353)
* fix resume message_hub * add unit test * support resume from messagehub * minor refine * add comment * fix typo * update docstring
This commit is contained in:
parent
c9c6d454f1
commit
78fad67d0d
@ -1,4 +1,6 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
import logging
|
||||
from collections import OrderedDict
|
||||
from typing import Any, Optional, Union
|
||||
|
||||
@ -7,6 +9,7 @@ import torch
|
||||
|
||||
from mmengine.utils import ManagerMixin
|
||||
from .history_buffer import HistoryBuffer
|
||||
from .logger import print_log
|
||||
|
||||
|
||||
class MessageHub(ManagerMixin):
|
||||
@ -217,7 +220,7 @@ class MessageHub(ManagerMixin):
|
||||
else:
|
||||
assert self._resumed_keys[key] == resumed, \
|
||||
f'{key} used to be {self._resumed_keys[key]}, but got ' \
|
||||
'{resumed} now. resumed keys cannot be modified repeatedly'
|
||||
'{resumed} now. resumed keys cannot be modified repeatedly.'
|
||||
|
||||
@property
|
||||
def log_scalars(self) -> OrderedDict:
|
||||
@ -301,20 +304,68 @@ class MessageHub(ManagerMixin):
|
||||
assert isinstance(value, (int, float))
|
||||
return value # type: ignore
|
||||
|
||||
def __getstate__(self):
|
||||
for key in list(self._log_scalars.keys()):
|
||||
assert key in self._resumed_keys, (
|
||||
f'Cannot found {key} in {self}._resumed_keys, '
|
||||
'please make sure you do not change the _resumed_keys '
|
||||
'outside the class')
|
||||
if not self._resumed_keys[key]:
|
||||
self._log_scalars.pop(key)
|
||||
def state_dict(self) -> dict:
|
||||
"""Returns a dictionary containing log scalars, runtime information and
|
||||
resumed keys, which should be resumed.
|
||||
|
||||
for key in list(self._runtime_info.keys()):
|
||||
assert key in self._resumed_keys, (
|
||||
f'Cannot found {key} in {self}._resumed_keys, '
|
||||
'please make sure you do not change the _resumed_keys '
|
||||
'outside the class')
|
||||
if not self._resumed_keys[key]:
|
||||
self._runtime_info.pop(key)
|
||||
return self.__dict__
|
||||
The returned ``state_dict`` can be loaded by :meth:`load_state_dict`.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary contains ``log_scalars``, ``runtime_info`` and
|
||||
``resumed_keys``.
|
||||
"""
|
||||
saved_scalars = OrderedDict()
|
||||
saved_info = OrderedDict()
|
||||
|
||||
for key, value in self._log_scalars.items():
|
||||
if self._resumed_keys.get(key, False):
|
||||
saved_scalars[key] = copy.deepcopy(value)
|
||||
|
||||
for key, value in self._runtime_info.items():
|
||||
if self._resumed_keys.get(key, False):
|
||||
try:
|
||||
saved_info[key] = copy.deepcopy(value)
|
||||
except: # noqa: E722
|
||||
print_log(
|
||||
f'{key} in message_hub cannot be copied, '
|
||||
f'just return its reference. ',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
saved_scalars[key] = value
|
||||
return dict(
|
||||
log_scalars=saved_scalars,
|
||||
runtime_info=saved_info,
|
||||
resumed_keys=self._resumed_keys)
|
||||
|
||||
def load_state_dict(self, state_dict: Union['MessageHub', dict]) -> None:
|
||||
"""Loads log scalars, runtime information and resumed keys from
|
||||
``state_dict`` or ``message_hub``.
|
||||
|
||||
If ``state_dict`` is a dictionary returned by :meth:`state_dict`, it
|
||||
will only make copies of data which should be resumed from the source
|
||||
``message_hub``.
|
||||
|
||||
If ``state_dict`` is a ``message_hub`` instance, it will make copies of
|
||||
all data from the source message_hub. We suggest to load data from
|
||||
``dict`` rather than a ``MessageHub`` instance.
|
||||
|
||||
Args:
|
||||
state_dict (dict or MessageHub): A dictionary contains key
|
||||
``log_scalars`` ``runtime_info`` and ``resumed_keys``, or a
|
||||
MessageHub instance.
|
||||
"""
|
||||
if isinstance(state_dict, dict):
|
||||
for key in ('log_scalars', 'runtime_info', 'resumed_keys'):
|
||||
assert key in state_dict, (
|
||||
'The loaded `state_dict` of `MessageHub` must contain '
|
||||
f'key: `{key}`')
|
||||
self._log_scalars = copy.deepcopy(state_dict['log_scalars'])
|
||||
self._runtime_info = copy.deepcopy(state_dict['runtime_info'])
|
||||
self._resumed_keys = copy.deepcopy(state_dict['resumed_keys'])
|
||||
# Since some checkpoints saved serialized `message_hub` instance,
|
||||
# `load_state_dict` support loading `message_hub` instance for
|
||||
# compatibility
|
||||
else:
|
||||
self._log_scalars = copy.deepcopy(state_dict._log_scalars)
|
||||
self._runtime_info = copy.deepcopy(state_dict._runtime_info)
|
||||
self._resumed_keys = copy.deepcopy(state_dict._resumed_keys)
|
||||
|
@ -1881,7 +1881,7 @@ class Runner:
|
||||
'check the correctness of the checkpoint or the training '
|
||||
'dataset.')
|
||||
|
||||
self.message_hub = checkpoint['message_hub']
|
||||
self.message_hub.load_state_dict(checkpoint['message_hub'])
|
||||
|
||||
# resume optimizer
|
||||
if 'optimizer' in checkpoint and resume_optimizer:
|
||||
@ -2008,7 +2008,7 @@ class Runner:
|
||||
checkpoint = {
|
||||
'meta': meta,
|
||||
'state_dict': weights_to_cpu(get_state_dict(model)),
|
||||
'message_hub': self.message_hub
|
||||
'message_hub': self.message_hub.state_dict()
|
||||
}
|
||||
# save optimizer state dict to checkpoint
|
||||
if save_optimizer:
|
||||
|
@ -94,6 +94,45 @@ class TestMessageHub:
|
||||
loss_dict = dict(error_type=dict(count=1))
|
||||
message_hub.update_scalars(loss_dict)
|
||||
|
||||
def test_state_dict(self):
|
||||
message_hub = MessageHub.get_instance('test_state_dict')
|
||||
# update log_scalars.
|
||||
message_hub.update_scalar('loss', 0.1)
|
||||
message_hub.update_scalar('lr', 0.1, resumed=False)
|
||||
# update runtime information
|
||||
message_hub.update_info('iter', 1, resumed=True)
|
||||
message_hub.update_info('tensor', [1, 2, 3], resumed=False)
|
||||
state_dict = message_hub.state_dict()
|
||||
assert state_dict['log_scalars']['loss'].data == (np.array([0.1]),
|
||||
np.array([1]))
|
||||
assert 'lr' not in state_dict['log_scalars']
|
||||
assert state_dict['runtime_info']['iter'] == 1
|
||||
assert 'tensor' not in state_dict
|
||||
|
||||
def test_load_state_dict(self):
|
||||
message_hub1 = MessageHub.get_instance('test_load_state_dict1')
|
||||
# update log_scalars.
|
||||
message_hub1.update_scalar('loss', 0.1)
|
||||
message_hub1.update_scalar('lr', 0.1, resumed=False)
|
||||
# update runtime information
|
||||
message_hub1.update_info('iter', 1, resumed=True)
|
||||
message_hub1.update_info('tensor', [1, 2, 3], resumed=False)
|
||||
state_dict = message_hub1.state_dict()
|
||||
|
||||
# Resume from state_dict
|
||||
message_hub2 = MessageHub.get_instance('test_load_state_dict2')
|
||||
message_hub2.load_state_dict(state_dict)
|
||||
assert message_hub2.get_scalar('loss').data == (np.array([0.1]),
|
||||
np.array([1]))
|
||||
assert message_hub2.get_info('iter') == 1
|
||||
|
||||
# Test resume from `MessageHub` instance.
|
||||
message_hub3 = MessageHub.get_instance('test_load_state_dict3')
|
||||
message_hub3.load_state_dict(state_dict)
|
||||
assert message_hub3.get_scalar('loss').data == (np.array([0.1]),
|
||||
np.array([1]))
|
||||
assert message_hub3.get_info('iter') == 1
|
||||
|
||||
def test_getstate(self):
|
||||
message_hub = MessageHub.get_instance('name')
|
||||
# update log_scalars.
|
||||
@ -116,8 +155,8 @@ class TestMessageHub:
|
||||
def test_get_instance(self):
|
||||
# Test get root mmengine message hub.
|
||||
MessageHub._instance_dict = OrderedDict()
|
||||
root_logger = MessageHub.get_current_instance()
|
||||
assert id(MessageHub.get_instance('mmengine')) == id(root_logger)
|
||||
message_hub = MessageHub.get_current_instance()
|
||||
assert id(MessageHub.get_instance('mmengine')) == id(message_hub)
|
||||
# Test original `get_current_instance` function.
|
||||
MessageHub.get_instance('mmdet')
|
||||
assert MessageHub.get_current_instance().instance_name == 'mmdet'
|
||||
|
@ -1689,9 +1689,11 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(ckpt['meta']['seed'], runner.seed)
|
||||
assert isinstance(ckpt['optimizer'], dict)
|
||||
assert isinstance(ckpt['param_schedulers'], list)
|
||||
self.assertIsInstance(ckpt['message_hub'], MessageHub)
|
||||
self.assertEqual(ckpt['message_hub'].get_info('epoch'), 2)
|
||||
self.assertEqual(ckpt['message_hub'].get_info('iter'), 11)
|
||||
self.assertIsInstance(ckpt['message_hub'], dict)
|
||||
message_hub = MessageHub.get_instance('test_ckpt')
|
||||
message_hub.load_state_dict(ckpt['message_hub'])
|
||||
self.assertEqual(message_hub.get_info('epoch'), 2)
|
||||
self.assertEqual(message_hub.get_info('iter'), 11)
|
||||
|
||||
# 1.2 test `load_checkpoint`
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
@ -1728,6 +1730,10 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(runner.message_hub, MessageHub)
|
||||
self.assertEqual(runner.message_hub.get_info('epoch'), 2)
|
||||
self.assertEqual(runner.message_hub.get_info('iter'), 11)
|
||||
self.assertEqual(MessageHub.get_current_instance().get_info('epoch'),
|
||||
2)
|
||||
self.assertEqual(MessageHub.get_current_instance().get_info('iter'),
|
||||
11)
|
||||
|
||||
# 1.3.2 test resume with unmatched dataset_meta
|
||||
ckpt_modified = copy.deepcopy(ckpt)
|
||||
@ -1856,9 +1862,10 @@ class TestRunner(TestCase):
|
||||
self.assertEqual(ckpt['meta']['iter'], 12)
|
||||
assert isinstance(ckpt['optimizer'], dict)
|
||||
assert isinstance(ckpt['param_schedulers'], list)
|
||||
self.assertIsInstance(ckpt['message_hub'], MessageHub)
|
||||
self.assertEqual(ckpt['message_hub'].get_info('epoch'), 0)
|
||||
self.assertEqual(ckpt['message_hub'].get_info('iter'), 11)
|
||||
self.assertIsInstance(ckpt['message_hub'], dict)
|
||||
message_hub.load_state_dict(ckpt['message_hub'])
|
||||
self.assertEqual(message_hub.get_info('epoch'), 0)
|
||||
self.assertEqual(message_hub.get_info('iter'), 11)
|
||||
|
||||
# 2.2 test `load_checkpoint`
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
@ -1907,6 +1914,17 @@ class TestRunner(TestCase):
|
||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||
|
||||
# 2.6 test resumed message_hub has the history value.
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.experiment_name = 'test_checkpoint13'
|
||||
cfg.resume = True
|
||||
cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth')
|
||||
runner = Runner.from_cfg(cfg)
|
||||
runner.load_or_resume()
|
||||
assert len(runner.message_hub.log_scalars['train/lr'].data[1]) == 3
|
||||
assert len(MessageHub.get_current_instance().log_scalars['train/lr'].
|
||||
data[1]) == 3
|
||||
|
||||
def test_build_runner(self):
|
||||
# No need to test other cases which have been tested in
|
||||
# `test_build_from_cfg`
|
||||
|
Loading…
x
Reference in New Issue
Block a user