[Feature] Add get_hooks_info() to print hooks messages (#672)
* Add test of get_hooks_info() * Change to use original Runner for get_hook_info() test * Change to test after_train_iter hooks for get_hook_info() * Complement the stages list * Add logging hooks information in Runner.__init__() * Rearrange the stages list * Restore the stages to tuple type * Clean the unnecessary changes * Replace statement with TestCase's methods * add test stages in method_stages_map * change the hooks info into a f-string * return list(trigger_stages) directly * change keys of method_stages_map * Fix previous changes to method_stages_map.keyspull/760/head
parent
b06234cfcd
commit
6636f07cfe
|
@ -1,6 +1,8 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Dict, Optional, Sequence, Union
|
||||
|
||||
from mmengine import is_method_overridden
|
||||
|
||||
DATA_BATCH = Optional[Union[dict, tuple, list]]
|
||||
|
||||
|
||||
|
@ -11,6 +13,13 @@ class Hook:
|
|||
"""
|
||||
|
||||
priority = 'NORMAL'
|
||||
stages = ('before_run', 'after_load_checkpoint', 'before_train',
|
||||
'before_train_epoch', 'before_train_iter', 'after_train_iter',
|
||||
'after_train_epoch', 'before_val', 'before_val_epoch',
|
||||
'before_val_iter', 'after_val_iter', 'after_val_epoch',
|
||||
'after_val', 'before_save_checkpoint', 'after_train',
|
||||
'before_test', 'before_test_epoch', 'before_test_iter',
|
||||
'after_test_iter', 'after_test_epoch', 'after_test', 'after_run')
|
||||
|
||||
def before_run(self, runner) -> None:
|
||||
"""All subclasses should override this method, if they need any
|
||||
|
@ -416,3 +425,28 @@ class Hook:
|
|||
bool: Whether current iteration is the last train iteration.
|
||||
"""
|
||||
return runner.iter + 1 == runner.max_iters
|
||||
|
||||
def get_triggered_stages(self) -> list:
|
||||
trigger_stages = set()
|
||||
for stage in Hook.stages:
|
||||
if is_method_overridden(stage, Hook, self):
|
||||
trigger_stages.add(stage)
|
||||
|
||||
# some methods will be triggered in multi stages
|
||||
# use this dict to map method to stages.
|
||||
method_stages_map = {
|
||||
'_before_epoch':
|
||||
['before_train_epoch', 'before_val_epoch', 'before_test_epoch'],
|
||||
'_after_epoch':
|
||||
['after_train_epoch', 'after_val_epoch', 'after_test_epoch'],
|
||||
'_before_iter':
|
||||
['before_train_iter', 'before_val_iter', 'before_test_iter'],
|
||||
'_after_iter':
|
||||
['after_train_iter', 'after_val_iter', 'after_test_iter'],
|
||||
}
|
||||
|
||||
for method, map_stages in method_stages_map.items():
|
||||
if is_method_overridden(method, Hook, self):
|
||||
trigger_stages.update(map_stages)
|
||||
|
||||
return list(trigger_stages)
|
||||
|
|
|
@ -415,6 +415,9 @@ class Runner:
|
|||
self._hooks: List[Hook] = []
|
||||
# register hooks to `self._hooks`
|
||||
self.register_hooks(default_hooks, custom_hooks)
|
||||
# log hooks information
|
||||
self.logger.info(f'Hooks will be executed in the following '
|
||||
f'order:\n{self.get_hooks_info()}')
|
||||
|
||||
# dump `cfg` to `work_dir`
|
||||
self.dump_config()
|
||||
|
@ -1576,6 +1579,29 @@ class Runner:
|
|||
|
||||
return log_processor # type: ignore
|
||||
|
||||
def get_hooks_info(self) -> str:
|
||||
# Get hooks info in each stage
|
||||
stage_hook_map: Dict[str, list] = {stage: [] for stage in Hook.stages}
|
||||
for hook in self.hooks:
|
||||
try:
|
||||
priority = Priority(hook.priority).name # type: ignore
|
||||
except ValueError:
|
||||
priority = hook.priority # type: ignore
|
||||
classname = hook.__class__.__name__
|
||||
hook_info = f'({priority:<12}) {classname:<35}'
|
||||
for trigger_stage in hook.get_triggered_stages():
|
||||
stage_hook_map[trigger_stage].append(hook_info)
|
||||
|
||||
stage_hook_infos = []
|
||||
for stage in Hook.stages:
|
||||
hook_infos = stage_hook_map[stage]
|
||||
if len(hook_infos) > 0:
|
||||
info = f'{stage}:\n'
|
||||
info += '\n'.join(hook_infos)
|
||||
info += '\n -------------------- '
|
||||
stage_hook_infos.append(info)
|
||||
return '\n'.join(stage_hook_infos)
|
||||
|
||||
def load_or_resume(self) -> None:
|
||||
"""load or resume checkpoint."""
|
||||
if self._has_loaded:
|
||||
|
|
|
@ -2330,3 +2330,19 @@ class TestRunner(TestCase):
|
|||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_runner2'
|
||||
assert isinstance(RUNNERS.build(cfg), Runner)
|
||||
|
||||
def test_get_hooks_info(self):
|
||||
# test get_hooks_info() function
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_get_hooks_info_from_test_runner_py'
|
||||
cfg.runner_type = 'Runner'
|
||||
runner = RUNNERS.build(cfg)
|
||||
self.assertIsInstance(runner, Runner)
|
||||
target_str = ('after_train_iter:\n'
|
||||
'(VERY_HIGH ) RuntimeInfoHook \n'
|
||||
'(NORMAL ) IterTimerHook \n'
|
||||
'(BELOW_NORMAL) LoggerHook \n'
|
||||
'(LOW ) ParamSchedulerHook \n'
|
||||
'(VERY_LOW ) CheckpointHook \n')
|
||||
self.assertIn(target_str, runner.get_hooks_info(),
|
||||
'target string is not in logged hooks information.')
|
||||
|
|
Loading…
Reference in New Issue