add register_itertimer_hook function (#838)

* add register_itertimer_hook function

* set default value

* revise minors

* revise according to comments

* fix according to comments

* update

* update
pull/842/head
Miao Zheng 2021-02-20 10:46:31 +08:00 committed by GitHub
parent f75a88c297
commit 7fa78e7a1c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 15 additions and 3 deletions

View File

@ -1,4 +1,5 @@
# Copyright (c) Open-MMLab. All rights reserved.
import copy
import logging
import os.path as osp
import warnings
@ -11,7 +12,7 @@ import mmcv
from ..parallel import is_module_wrapper
from .checkpoint import load_checkpoint
from .dist_utils import get_dist_info
from .hooks import HOOKS, Hook, IterTimerHook
from .hooks import HOOKS, Hook
from .log_buffer import LogBuffer
from .priority import get_priority
from .utils import get_time_str
@ -414,12 +415,23 @@ class BaseRunner(metaclass=ABCMeta):
info, HOOKS, default_args=dict(interval=log_interval))
self.register_hook(logger_hook, priority='VERY_LOW')
def register_timer_hook(self, timer_config):
if timer_config is None:
return
if isinstance(timer_config, dict):
timer_config_ = copy.deepcopy(timer_config)
hook = mmcv.buid_from_cfg(timer_config_, HOOKS)
else:
hook = timer_config
self.register_hook(hook)
def register_training_hooks(self,
lr_config,
optimizer_config=None,
checkpoint_config=None,
log_config=None,
momentum_config=None):
momentum_config=None,
timer_config=dict(type='IterTimerHook')):
"""Register default hooks for training.
Default hooks include:
@ -435,5 +447,5 @@ class BaseRunner(metaclass=ABCMeta):
self.register_momentum_hook(momentum_config)
self.register_optimizer_hook(optimizer_config)
self.register_checkpoint_hook(checkpoint_config)
self.register_hook(IterTimerHook())
self.register_timer_hook(timer_config)
self.register_logger_hooks(log_config)