From 78fad67d0d70e3f7e501194f876125e2a044ec2d Mon Sep 17 00:00:00 2001 From: Mashiro <57566630+HAOCHENYE@users.noreply.github.com> Date: Thu, 14 Jul 2022 20:13:22 +0800 Subject: [PATCH] [Fix] fix resume message_hub (#353) * fix resume message_hub * add unit test * support resume from messagehub * minor refine * add comment * fix typo * update docstring --- mmengine/logging/message_hub.py | 85 ++++++++++++++++++++------ mmengine/runner/runner.py | 4 +- tests/test_logging/test_message_hub.py | 43 ++++++++++++- tests/test_runner/test_runner.py | 30 +++++++-- 4 files changed, 135 insertions(+), 27 deletions(-) diff --git a/mmengine/logging/message_hub.py b/mmengine/logging/message_hub.py index 4489b6df..a3083b10 100644 --- a/mmengine/logging/message_hub.py +++ b/mmengine/logging/message_hub.py @@ -1,4 +1,6 @@ # Copyright (c) OpenMMLab. All rights reserved. +import copy +import logging from collections import OrderedDict from typing import Any, Optional, Union @@ -7,6 +9,7 @@ import torch from mmengine.utils import ManagerMixin from .history_buffer import HistoryBuffer +from .logger import print_log class MessageHub(ManagerMixin): @@ -217,7 +220,7 @@ class MessageHub(ManagerMixin): else: assert self._resumed_keys[key] == resumed, \ f'{key} used to be {self._resumed_keys[key]}, but got ' \ - '{resumed} now. resumed keys cannot be modified repeatedly' + '{resumed} now. resumed keys cannot be modified repeatedly.' @property def log_scalars(self) -> OrderedDict: @@ -301,20 +304,68 @@ class MessageHub(ManagerMixin): assert isinstance(value, (int, float)) return value # type: ignore - def __getstate__(self): - for key in list(self._log_scalars.keys()): - assert key in self._resumed_keys, ( - f'Cannot found {key} in {self}._resumed_keys, ' - 'please make sure you do not change the _resumed_keys ' - 'outside the class') - if not self._resumed_keys[key]: - self._log_scalars.pop(key) + def state_dict(self) -> dict: + """Returns a dictionary containing log scalars, runtime information and + resumed keys, which should be resumed. - for key in list(self._runtime_info.keys()): - assert key in self._resumed_keys, ( - f'Cannot found {key} in {self}._resumed_keys, ' - 'please make sure you do not change the _resumed_keys ' - 'outside the class') - if not self._resumed_keys[key]: - self._runtime_info.pop(key) - return self.__dict__ + The returned ``state_dict`` can be loaded by :meth:`load_state_dict`. + + Returns: + dict: A dictionary contains ``log_scalars``, ``runtime_info`` and + ``resumed_keys``. + """ + saved_scalars = OrderedDict() + saved_info = OrderedDict() + + for key, value in self._log_scalars.items(): + if self._resumed_keys.get(key, False): + saved_scalars[key] = copy.deepcopy(value) + + for key, value in self._runtime_info.items(): + if self._resumed_keys.get(key, False): + try: + saved_info[key] = copy.deepcopy(value) + except: # noqa: E722 + print_log( + f'{key} in message_hub cannot be copied, ' + f'just return its reference. ', + logger='current', + level=logging.WARNING) + saved_scalars[key] = value + return dict( + log_scalars=saved_scalars, + runtime_info=saved_info, + resumed_keys=self._resumed_keys) + + def load_state_dict(self, state_dict: Union['MessageHub', dict]) -> None: + """Loads log scalars, runtime information and resumed keys from + ``state_dict`` or ``message_hub``. + + If ``state_dict`` is a dictionary returned by :meth:`state_dict`, it + will only make copies of data which should be resumed from the source + ``message_hub``. + + If ``state_dict`` is a ``message_hub`` instance, it will make copies of + all data from the source message_hub. We suggest to load data from + ``dict`` rather than a ``MessageHub`` instance. + + Args: + state_dict (dict or MessageHub): A dictionary contains key + ``log_scalars`` ``runtime_info`` and ``resumed_keys``, or a + MessageHub instance. + """ + if isinstance(state_dict, dict): + for key in ('log_scalars', 'runtime_info', 'resumed_keys'): + assert key in state_dict, ( + 'The loaded `state_dict` of `MessageHub` must contain ' + f'key: `{key}`') + self._log_scalars = copy.deepcopy(state_dict['log_scalars']) + self._runtime_info = copy.deepcopy(state_dict['runtime_info']) + self._resumed_keys = copy.deepcopy(state_dict['resumed_keys']) + # Since some checkpoints saved serialized `message_hub` instance, + # `load_state_dict` support loading `message_hub` instance for + # compatibility + else: + self._log_scalars = copy.deepcopy(state_dict._log_scalars) + self._runtime_info = copy.deepcopy(state_dict._runtime_info) + self._resumed_keys = copy.deepcopy(state_dict._resumed_keys) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 09cd2300..6b733b67 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -1881,7 +1881,7 @@ class Runner: 'check the correctness of the checkpoint or the training ' 'dataset.') - self.message_hub = checkpoint['message_hub'] + self.message_hub.load_state_dict(checkpoint['message_hub']) # resume optimizer if 'optimizer' in checkpoint and resume_optimizer: @@ -2008,7 +2008,7 @@ class Runner: checkpoint = { 'meta': meta, 'state_dict': weights_to_cpu(get_state_dict(model)), - 'message_hub': self.message_hub + 'message_hub': self.message_hub.state_dict() } # save optimizer state dict to checkpoint if save_optimizer: diff --git a/tests/test_logging/test_message_hub.py b/tests/test_logging/test_message_hub.py index e64cb88a..e0e1180d 100644 --- a/tests/test_logging/test_message_hub.py +++ b/tests/test_logging/test_message_hub.py @@ -94,6 +94,45 @@ class TestMessageHub: loss_dict = dict(error_type=dict(count=1)) message_hub.update_scalars(loss_dict) + def test_state_dict(self): + message_hub = MessageHub.get_instance('test_state_dict') + # update log_scalars. + message_hub.update_scalar('loss', 0.1) + message_hub.update_scalar('lr', 0.1, resumed=False) + # update runtime information + message_hub.update_info('iter', 1, resumed=True) + message_hub.update_info('tensor', [1, 2, 3], resumed=False) + state_dict = message_hub.state_dict() + assert state_dict['log_scalars']['loss'].data == (np.array([0.1]), + np.array([1])) + assert 'lr' not in state_dict['log_scalars'] + assert state_dict['runtime_info']['iter'] == 1 + assert 'tensor' not in state_dict + + def test_load_state_dict(self): + message_hub1 = MessageHub.get_instance('test_load_state_dict1') + # update log_scalars. + message_hub1.update_scalar('loss', 0.1) + message_hub1.update_scalar('lr', 0.1, resumed=False) + # update runtime information + message_hub1.update_info('iter', 1, resumed=True) + message_hub1.update_info('tensor', [1, 2, 3], resumed=False) + state_dict = message_hub1.state_dict() + + # Resume from state_dict + message_hub2 = MessageHub.get_instance('test_load_state_dict2') + message_hub2.load_state_dict(state_dict) + assert message_hub2.get_scalar('loss').data == (np.array([0.1]), + np.array([1])) + assert message_hub2.get_info('iter') == 1 + + # Test resume from `MessageHub` instance. + message_hub3 = MessageHub.get_instance('test_load_state_dict3') + message_hub3.load_state_dict(state_dict) + assert message_hub3.get_scalar('loss').data == (np.array([0.1]), + np.array([1])) + assert message_hub3.get_info('iter') == 1 + def test_getstate(self): message_hub = MessageHub.get_instance('name') # update log_scalars. @@ -116,8 +155,8 @@ class TestMessageHub: def test_get_instance(self): # Test get root mmengine message hub. MessageHub._instance_dict = OrderedDict() - root_logger = MessageHub.get_current_instance() - assert id(MessageHub.get_instance('mmengine')) == id(root_logger) + message_hub = MessageHub.get_current_instance() + assert id(MessageHub.get_instance('mmengine')) == id(message_hub) # Test original `get_current_instance` function. MessageHub.get_instance('mmdet') assert MessageHub.get_current_instance().instance_name == 'mmdet' diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 53d632b0..c9dcea4a 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -1689,9 +1689,11 @@ class TestRunner(TestCase): self.assertEqual(ckpt['meta']['seed'], runner.seed) assert isinstance(ckpt['optimizer'], dict) assert isinstance(ckpt['param_schedulers'], list) - self.assertIsInstance(ckpt['message_hub'], MessageHub) - self.assertEqual(ckpt['message_hub'].get_info('epoch'), 2) - self.assertEqual(ckpt['message_hub'].get_info('iter'), 11) + self.assertIsInstance(ckpt['message_hub'], dict) + message_hub = MessageHub.get_instance('test_ckpt') + message_hub.load_state_dict(ckpt['message_hub']) + self.assertEqual(message_hub.get_info('epoch'), 2) + self.assertEqual(message_hub.get_info('iter'), 11) # 1.2 test `load_checkpoint` cfg = copy.deepcopy(self.epoch_based_cfg) @@ -1728,6 +1730,10 @@ class TestRunner(TestCase): self.assertIsInstance(runner.message_hub, MessageHub) self.assertEqual(runner.message_hub.get_info('epoch'), 2) self.assertEqual(runner.message_hub.get_info('iter'), 11) + self.assertEqual(MessageHub.get_current_instance().get_info('epoch'), + 2) + self.assertEqual(MessageHub.get_current_instance().get_info('iter'), + 11) # 1.3.2 test resume with unmatched dataset_meta ckpt_modified = copy.deepcopy(ckpt) @@ -1856,9 +1862,10 @@ class TestRunner(TestCase): self.assertEqual(ckpt['meta']['iter'], 12) assert isinstance(ckpt['optimizer'], dict) assert isinstance(ckpt['param_schedulers'], list) - self.assertIsInstance(ckpt['message_hub'], MessageHub) - self.assertEqual(ckpt['message_hub'].get_info('epoch'), 0) - self.assertEqual(ckpt['message_hub'].get_info('iter'), 11) + self.assertIsInstance(ckpt['message_hub'], dict) + message_hub.load_state_dict(ckpt['message_hub']) + self.assertEqual(message_hub.get_info('epoch'), 0) + self.assertEqual(message_hub.get_info('iter'), 11) # 2.2 test `load_checkpoint` cfg = copy.deepcopy(self.iter_based_cfg) @@ -1907,6 +1914,17 @@ class TestRunner(TestCase): self.assertIsInstance(runner.optim_wrapper.optimizer, SGD) self.assertIsInstance(runner.param_schedulers[0], MultiStepLR) + # 2.6 test resumed message_hub has the history value. + cfg = copy.deepcopy(self.iter_based_cfg) + cfg.experiment_name = 'test_checkpoint13' + cfg.resume = True + cfg.load_from = osp.join(self.temp_dir, 'iter_3.pth') + runner = Runner.from_cfg(cfg) + runner.load_or_resume() + assert len(runner.message_hub.log_scalars['train/lr'].data[1]) == 3 + assert len(MessageHub.get_current_instance().log_scalars['train/lr']. + data[1]) == 3 + def test_build_runner(self): # No need to test other cases which have been tested in # `test_build_from_cfg`