[Feature]: Support custom hooks in IterBasedRunner. (#1193)

This commit is contained in:
Ma Zerun 2021-07-13 14:09:56 +08:00 committed by GitHub
parent 44e19ff68c
commit a5684b0de2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -225,28 +225,48 @@ class IterBasedRunner(BaseRunner):
optimizer_config=None, optimizer_config=None,
checkpoint_config=None, checkpoint_config=None,
log_config=None, log_config=None,
momentum_config=None): momentum_config=None,
custom_hooks_config=None):
"""Register default hooks for iter-based training. """Register default hooks for iter-based training.
Checkpoint hook, optimizer stepper hook and logger hooks will be set to
`by_epoch=False` by default.
Default hooks include: Default hooks include:
- LrUpdaterHook +----------------------+-------------------------+
- MomentumUpdaterHook | Hooks | Priority |
- OptimizerStepperHook +======================+=========================+
- CheckpointSaverHook | LrUpdaterHook | VERY_HIGH (10) |
- IterTimerHook +----------------------+-------------------------+
- LoggerHook(s) | MomentumUpdaterHook | HIGH (30) |
+----------------------+-------------------------+
| OptimizerStepperHook | ABOVE_NORMAL (40) |
+----------------------+-------------------------+
| CheckpointSaverHook | NORMAL (50) |
+----------------------+-------------------------+
| IterTimerHook | LOW (70) |
+----------------------+-------------------------+
| LoggerHook(s) | VERY_LOW (90) |
+----------------------+-------------------------+
| CustomHook(s) | defaults to NORMAL (50) |
+----------------------+-------------------------+
If custom hooks have same priority with default hooks, custom hooks
will be triggered after default hooks.
""" """
if checkpoint_config is not None: if checkpoint_config is not None:
checkpoint_config.setdefault('by_epoch', False) checkpoint_config.setdefault('by_epoch', False)
if lr_config is not None: if lr_config is not None:
lr_config.setdefault('by_epoch', False) lr_config.setdefault('by_epoch', False)
self.register_lr_hook(lr_config)
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
if log_config is not None: if log_config is not None:
for info in log_config['hooks']: for info in log_config['hooks']:
info.setdefault('by_epoch', False) info.setdefault('by_epoch', False)
self.register_logger_hooks(log_config) super(IterBasedRunner, self).register_training_hooks(
lr_config=lr_config,
momentum_config=momentum_config,
optimizer_config=optimizer_config,
checkpoint_config=checkpoint_config,
log_config=log_config,
timer_config=IterTimerHook(),
custom_hooks_config=custom_hooks_config)