Add syncbuffer hook (#443)

* reformat

* reformat

* Add register hook from cfg

* docstring

* change according to comments
pull/446/head
Wang Xinjiang 2020-07-24 14:15:44 +08:00 committed by GitHub
parent 903091effc
commit 66604e83de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 62 additions and 16 deletions

View File

@ -271,6 +271,22 @@ class BaseRunner(metaclass=ABCMeta):
if not inserted:
self._hooks.insert(0, hook)
def register_hook_from_cfg(self, hook_cfg):
"""Register a hook from its cfg.
Args:
hook_cfg (dict): Hook config. It should have at least keys 'type'
and 'priority' indicating its type and priority.
Notes:
The specific hook class to register should not use 'type' and
'priority' arguments during initialization.
"""
hook_cfg = hook_cfg.copy()
priority = hook_cfg.pop('priority', 'NORMAL')
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
self.register_hook(hook, priority=priority)
def call_hook(self, fn_name):
"""Call all hooks.

View File

@ -0,0 +1,26 @@
# Copyright (c) Open-MMLab. All rights reserved.
import torch.distributed as dist
from .hook import HOOKS, Hook
@HOOKS.register_module()
class SyncBuffersHook(Hook):
"""Synchronize model buffers such as running_mean and running_var in BN at
the end of each epoch.
Args:
distributed (bool): Whether distributed training is used. It is
effective only for distributed training. Defaults to True.
"""
def __init__(self, distributed=True):
self.distributed = distributed
def after_epoch(self, runner):
"""All-reduce model buffers at the end of each epoch."""
if self.distributed:
buffers = runner.model.buffers()
world_size = dist.get_world_size()
for tensor in buffers:
dist.all_reduce(tensor.div_(world_size))

View File

@ -18,11 +18,7 @@ from torch.utils.data import DataLoader
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
PaviLoggerHook, WandbLoggerHook)
from mmcv.runner.hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
CosineRestartLrUpdaterHook,
CyclicLrUpdaterHook)
from mmcv.runner.hooks.momentum_updater import (
CosineAnnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook)
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
def test_pavi_hook():
@ -53,21 +49,23 @@ def test_momentum_runner_hook():
runner = _build_demo_runner()
# add momentum scheduler
hook = CyclicMomentumUpdaterHook(
hook_cfg = dict(
type='CyclicMomentumUpdaterHook',
by_epoch=False,
target_ratio=(0.85 / 0.95, 1),
cyclic_times=1,
step_ratio_up=0.4)
runner.register_hook(hook)
runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler
hook = CyclicLrUpdaterHook(
hook_cfg = dict(
type='CyclicLrUpdaterHook',
by_epoch=False,
target_ratio=(10, 1),
cyclic_times=1,
step_ratio_up=0.4)
runner.register_hook(hook)
runner.register_hook(IterTimerHook())
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
@ -101,19 +99,25 @@ def test_cosine_runner_hook():
runner = _build_demo_runner()
# add momentum scheduler
hook = CosineAnnealingMomentumUpdaterHook(
hook_cfg = dict(
type='CosineAnnealingMomentumUpdaterHook',
min_momentum_ratio=0.99 / 0.95,
by_epoch=False,
warmup_iters=2,
warmup_ratio=0.9 / 0.95)
runner.register_hook(hook)
runner.register_hook_from_cfg(hook_cfg)
# add momentum LR scheduler
hook = CosineAnnealingLrUpdaterHook(
by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9)
runner.register_hook(hook)
hook_cfg = dict(
type='CosineAnnealingLrUpdaterHook',
by_epoch=False,
min_lr_ratio=0,
warmup_iters=2,
warmup_ratio=0.9)
runner.register_hook_from_cfg(hook_cfg)
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
runner.register_hook(IterTimerHook())
# add pavi hook
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
runner.register_hook(hook)