mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] fix resume message_hub (#353)
* fix resume message_hub * add unit test * support resume from messagehub * minor refine * add comment * fix typo * update docstring
This commit is contained in:
parent
c9c6d454f1
commit
78fad67d0d
@ -1,4 +1,6 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
@ -7,6 +9,7 @@ import torch
|
|||||||
|
|
||||||
from mmengine.utils import ManagerMixin
|
from mmengine.utils import ManagerMixin
|
||||||
from .history_buffer import HistoryBuffer
|
from .history_buffer import HistoryBuffer
|
||||||
|
from .logger import print_log
|
||||||
|
|
||||||
|
|
||||||
class MessageHub(ManagerMixin):
|
class MessageHub(ManagerMixin):
|
||||||
@ -217,7 +220,7 @@ class MessageHub(ManagerMixin):
|
|||||||
else:
|
else:
|
||||||
assert self._resumed_keys[key] == resumed, \
|
assert self._resumed_keys[key] == resumed, \
|
||||||
f'{key} used to be {self._resumed_keys[key]}, but got ' \
|
f'{key} used to be {self._resumed_keys[key]}, but got ' \
|
||||||
'{resumed} now. resumed keys cannot be modified repeatedly'
|
'{resumed} now. resumed keys cannot be modified repeatedly.'
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def log_scalars(self) -> OrderedDict:
|
def log_scalars(self) -> OrderedDict:
|
||||||
@ -301,20 +304,68 @@ class MessageHub(ManagerMixin):
|
|||||||
assert isinstance(value, (int, float))
|
assert isinstance(value, (int, float))
|
||||||
return value # type: ignore
|
return value # type: ignore
|
||||||
|
|
||||||
def __getstate__(self):
|
def state_dict(self) -> dict:
|
||||||
for key in list(self._log_scalars.keys()):
|
"""Returns a dictionary containing log scalars, runtime information and
|
||||||
assert key in self._resumed_keys, (
|
resumed keys, which should be resumed.
|
||||||
f'Cannot found {key} in {self}._resumed_keys, '
|
|
||||||
'please make sure you do not change the _resumed_keys '
|
|
||||||
'outside the class')
|
|
||||||
if not self._resumed_keys[key]:
|
|
||||||
self._log_scalars.pop(key)
|
|
||||||
|
|
||||||
for key in list(self._runtime_info.keys()):
|
The returned ``state_dict`` can be loaded by :meth:`load_state_dict`.
|
||||||
assert key in self._resumed_keys, (
|
|
||||||
f'Cannot found {key} in {self}._resumed_keys, '
|
Returns:
|
||||||
'please make sure you do not change the _resumed_keys '
|
dict: A dictionary contains ``log_scalars``, ``runtime_info`` and
|
||||||
'outside the class')
|
``resumed_keys``.
|
||||||
if not self._resumed_keys[key]:
|
"""
|
||||||
self._runtime_info.pop(key)
|
saved_scalars = OrderedDict()
|
||||||
return self.__dict__
|
saved_info = OrderedDict()
|
||||||
|
|
||||||
|
for key, value in self._log_scalars.items():
|
||||||
|
if self._resumed_keys.get(key, False):
|
||||||
|
saved_scalars[key] = copy.deepcopy(value)
|
||||||
|
|
||||||
|
for key, value in self._runtime_info.items():
|
||||||
|
if self._resumed_keys.get(key, False):
|
||||||
|
try:
|
||||||
|
saved_info[key] = copy.deepcopy(value)
|
||||||
|
except: # noqa: E722
|
||||||
|
print_log(
|
||||||
|
f'{key} in message_hub cannot be copied, '
|
||||||
|
f'just return its reference. ',
|
||||||
|
logger='current',
|
||||||
|
level=logging.WARNING)
|
||||||
|
saved_scalars[key] = value
|
||||||
|
return dict(
|
||||||
|
log_scalars=saved_scalars,
|
||||||
|
runtime_info=saved_info,
|
||||||
|
resumed_keys=self._resumed_keys)
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: Union['MessageHub', dict]) -> None:
|
||||||
|
"""Loads log scalars, runtime information and resumed keys from
|
||||||
|
``state_dict`` or ``message_hub``.
|
||||||
|
|
||||||
|
If ``state_dict`` is a dictionary returned by :meth:`state_dict`, it
|
||||||
|
will only make copies of data which should be resumed from the source
|
||||||
|
``message_hub``.
|
||||||
|
|
||||||
|
If ``state_dict`` is a ``message_hub`` instance, it will make copies of
|
||||||
|
all data from the source message_hub. We suggest to load data from
|
||||||
|
``dict`` rather than a ``MessageHub`` instance.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict or MessageHub): A dictionary contains key
|
||||||
|
``log_scalars`` ``runtime_info`` and ``resumed_keys``, or a
|
||||||
|
MessageHub instance.
|
||||||
|
"""
|
||||||
|
if isinstance(state_dict, dict):
|
||||||
|
for key in ('log_scalars', 'runtime_info', 'resumed_keys'):
|
||||||
|
assert key in state_dict, (
|
||||||
|
'The loaded `state_dict` of `MessageHub` must contain '
|
||||||
|
f'key: `{key}`')
|
||||||
|
self._log_scalars = copy.deepcopy(state_dict['log_scalars'])
|
||||||
|
self._runtime_info = copy.deepcopy(state_dict['runtime_info'])
|
||||||
|
self._resumed_keys = copy.deepcopy(state_dict['resumed_keys'])
|
||||||
|
# Since some checkpoints saved serialized `message_hub` instance,
|
||||||
|
# `load_state_dict` support loading `message_hub` instance for
|
||||||
|
# compatibility
|
||||||
|
else:
|
||||||
|
self._log_scalars = copy.deepcopy(state_dict._log_scalars)
|
||||||
|
self._runtime_info = copy.deepcopy(state_dict._runtime_info)
|
||||||
|
self._resumed_keys = copy.deepcopy(state_dict._resumed_keys)
|
||||||
|
@ -1881,7 +1881,7 @@ class Runner:
|
|||||||
'check the correctness of the checkpoint or the training '
|
'check the correctness of the checkpoint or the training '
|
||||||
'dataset.')
|
'dataset.')
|
||||||
|
|
||||||
self.message_hub = checkpoint['message_hub']
|
self.message_hub.load_state_dict(checkpoint['message_hub'])
|
||||||
|
|
||||||
# resume optimizer
|
# resume optimizer
|
||||||
if 'optimizer' in checkpoint and resume_optimizer:
|
if 'optimizer' in checkpoint and resume_optimizer:
|
||||||
@ -2008,7 +2008,7 @@ class Runner:
|
|||||||
checkpoint = {
|
checkpoint = {
|
||||||
'meta': meta,
|
'meta': meta,
|
||||||
'state_dict': weights_to_cpu(get_state_dict(model)),
|
'state_dict': weights_to_cpu(get_state_dict(model)),
|
||||||
'message_hub': self.message_hub
|
'message_hub': self.message_hub.state_dict()
|
||||||
}
|
}
|
||||||
# save optimizer state dict to checkpoint
|
# save optimizer state dict to checkpoint
|
||||||
if save_optimizer:
|
if save_optimizer:
|
||||||
|
@ -94,6 +94,45 @@ class TestMessageHub:
|
|||||||
loss_dict = dict(error_type=dict(count=1))
|
loss_dict = dict(error_type=dict(count=1))
|
||||||
message_hub.update_scalars(loss_dict)
|
message_hub.update_scalars(loss_dict)
|
||||||
|
|
||||||
|
def test_state_dict(self):
|
||||||
|
message_hub = MessageHub.get_instance('test_state_dict')
|
||||||
|
# update log_scalars.
|
||||||
|
message_hub.update_scalar('loss', 0.1)
|
||||||
|
message_hub.update_scalar('lr', 0.1, resumed=False)
|
||||||
|
# update runtime information
|
||||||
|
message_hub.update_info('iter', 1, resumed=True)
|
||||||
|
message_hub.update_info('tensor', [1, 2, 3], resumed=False)
|
||||||
|
state_dict = message_hub.state_dict()
|
||||||
|
assert state_dict['log_scalars']['loss'].data == (np.array([0.1]),
|
||||||
|
np.array([1]))
|
||||||
|
assert 'lr' not in state_dict['log_scalars']
|
||||||
|
assert state_dict['runtime_info']['iter'] == 1
|
||||||
|
assert 'tensor' not in state_dict
|
||||||
|
|
||||||
|
def test_load_state_dict(self):
|
||||||
|
message_hub1 = MessageHub.get_instance('test_load_state_dict1')
|
||||||
|
# update log_scalars.
|
||||||
|
message_hub1.update_scalar('loss', 0.1)
|
||||||
|
message_hub1.update_scalar('lr', 0.1, resumed=False)
|
||||||
|
# update runtime information
|
||||||
|
message_hub1.update_info('iter', 1, resumed=True)
|
||||||
|
message_hub1.update_info('tensor', [1, 2, 3], resumed=False)
|
||||||
|
state_dict = message_hub1.state_dict()
|
||||||
|
|
||||||
|
# Resume from state_dict
|
||||||
|
message_hub2 = MessageHub.get_instance('test_load_state_dict2')
|
||||||
|
message_hub2.load_state_dict(state_dict)
|
||||||
|
assert message_hub2.get_scalar('loss').data == (np.array([0.1]),
|
||||||
|
np.array([1]))
|
||||||
|
assert message_hub2.get_info('iter') == 1
|
||||||
|
|
||||||
|
# Test resume from `MessageHub` instance.
|
||||||
|
message_hub3 = MessageHub.get_instance('test_load_state_dict3')
|
||||||
|
message_hub3.load_state_dict(state_dict)
|
||||||
|
assert message_hub3.get_scalar('loss').data == (np.array([0.1]),
|
||||||
|
np.array([1]))
|
||||||
|
assert message_hub3.get_info('iter') == 1
|
||||||
|
|
||||||
def test_getstate(self):
|
def test_getstate(self):
|
||||||
message_hub = MessageHub.get_instance('name')
|
message_hub = MessageHub.get_instance('name')
|
||||||
# update log_scalars.
|
# update log_scalars.
|
||||||
@ -116,8 +155,8 @@ class TestMessageHub:
|
|||||||
def test_get_instance(self):
|
def test_get_instance(self):
|
||||||
# Test get root mmengine message hub.
|
# Test get root mmengine message hub.
|
||||||
MessageHub._instance_dict = OrderedDict()
|
MessageHub._instance_dict = OrderedDict()
|
||||||
root_logger = MessageHub.get_current_instance()
|
message_hub = MessageHub.get_current_instance()
|
||||||
assert id(MessageHub.get_instance('mmengine')) == id(root_logger)
|
assert id(MessageHub.get_instance('mmengine')) == id(message_hub)
|
||||||
# Test original `get_current_instance` function.
|
# Test original `get_current_instance` function.
|
||||||
MessageHub.get_instance('mmdet')
|
MessageHub.get_instance('mmdet')
|
||||||
assert MessageHub.get_current_instance().instance_name == 'mmdet'
|
assert MessageHub.get_current_instance().instance_name == 'mmdet'
|
||||||
|
@ -1689,9 +1689,11 @@ class TestRunner(TestCase):
|
|||||||
self.assertEqual(ckpt['meta']['seed'], runner.seed)
|
self.assertEqual(ckpt['meta']['seed'], runner.seed)
|
||||||
assert isinstance(ckpt['optimizer'], dict)
|
assert isinstance(ckpt['optimizer'], dict)
|
||||||
assert isinstance(ckpt['param_schedulers'], list)
|
assert isinstance(ckpt['param_schedulers'], list)
|
||||||
self.assertIsInstance(ckpt['message_hub'], MessageHub)
|
self.assertIsInstance(ckpt['message_hub'], dict)
|
||||||
self.assertEqual(ckpt['message_hub'].get_info('epoch'), 2)
|
message_hub = MessageHub.get_instance('test_ckpt')
|
||||||
self.assertEqual(ckpt['message_hub'].get_info('iter'), 11)
|
message_hub.load_state_dict(ckpt['message_hub'])
|
||||||
|
self.assertEqual(message_hub.get_info('epoch'), 2)
|
||||||
|
self.assertEqual(message_hub.get_info('iter'), 11)
|
||||||
|
|
||||||
# 1.2 test `load_checkpoint`
|
# 1.2 test `load_checkpoint`
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
@ -1728,6 +1730,10 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(runner.message_hub, MessageHub)
|
self.assertIsInstance(runner.message_hub, MessageHub)
|
||||||
self.assertEqual(runner.message_hub.get_info('epoch'), 2)
|
self.assertEqual(runner.message_hub.get_info('epoch'), 2)
|
||||||
self.assertEqual(runner.message_hub.get_info('iter'), 11)
|
self.assertEqual(runner.message_hub.get_info('iter'), 11)
|
||||||
|
self.assertEqual(MessageHub.get_current_instance().get_info('epoch'),
|
||||||
|
2)
|
||||||
|
self.assertEqual(MessageHub.get_current_instance().get_info('iter'),
|
||||||
|
11)
|
||||||
|
|
||||||
# 1.3.2 test resume with unmatched dataset_meta
|
# 1.3.2 test resume with unmatched dataset_meta
|
||||||
ckpt_modified = copy.deepcopy(ckpt)
|
ckpt_modified = copy.deepcopy(ckpt)
|
||||||
@ -1856,9 +1862,10 @@ class TestRunner(TestCase):
|
|||||||
self.assertEqual(ckpt['meta']['iter'], 12)
|
self.assertEqual(ckpt['meta']['iter'], 12)
|
||||||
assert isinstance(ckpt['optimizer'], dict)
|
assert isinstance(ckpt['optimizer'], dict)
|
||||||
assert isinstance(ckpt['param_schedulers'], list)
|
assert isinstance(ckpt['param_schedulers'], list)
|
||||||
self.assertIsInstance(ckpt['message_hub'], MessageHub)
|
self.assertIsInstance(ckpt['message_hub'], dict)
|
||||||
self.assertEqual(ckpt['message_hub'].get_info('epoch'), 0)
|
message_hub.load_state_dict(ckpt['message_hub'])
|
||||||
self.assertEqual(ckpt['message_hub'].get_info('iter'), 11)
|
self.assertEqual(message_hub.get_info('epoch'), 0)
|
||||||
|
self.assertEqual(message_hub.get_info('iter'), 11)
|
||||||
|
|
||||||
# 2.2 test `load_checkpoint`
|
# 2.2 test `load_checkpoint`
|
||||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
@ -1907,6 +1914,17 @@ class TestRunner(TestCase):
|
|||||||
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
|
||||||
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
|
||||||
|
|
||||||
|
# 2.6 test resumed message_hub has the history value.
|
||||||
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.experiment_name = 'test_checkpoint13'
|
||||||
|
cfg.resume = True
|
||||||
|
cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth')
|
||||||
|
runner = Runner.from_cfg(cfg)
|
||||||
|
runner.load_or_resume()
|
||||||
|
assert len(runner.message_hub.log_scalars['train/lr'].data[1]) == 3
|
||||||
|
assert len(MessageHub.get_current_instance().log_scalars['train/lr'].
|
||||||
|
data[1]) == 3
|
||||||
|
|
||||||
def test_build_runner(self):
|
def test_build_runner(self):
|
||||||
# No need to test other cases which have been tested in
|
# No need to test other cases which have been tested in
|
||||||
# `test_build_from_cfg`
|
# `test_build_from_cfg`
|
||||||
|
Loading…
x
Reference in New Issue
Block a user