[Feature] Add get_hooks_info() to print hooks messages (#672)

* Add test of get_hooks_info()

* Change to use original Runner for get_hook_info() test

* Change to test after_train_iter hooks for get_hook_info()

* Complement the stages list

* Add logging hooks information in Runner.__init__()

* Rearrange the stages list

* Restore the stages to tuple type

* Clean the unnecessary changes

* Replace  statement with TestCase's methods

* add test stages in method_stages_map

* change the hooks info into a f-string

* return list(trigger_stages) directly

* change keys of method_stages_map

* Fix previous changes to method_stages_map.keys
pull/760/head
songyuc 2022-11-22 20:02:29 +08:00 committed by GitHub
parent b06234cfcd
commit 6636f07cfe
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 76 additions and 0 deletions

View File

@ -1,6 +1,8 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, Optional, Sequence, Union
from mmengine import is_method_overridden
DATA_BATCH = Optional[Union[dict, tuple, list]]
@ -11,6 +13,13 @@ class Hook:
"""
priority = 'NORMAL'
stages = ('before_run', 'after_load_checkpoint', 'before_train',
'before_train_epoch', 'before_train_iter', 'after_train_iter',
'after_train_epoch', 'before_val', 'before_val_epoch',
'before_val_iter', 'after_val_iter', 'after_val_epoch',
'after_val', 'before_save_checkpoint', 'after_train',
'before_test', 'before_test_epoch', 'before_test_iter',
'after_test_iter', 'after_test_epoch', 'after_test', 'after_run')
def before_run(self, runner) -> None:
"""All subclasses should override this method, if they need any
@ -416,3 +425,28 @@ class Hook:
bool: Whether current iteration is the last train iteration.
"""
return runner.iter + 1 == runner.max_iters
def get_triggered_stages(self) -> list:
trigger_stages = set()
for stage in Hook.stages:
if is_method_overridden(stage, Hook, self):
trigger_stages.add(stage)
# some methods will be triggered in multi stages
# use this dict to map method to stages.
method_stages_map = {
'_before_epoch':
['before_train_epoch', 'before_val_epoch', 'before_test_epoch'],
'_after_epoch':
['after_train_epoch', 'after_val_epoch', 'after_test_epoch'],
'_before_iter':
['before_train_iter', 'before_val_iter', 'before_test_iter'],
'_after_iter':
['after_train_iter', 'after_val_iter', 'after_test_iter'],
}
for method, map_stages in method_stages_map.items():
if is_method_overridden(method, Hook, self):
trigger_stages.update(map_stages)
return list(trigger_stages)

View File

@ -415,6 +415,9 @@ class Runner:
self._hooks: List[Hook] = []
# register hooks to `self._hooks`
self.register_hooks(default_hooks, custom_hooks)
# log hooks information
self.logger.info(f'Hooks will be executed in the following '
f'order:\n{self.get_hooks_info()}')
# dump `cfg` to `work_dir`
self.dump_config()
@ -1576,6 +1579,29 @@ class Runner:
return log_processor # type: ignore
def get_hooks_info(self) -> str:
# Get hooks info in each stage
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
for hook in self.hooks:
try:
priority = Priority(hook.priority).name # type: ignore
except ValueError:
priority = hook.priority # type: ignore
classname = hook.__class__.__name__
hook_info = f'({priority:<12}) {classname:<35}'
for trigger_stage in hook.get_triggered_stages():
stage_hook_map[trigger_stage].append(hook_info)
stage_hook_infos = []
for stage in Hook.stages:
hook_infos = stage_hook_map[stage]
if len(hook_infos) > 0:
info = f'{stage}:\n'
info += '\n'.join(hook_infos)
info += '\n -------------------- '
stage_hook_infos.append(info)
return '\n'.join(stage_hook_infos)
def load_or_resume(self) -> None:
"""load or resume checkpoint."""
if self._has_loaded:

View File

@ -2330,3 +2330,19 @@ class TestRunner(TestCase):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_runner2'
assert isinstance(RUNNERS.build(cfg), Runner)
def test_get_hooks_info(self):
# test get_hooks_info() function
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_get_hooks_info_from_test_runner_py'
cfg.runner_type = 'Runner'
runner = RUNNERS.build(cfg)
self.assertIsInstance(runner, Runner)
target_str = ('after_train_iter:\n'
'(VERY_HIGH ) RuntimeInfoHook \n'
'(NORMAL ) IterTimerHook \n'
'(BELOW_NORMAL) LoggerHook \n'
'(LOW ) ParamSchedulerHook \n'
'(VERY_LOW ) CheckpointHook \n')
self.assertIn(target_str, runner.get_hooks_info(),
'target string is not in logged hooks information.')