[Fix] fix resume message_hub (#353)

* fix resume message_hub

* add unit test

* support resume from messagehub

* minor refine

* add comment

* fix typo

* update docstring
This commit is contained in:
Mashiro 2022-07-14 20:13:22 +08:00 committed by GitHub
parent c9c6d454f1
commit 78fad67d0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 135 additions and 27 deletions

View File

@ -1,4 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import logging
from collections import OrderedDict
from typing import Any, Optional, Union
@ -7,6 +9,7 @@ import torch
from mmengine.utils import ManagerMixin
from .history_buffer import HistoryBuffer
from .logger import print_log
class MessageHub(ManagerMixin):
@ -217,7 +220,7 @@ class MessageHub(ManagerMixin):
else:
assert self._resumed_keys[key] == resumed, \
f'{key} used to be {self._resumed_keys[key]}, but got ' \
'{resumed} now. resumed keys cannot be modified repeatedly'
'{resumed} now. resumed keys cannot be modified repeatedly.'
@property
def log_scalars(self) -> OrderedDict:
@ -301,20 +304,68 @@ class MessageHub(ManagerMixin):
assert isinstance(value, (int, float))
return value # type: ignore
def __getstate__(self):
for key in list(self._log_scalars.keys()):
assert key in self._resumed_keys, (
f'Cannot found {key} in {self}._resumed_keys, '
'please make sure you do not change the _resumed_keys '
'outside the class')
if not self._resumed_keys[key]:
self._log_scalars.pop(key)
def state_dict(self) -> dict:
"""Returns a dictionary containing log scalars, runtime information and
resumed keys, which should be resumed.
for key in list(self._runtime_info.keys()):
assert key in self._resumed_keys, (
f'Cannot found {key} in {self}._resumed_keys, '
'please make sure you do not change the _resumed_keys '
'outside the class')
if not self._resumed_keys[key]:
self._runtime_info.pop(key)
return self.__dict__
The returned ``state_dict`` can be loaded by :meth:`load_state_dict`.
Returns:
dict: A dictionary contains ``log_scalars``, ``runtime_info`` and
``resumed_keys``.
"""
saved_scalars = OrderedDict()
saved_info = OrderedDict()
for key, value in self._log_scalars.items():
if self._resumed_keys.get(key, False):
saved_scalars[key] = copy.deepcopy(value)
for key, value in self._runtime_info.items():
if self._resumed_keys.get(key, False):
try:
saved_info[key] = copy.deepcopy(value)
except: # noqa: E722
print_log(
f'{key} in message_hub cannot be copied, '
f'just return its reference. ',
logger='current',
level=logging.WARNING)
saved_scalars[key] = value
return dict(
log_scalars=saved_scalars,
runtime_info=saved_info,
resumed_keys=self._resumed_keys)
def load_state_dict(self, state_dict: Union['MessageHub', dict]) -> None:
"""Loads log scalars, runtime information and resumed keys from
``state_dict`` or ``message_hub``.
If ``state_dict`` is a dictionary returned by :meth:`state_dict`, it
will only make copies of data which should be resumed from the source
``message_hub``.
If ``state_dict`` is a ``message_hub`` instance, it will make copies of
all data from the source message_hub. We suggest to load data from
``dict`` rather than a ``MessageHub`` instance.
Args:
state_dict (dict or MessageHub): A dictionary contains key
``log_scalars`` ``runtime_info`` and ``resumed_keys``, or a
MessageHub instance.
"""
if isinstance(state_dict, dict):
for key in ('log_scalars', 'runtime_info', 'resumed_keys'):
assert key in state_dict, (
'The loaded `state_dict` of `MessageHub` must contain '
f'key: `{key}`')
self._log_scalars = copy.deepcopy(state_dict['log_scalars'])
self._runtime_info = copy.deepcopy(state_dict['runtime_info'])
self._resumed_keys = copy.deepcopy(state_dict['resumed_keys'])
# Since some checkpoints saved serialized `message_hub` instance,
# `load_state_dict` support loading `message_hub` instance for
# compatibility
else:
self._log_scalars = copy.deepcopy(state_dict._log_scalars)
self._runtime_info = copy.deepcopy(state_dict._runtime_info)
self._resumed_keys = copy.deepcopy(state_dict._resumed_keys)

View File

@ -1881,7 +1881,7 @@ class Runner:
'check the correctness of the checkpoint or the training '
'dataset.')
self.message_hub = checkpoint['message_hub']
self.message_hub.load_state_dict(checkpoint['message_hub'])
# resume optimizer
if 'optimizer' in checkpoint and resume_optimizer:
@ -2008,7 +2008,7 @@ class Runner:
checkpoint = {
'meta': meta,
'state_dict': weights_to_cpu(get_state_dict(model)),
'message_hub': self.message_hub
'message_hub': self.message_hub.state_dict()
}
# save optimizer state dict to checkpoint
if save_optimizer:

View File

