[Feature]: Support custom hooks in IterBasedRunner. (#1193)

This commit is contained in:
Ma Zerun 2021-07-13 14:09:56 +08:00 committed by GitHub
parent 44e19ff68c
commit a5684b0de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -225,28 +225,48 @@ class IterBasedRunner(BaseRunner):
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None):
momentum_config=None,
custom_hooks_config=None):
"""Register default hooks for iter-based training.
Checkpoint hook, optimizer stepper hook and logger hooks will be set to
`by_epoch=False` by default.
Default hooks include:
- LrUpdaterHook
- MomentumUpdaterHook
- OptimizerStepperHook
- CheckpointSaverHook
- IterTimerHook
- LoggerHook(s)
+----------------------+-------------------------+
| Hooks | Priority |
+======================+=========================+
| LrUpdaterHook | VERY_HIGH (10) |
+----------------------+-------------------------+
| MomentumUpdaterHook | HIGH (30) |
+----------------------+-------------------------+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
+----------------------+-------------------------+
| CheckpointSaverHook | NORMAL (50) |
+----------------------+-------------------------+
| IterTimerHook | LOW (70) |
+----------------------+-------------------------+
| LoggerHook(s) | VERY_LOW (90) |
+----------------------+-------------------------+
| CustomHook(s) | defaults to NORMAL (50) |
+----------------------+-------------------------+
If custom hooks have same priority with default hooks, custom hooks
will be triggered after default hooks.
"""
if checkpoint_config is not None:
checkpoint_config.setdefault('by_epoch', False)
if lr_config is not None:
lr_config.setdefault('by_epoch', False)
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
if log_config is not None:
for info in log_config['hooks']:
info.setdefault('by_epoch', False)
self.register_logger_hooks(log_config)
super(IterBasedRunner, self).register_training_hooks(
lr_config=lr_config,
momentum_config=momentum_config,
optimizer_config=optimizer_config,
checkpoint_config=checkpoint_config,
log_config=log_config,
timer_config=IterTimerHook(),
custom_hooks_config=custom_hooks_config)