[Fix] Fix resume `message_hub` and save `metainfo` in message_hub. (#394)
* Fix resume message hub and save metainfo in messagehub * fix as commentpull/401/head
parent
eb25129935
commit
df4e6e3294
|
@ -1,7 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
from mmengine.registry import HOOKS
|
||||
from ..registry import HOOKS
|
||||
from ..utils import get_git_hash
|
||||
from .hook import Hook
|
||||
|
||||
DATA_BATCH = Optional[Sequence[dict]]
|
||||
|
@ -18,12 +19,23 @@ class RuntimeInfoHook(Hook):
|
|||
|
||||
priority = 'VERY_HIGH'
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
import mmengine
|
||||
metainfo = dict(
|
||||
cfg=runner.cfg.pretty_text,
|
||||
seed=runner.seed,
|
||||
experiment_name=runner.experiment_name,
|
||||
mmengine_version=mmengine.__version__ + get_git_hash())
|
||||
runner.message_hub.update_info_dict(metainfo)
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Update resumed training state."""
|
||||
runner.message_hub.update_info('epoch', runner.epoch)
|
||||
runner.message_hub.update_info('iter', runner.iter)
|
||||
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
||||
runner.message_hub.update_info('max_iters', runner.max_iters)
|
||||
runner.message_hub.update_info(
|
||||
'dataset_meta', runner.train_dataloader.dataset.metainfo)
|
||||
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
"""Update current epoch information before every epoch."""
|
||||
|
|
|
@ -121,7 +121,8 @@ class MessageHub(ManagerMixin):
|
|||
keys cannot be modified repeatedly'
|
||||
|
||||
Note:
|
||||
resumed cannot be set repeatedly for the same key.
|
||||
The ``resumed`` argument needs to be consistent for the same
|
||||
``key``.
|
||||
|
||||
Args:
|
||||
key (str): Key of ``HistoryBuffer``.
|
||||
|
@ -149,6 +150,10 @@ class MessageHub(ManagerMixin):
|
|||
be ``dict(value=xxx) or dict(value=xxx, count=xxx)``. Item in
|
||||
``log_dict`` has the same resume option.
|
||||
|
||||
Note:
|
||||
The ``resumed`` argument needs to be consistent for the same
|
||||
``log_dict``.
|
||||
|
||||
Args:
|
||||
log_dict (str): Used for batch updating :attr:`_log_scalars`.
|
||||
resumed (bool): Whether all ``HistoryBuffer`` referred in
|
||||
|
@ -187,7 +192,8 @@ class MessageHub(ManagerMixin):
|
|||
time calling ``update_info``.
|
||||
|
||||
Note:
|
||||
resumed cannot be set repeatedly for the same key.
|
||||
The ``resumed`` argument needs to be consistent for the same
|
||||
``key``.
|
||||
|
||||
Examples:
|
||||
>>> message_hub = MessageHub()
|
||||
|
@ -203,6 +209,31 @@ class MessageHub(ManagerMixin):
|
|||
self._resumed_keys[key] = resumed
|
||||
self._runtime_info[key] = value
|
||||
|
||||
def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None:
|
||||
"""Update runtime information with dictionary.
|
||||
|
||||
The key corresponding runtime information will be overwritten each
|
||||
time calling ``update_info``.
|
||||
|
||||
Note:
|
||||
The ``resumed`` argument needs to be consistent for the same
|
||||
``info_dict``.
|
||||
|
||||
Examples:
|
||||
>>> message_hub = MessageHub()
|
||||
>>> message_hub.update_info({'iter': 100})
|
||||
|
||||
Args:
|
||||
info_dict (str): Runtime information dictionary.
|
||||
resumed (bool): Whether the corresponding ``HistoryBuffer``
|
||||
could be resumed.
|
||||
"""
|
||||
assert isinstance(info_dict, dict), ('`log_dict` must be a dict!, '
|
||||
f'but got {type(info_dict)}')
|
||||
for key, value in info_dict.items():
|
||||
self._set_resumed_keys(key, resumed)
|
||||
self.update_info(key, value, resumed=resumed)
|
||||
|
||||
def _set_resumed_keys(self, key: str, resumed: bool) -> None:
|
||||
"""Set corresponding resumed keys.
|
||||
|
||||
|
@ -331,7 +362,7 @@ class MessageHub(ManagerMixin):
|
|||
f'just return its reference. ',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
saved_scalars[key] = value
|
||||
saved_info[key] = value
|
||||
return dict(
|
||||
log_scalars=saved_scalars,
|
||||
runtime_info=saved_info,
|
||||
|
@ -359,9 +390,48 @@ class MessageHub(ManagerMixin):
|
|||
assert key in state_dict, (
|
||||
'The loaded `state_dict` of `MessageHub` must contain '
|
||||
f'key: `{key}`')
|
||||
self._log_scalars = copy.deepcopy(state_dict['log_scalars'])
|
||||
self._runtime_info = copy.deepcopy(state_dict['runtime_info'])
|
||||
self._resumed_keys = copy.deepcopy(state_dict['resumed_keys'])
|
||||
# The old `MessageHub` could save non-HistoryBuffer `log_scalars`,
|
||||
# therefore the loaded `log_scalars` needs to be filtered.
|
||||
for key, value in state_dict['log_scalars'].items():
|
||||
if not isinstance(value, HistoryBuffer):
|
||||
print_log(
|
||||
f'{key} in message_hub is not HistoryBuffer, '
|
||||
f'just skip resuming it.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
continue
|
||||
self.log_scalars[key] = value
|
||||
|
||||
for key, value in state_dict['runtime_info'].items():
|
||||
try:
|
||||
self._runtime_info[key] = copy.deepcopy(value)
|
||||
except: # noqa: E722
|
||||
print_log(
|
||||
f'{key} in message_hub cannot be copied, '
|
||||
f'just return its reference.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self._runtime_info[key] = value
|
||||
|
||||
for key, value in state_dict['resumed_keys'].items():
|
||||
if key not in set(self.log_scalars.keys()) | \
|
||||
set(self._runtime_info.keys()):
|
||||
print_log(
|
||||
f'resumed key: {key} is not defined in message_hub, '
|
||||
f'just skip resuming this key.',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
continue
|
||||
elif not value:
|
||||
print_log(
|
||||
f'Although resumed key: {key} is False, {key} '
|
||||
'will still be loaded this time. This key will '
|
||||
'not be saved by the next calling of '
|
||||
'`MessageHub.state_dict()`',
|
||||
logger='current',
|
||||
level=logging.WARNING)
|
||||
self._resumed_keys[key] = value
|
||||
|
||||
# Since some checkpoints saved serialized `message_hub` instance,
|
||||
# `load_state_dict` support loading `message_hub` instance for
|
||||
# compatibility
|
||||
|
|
|
@ -282,7 +282,6 @@ class IterBasedTrainLoop(BaseLoop):
|
|||
# outputs should be a dict of loss.
|
||||
outputs = self.runner.model.train_step(
|
||||
data_batch, optim_wrapper=self.runner.optim_wrapper)
|
||||
self.runner.message_hub.update_info('train_logs', outputs)
|
||||
|
||||
self.runner.call_hook(
|
||||
'after_train_iter',
|
||||
|
|
|
@ -6,7 +6,13 @@ import numpy as np
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmengine import MessageHub
|
||||
from mmengine import HistoryBuffer, MessageHub
|
||||
|
||||
|
||||
class NoDeepCopy:
|
||||
|
||||
def __deepcopy__(self, memodict={}):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TestMessageHub:
|
||||
|
@ -45,6 +51,15 @@ class TestMessageHub:
|
|||
message_hub.update_info('key', 1)
|
||||
assert message_hub.runtime_info['key'] == 1
|
||||
|
||||
def test_update_infos(self):
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
# test runtime value can be overwritten.
|
||||
message_hub.update_info_dict({'a': 2, 'b': 3})
|
||||
assert message_hub.runtime_info['a'] == 2
|
||||
assert message_hub.runtime_info['b'] == 3
|
||||
assert message_hub._resumed_keys['a']
|
||||
assert message_hub._resumed_keys['b']
|
||||
|
||||
def test_get_scalar(self):
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
# Get undefined key will raise error
|
||||
|
@ -102,14 +117,18 @@ class TestMessageHub:
|
|||
# update runtime information
|
||||
message_hub.update_info('iter', 1, resumed=True)
|
||||
message_hub.update_info('tensor', [1, 2, 3], resumed=False)
|
||||
no_copy = NoDeepCopy()
|
||||
message_hub.update_info('no_copy', no_copy, resumed=True)
|
||||
state_dict = message_hub.state_dict()
|
||||
|
||||
assert state_dict['log_scalars']['loss'].data == (np.array([0.1]),
|
||||
np.array([1]))
|
||||
assert 'lr' not in state_dict['log_scalars']
|
||||
assert state_dict['runtime_info']['iter'] == 1
|
||||
assert 'tensor' not in state_dict
|
||||
assert 'tensor' not in state_dict['runtime_info']
|
||||
assert state_dict['runtime_info']['no_copy'] is no_copy
|
||||
|
||||
def test_load_state_dict(self):
|
||||
def test_load_state_dict(self, capsys):
|
||||
message_hub1 = MessageHub.get_instance('test_load_state_dict1')
|
||||
# update log_scalars.
|
||||
message_hub1.update_scalar('loss', 0.1)
|
||||
|
@ -133,6 +152,23 @@ class TestMessageHub:
|
|||
np.array([1]))
|
||||
assert message_hub3.get_info('iter') == 1
|
||||
|
||||
# Test resume custom state_dict
|
||||
state_dict = OrderedDict()
|
||||
state_dict['log_scalars'] = dict(a=1, b=HistoryBuffer())
|
||||
state_dict['runtime_info'] = dict(c=1, d=NoDeepCopy(), e=1)
|
||||
state_dict['resumed_keys'] = dict(
|
||||
a=True, b=True, c=True, e=False, f=True)
|
||||
|
||||
message_hub4 = MessageHub.get_instance('test_load_state_dict4')
|
||||
message_hub4.load_state_dict(state_dict)
|
||||
assert 'a' not in message_hub4.log_scalars and 'b' in \
|
||||
message_hub4.log_scalars
|
||||
assert 'c' in message_hub4.runtime_info and \
|
||||
state_dict['runtime_info']['d'] is \
|
||||
message_hub4.runtime_info['d']
|
||||
assert message_hub4._resumed_keys == OrderedDict(
|
||||
b=True, c=True, e=False)
|
||||
|
||||
def test_getstate(self):
|
||||
message_hub = MessageHub.get_instance('name')
|
||||
# update log_scalars.
|
||||
|
|
Loading…
Reference in New Issue