Add the loop stage in message_hub (#1277)

pull/1279/head
Zaida Zhou 2023-07-31 14:22:49 +08:00 committed by GitHub
parent 237aee3866
commit 2df93eb51f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 2 deletions

View File

@ -54,12 +54,15 @@ class RuntimeInfoHook(Hook):
mmengine_version=__version__ + get_git_hash())
runner.message_hub.update_info_dict(metainfo)
self.last_loop_stage = None
def before_train(self, runner) -> None:
"""Update resumed training state.
Args:
runner (Runner): The runner of the training process.
"""
runner.message_hub.update_info('loop_stage', 'train')
runner.message_hub.update_info('epoch', runner.epoch)
runner.message_hub.update_info('iter', runner.iter)
runner.message_hub.update_info('max_epochs', runner.max_epochs)
@ -68,6 +71,9 @@ class RuntimeInfoHook(Hook):
runner.message_hub.update_info(
'dataset_meta', runner.train_dataloader.dataset.metainfo)
def after_train(self, runner) -> None:
runner.message_hub.pop_info('loop_stage')
def before_train_epoch(self, runner) -> None:
"""Update current epoch information before every epoch.
@ -119,6 +125,10 @@ class RuntimeInfoHook(Hook):
for key, value in outputs.items():
runner.message_hub.update_scalar(f'train/{key}', value)
def before_val(self, runner) -> None:
self.last_loop_stage = runner.message_hub.get_info('loop_stage')
runner.message_hub.update_info('loop_stage', 'val')
def after_val_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:
@ -138,6 +148,22 @@ class RuntimeInfoHook(Hook):
else:
runner.message_hub.update_info(f'val/{key}', value)
def after_val(self, runner) -> None:
# ValLoop may be called within the TrainLoop, so we need to reset
# the loop_stage
# workflow: before_train -> before_val -> after_val -> after_train
if self.last_loop_stage == 'train':
runner.message_hub.update_info('loop_stage', self.last_loop_stage)
self.last_loop_stage = None
else:
runner.message_hub.pop_info('loop_stage')
def before_test(self, runner) -> None:
runner.message_hub.update_info('loop_stage', 'test')
def after_test(self, runner) -> None:
runner.message_hub.pop_info('loop_stage')
def after_test_epoch(self,
runner,
metrics: Optional[Dict[str, float]] = None) -> None:

View File

@ -202,6 +202,20 @@ class MessageHub(ManagerMixin):
self._set_resumed_keys(key, resumed)
self._runtime_info[key] = value
def pop_info(self, key: str, default: Optional[Any] = None) -> Any:
"""Remove runtime information by key. If the key does not exist, this
method will return the default value.
Args:
key (str): Key of runtime information.
default (Any, optional): The default returned value for the
given key.
Returns:
Any: The runtime information if the key exists.
"""
return self._runtime_info.pop(key, default)
def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None:
"""Update runtime information with dictionary.
@ -289,7 +303,7 @@ class MessageHub(ManagerMixin):
return self.log_scalars[key]
def get_info(self, key: str, default: Optional[Any] = None) -> Any:
"""Get runtime information by key. if the key does not exist, this
"""Get runtime information by key. If the key does not exist, this
method will return default information.
Args:

View File

@ -36,17 +36,20 @@ class TestRuntimeInfoHook(RunnerTestCase):
DATASETS.module_dict.pop('DatasetWithMetainfo')
return super().tearDown()
def test_before_train(self):
def test_before_and_after_train(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.train_dataloader.dataset.type = 'DatasetWithoutMetainfo'
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_train(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
self.assertEqual(runner.message_hub.get_info('epoch'), 0)
self.assertEqual(runner.message_hub.get_info('iter'), 0)
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)
hook.after_train(runner)
self.assertIsNone(runner.message_hub.get_info('loop_stage'))
cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo'
runner = self.build_runner(cfg)
@ -110,6 +113,28 @@ class TestRuntimeInfoHook(RunnerTestCase):
self.assertEqual(
runner.message_hub.get_scalar('train/loss_cls').current(), 1.111)
def test_before_and_after_val(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_val(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val')
self.assertIsNone(hook.last_loop_stage)
hook.after_val(runner)
self.assertIsNone(runner.message_hub.get_info('loop_stage'))
# Simulate the workflow of calling the ValLoop within the TrainLoop
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_train(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
hook.before_val(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val')
self.assertEqual(hook.last_loop_stage, 'train')
hook.after_val(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
self.assertIsNone(hook.last_loop_stage)
def test_after_val_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
@ -118,6 +143,15 @@ class TestRuntimeInfoHook(RunnerTestCase):
self.assertEqual(
runner.message_hub.get_scalar('val/acc').current(), 0.8)
def test_before_and_after_test(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)
hook = self._get_runtime_info_hook(runner)
hook.before_test(runner)
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'test')
hook.after_test(runner)
self.assertIsNone(runner.message_hub.get_info('loop_stage'))
def test_after_test_epoch(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
runner = self.build_runner(cfg)

View File

@ -55,6 +55,14 @@ class TestMessageHub:
message_hub.update_info('key', 1)
assert message_hub.runtime_info['key'] == 1
def test_pop_info(self):
message_hub = MessageHub.get_instance('mmengine')
message_hub.update_info('pop_key', 'pop_info')
assert message_hub.runtime_info['pop_key'] == 'pop_info'
assert message_hub.pop_info('pop_key') == 'pop_info'
assert message_hub.pop_info('not_existed_key', 'info') == 'info'
def test_update_infos(self):
message_hub = MessageHub.get_instance('mmengine')
# test runtime value can be overwritten.