@ -94,6 +94,45 @@ class TestMessageHub:
loss_dict = dict(error_type=dict(count=1))
message_hub.update_scalars(loss_dict)
def test_state_dict(self):
message_hub = MessageHub.get_instance('test_state_dict')
# update log_scalars.
message_hub.update_scalar('loss', 0.1)
message_hub.update_scalar('lr', 0.1, resumed=False)
# update runtime information
message_hub.update_info('iter', 1, resumed=True)
message_hub.update_info('tensor', [1, 2, 3], resumed=False)
state_dict = message_hub.state_dict()
assert state_dict['log_scalars']['loss'].data == (np.array([0.1]),
np.array([1]))
assert 'lr' not in state_dict['log_scalars']
assert state_dict['runtime_info']['iter'] == 1
assert 'tensor' not in state_dict
def test_load_state_dict(self):
message_hub1 = MessageHub.get_instance('test_load_state_dict1')
# update log_scalars.
message_hub1.update_scalar('loss', 0.1)
message_hub1.update_scalar('lr', 0.1, resumed=False)
# update runtime information
message_hub1.update_info('iter', 1, resumed=True)
message_hub1.update_info('tensor', [1, 2, 3], resumed=False)
state_dict = message_hub1.state_dict()
# Resume from state_dict
message_hub2 = MessageHub.get_instance('test_load_state_dict2')
message_hub2.load_state_dict(state_dict)
assert message_hub2.get_scalar('loss').data == (np.array([0.1]),
np.array([1]))
assert message_hub2.get_info('iter') == 1
# Test resume from `MessageHub` instance.
message_hub3 = MessageHub.get_instance('test_load_state_dict3')
message_hub3.load_state_dict(state_dict)
assert message_hub3.get_scalar('loss').data == (np.array([0.1]),
np.array([1]))
assert message_hub3.get_info('iter') == 1
def test_getstate(self):
message_hub = MessageHub.get_instance('name')
# update log_scalars.
@ -116,8 +155,8 @@ class TestMessageHub:
def test_get_instance(self):
# Test get root mmengine message hub.
MessageHub._instance_dict = OrderedDict()
root_logger = MessageHub.get_current_instance()
assert id(MessageHub.get_instance('mmengine')) == id(root_logger)
message_hub = MessageHub.get_current_instance()
assert id(MessageHub.get_instance('mmengine')) == id(message_hub)
# Test original `get_current_instance` function.
MessageHub.get_instance('mmdet')
assert MessageHub.get_current_instance().instance_name == 'mmdet'

View File

@ -1689,9 +1689,11 @@ class TestRunner(TestCase):
self.assertEqual(ckpt['meta']['seed'], runner.seed)
assert isinstance(ckpt['optimizer'], dict)
assert isinstance(ckpt['param_schedulers'], list)
self.assertIsInstance(ckpt['message_hub'], MessageHub)
self.assertEqual(ckpt['message_hub'].get_info('epoch'), 2)
self.assertEqual(ckpt['message_hub'].get_info('iter'), 11)
self.assertIsInstance(ckpt['message_hub'], dict)
message_hub = MessageHub.get_instance('test_ckpt')
message_hub.load_state_dict(ckpt['message_hub'])
self.assertEqual(message_hub.get_info('epoch'), 2)
self.assertEqual(message_hub.get_info('iter'), 11)
# 1.2 test `load_checkpoint`
cfg = copy.deepcopy(self.epoch_based_cfg)
@ -1728,6 +1730,10 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.message_hub, MessageHub)
self.assertEqual(runner.message_hub.get_info('epoch'), 2)
self.assertEqual(runner.message_hub.get_info('iter'), 11)
self.assertEqual(MessageHub.get_current_instance().get_info('epoch'),
2)
self.assertEqual(MessageHub.get_current_instance().get_info('iter'),
11)
# 1.3.2 test resume with unmatched dataset_meta
ckpt_modified = copy.deepcopy(ckpt)
@ -1856,9 +1862,10 @@ class TestRunner(TestCase):
self.assertEqual(ckpt['meta']['iter'], 12)
assert isinstance(ckpt['optimizer'], dict)
assert isinstance(ckpt['param_schedulers'], list)
self.assertIsInstance(ckpt['message_hub'], MessageHub)
self.assertEqual(ckpt['message_hub'].get_info('epoch'), 0)
self.assertEqual(ckpt['message_hub'].get_info('iter'), 11)
self.assertIsInstance(ckpt['message_hub'], dict)
message_hub.load_state_dict(ckpt['message_hub'])
self.assertEqual(message_hub.get_info('epoch'), 0)
self.assertEqual(message_hub.get_info('iter'), 11)
# 2.2 test `load_checkpoint`
cfg = copy.deepcopy(self.iter_based_cfg)
@ -1907,6 +1914,17 @@ class TestRunner(TestCase):
self.assertIsInstance(runner.optim_wrapper.optimizer, SGD)
self.assertIsInstance(runner.param_schedulers[0], MultiStepLR)
# 2.6 test resumed message_hub has the history value.
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.experiment_name = 'test_checkpoint13'
cfg.resume = True
cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth')
runner = Runner.from_cfg(cfg)
runner.load_or_resume()
assert len(runner.message_hub.log_scalars['train/lr'].data[1]) == 3
assert len(MessageHub.get_current_instance().log_scalars['train/lr'].
data[1]) == 3
def test_build_runner(self):
# No need to test other cases which have been tested in
# `test_build_from_cfg`