Add the loop stage in message_hub (#1277)
parent
237aee3866
commit
2df93eb51f
|
@ -54,12 +54,15 @@ class RuntimeInfoHook(Hook):
|
|||
mmengine_version=__version__ + get_git_hash())
|
||||
runner.message_hub.update_info_dict(metainfo)
|
||||
|
||||
self.last_loop_stage = None
|
||||
|
||||
def before_train(self, runner) -> None:
|
||||
"""Update resumed training state.
|
||||
|
||||
Args:
|
||||
runner (Runner): The runner of the training process.
|
||||
"""
|
||||
runner.message_hub.update_info('loop_stage', 'train')
|
||||
runner.message_hub.update_info('epoch', runner.epoch)
|
||||
runner.message_hub.update_info('iter', runner.iter)
|
||||
runner.message_hub.update_info('max_epochs', runner.max_epochs)
|
||||
|
@ -68,6 +71,9 @@ class RuntimeInfoHook(Hook):
|
|||
runner.message_hub.update_info(
|
||||
'dataset_meta', runner.train_dataloader.dataset.metainfo)
|
||||
|
||||
def after_train(self, runner) -> None:
|
||||
runner.message_hub.pop_info('loop_stage')
|
||||
|
||||
def before_train_epoch(self, runner) -> None:
|
||||
"""Update current epoch information before every epoch.
|
||||
|
||||
|
@ -119,6 +125,10 @@ class RuntimeInfoHook(Hook):
|
|||
for key, value in outputs.items():
|
||||
runner.message_hub.update_scalar(f'train/{key}', value)
|
||||
|
||||
def before_val(self, runner) -> None:
|
||||
self.last_loop_stage = runner.message_hub.get_info('loop_stage')
|
||||
runner.message_hub.update_info('loop_stage', 'val')
|
||||
|
||||
def after_val_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
|
@ -138,6 +148,22 @@ class RuntimeInfoHook(Hook):
|
|||
else:
|
||||
runner.message_hub.update_info(f'val/{key}', value)
|
||||
|
||||
def after_val(self, runner) -> None:
|
||||
# ValLoop may be called within the TrainLoop, so we need to reset
|
||||
# the loop_stage
|
||||
# workflow: before_train -> before_val -> after_val -> after_train
|
||||
if self.last_loop_stage == 'train':
|
||||
runner.message_hub.update_info('loop_stage', self.last_loop_stage)
|
||||
self.last_loop_stage = None
|
||||
else:
|
||||
runner.message_hub.pop_info('loop_stage')
|
||||
|
||||
def before_test(self, runner) -> None:
|
||||
runner.message_hub.update_info('loop_stage', 'test')
|
||||
|
||||
def after_test(self, runner) -> None:
|
||||
runner.message_hub.pop_info('loop_stage')
|
||||
|
||||
def after_test_epoch(self,
|
||||
runner,
|
||||
metrics: Optional[Dict[str, float]] = None) -> None:
|
||||
|
|
|
@ -202,6 +202,20 @@ class MessageHub(ManagerMixin):
|
|||
self._set_resumed_keys(key, resumed)
|
||||
self._runtime_info[key] = value
|
||||
|
||||
def pop_info(self, key: str, default: Optional[Any] = None) -> Any:
|
||||
"""Remove runtime information by key. If the key does not exist, this
|
||||
method will return the default value.
|
||||
|
||||
Args:
|
||||
key (str): Key of runtime information.
|
||||
default (Any, optional): The default returned value for the
|
||||
given key.
|
||||
|
||||
Returns:
|
||||
Any: The runtime information if the key exists.
|
||||
"""
|
||||
return self._runtime_info.pop(key, default)
|
||||
|
||||
def update_info_dict(self, info_dict: dict, resumed: bool = True) -> None:
|
||||
"""Update runtime information with dictionary.
|
||||
|
||||
|
@ -289,7 +303,7 @@ class MessageHub(ManagerMixin):
|
|||
return self.log_scalars[key]
|
||||
|
||||
def get_info(self, key: str, default: Optional[Any] = None) -> Any:
|
||||
"""Get runtime information by key. if the key does not exist, this
|
||||
"""Get runtime information by key. If the key does not exist, this
|
||||
method will return default information.
|
||||
|
||||
Args:
|
||||
|
|
|
@ -36,17 +36,20 @@ class TestRuntimeInfoHook(RunnerTestCase):
|
|||
DATASETS.module_dict.pop('DatasetWithMetainfo')
|
||||
return super().tearDown()
|
||||
|
||||
def test_before_train(self):
|
||||
def test_before_and_after_train(self):
|
||||
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.train_dataloader.dataset.type = 'DatasetWithoutMetainfo'
|
||||
runner = self.build_runner(cfg)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.before_train(runner)
|
||||
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
|
||||
self.assertEqual(runner.message_hub.get_info('epoch'), 0)
|
||||
self.assertEqual(runner.message_hub.get_info('iter'), 0)
|
||||
self.assertEqual(runner.message_hub.get_info('max_epochs'), 2)
|
||||
self.assertEqual(runner.message_hub.get_info('max_iters'), 8)
|
||||
hook.after_train(runner)
|
||||
self.assertIsNone(runner.message_hub.get_info('loop_stage'))
|
||||
|
||||
cfg.train_dataloader.dataset.type = 'DatasetWithMetainfo'
|
||||
runner = self.build_runner(cfg)
|
||||
|
@ -110,6 +113,28 @@ class TestRuntimeInfoHook(RunnerTestCase):
|
|||
self.assertEqual(
|
||||
runner.message_hub.get_scalar('train/loss_cls').current(), 1.111)
|
||||
|
||||
def test_before_and_after_val(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.before_val(runner)
|
||||
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val')
|
||||
self.assertIsNone(hook.last_loop_stage)
|
||||
hook.after_val(runner)
|
||||
self.assertIsNone(runner.message_hub.get_info('loop_stage'))
|
||||
|
||||
# Simulate the workflow of calling the ValLoop within the TrainLoop
|
||||
runner = self.build_runner(cfg)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.before_train(runner)
|
||||
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
|
||||
hook.before_val(runner)
|
||||
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'val')
|
||||
self.assertEqual(hook.last_loop_stage, 'train')
|
||||
hook.after_val(runner)
|
||||
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'train')
|
||||
self.assertIsNone(hook.last_loop_stage)
|
||||
|
||||
def test_after_val_epoch(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
|
@ -118,6 +143,15 @@ class TestRuntimeInfoHook(RunnerTestCase):
|
|||
self.assertEqual(
|
||||
runner.message_hub.get_scalar('val/acc').current(), 0.8)
|
||||
|
||||
def test_before_and_after_test(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
hook = self._get_runtime_info_hook(runner)
|
||||
hook.before_test(runner)
|
||||
self.assertEqual(runner.message_hub.get_info('loop_stage'), 'test')
|
||||
hook.after_test(runner)
|
||||
self.assertIsNone(runner.message_hub.get_info('loop_stage'))
|
||||
|
||||
def test_after_test_epoch(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
runner = self.build_runner(cfg)
|
||||
|
|
|
@ -55,6 +55,14 @@ class TestMessageHub:
|
|||
message_hub.update_info('key', 1)
|
||||
assert message_hub.runtime_info['key'] == 1
|
||||
|
||||
def test_pop_info(self):
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
message_hub.update_info('pop_key', 'pop_info')
|
||||
assert message_hub.runtime_info['pop_key'] == 'pop_info'
|
||||
assert message_hub.pop_info('pop_key') == 'pop_info'
|
||||
|
||||
assert message_hub.pop_info('not_existed_key', 'info') == 'info'
|
||||
|
||||
def test_update_infos(self):
|
||||
message_hub = MessageHub.get_instance('mmengine')
|
||||
# test runtime value can be overwritten.
|
||||
|
|
Loading…
Reference in New Issue