mirror of https://github.com/open-mmlab/mmcv.git
Add syncbuffer hook (#443)
* reformat * reformat * Add register hook from cfg * docstring * change according to commentspull/446/head
parent
903091effc
commit
66604e83de
|
@ -271,6 +271,22 @@ class BaseRunner(metaclass=ABCMeta):
|
|||
if not inserted:
|
||||
self._hooks.insert(0, hook)
|
||||
|
||||
def register_hook_from_cfg(self, hook_cfg):
|
||||
"""Register a hook from its cfg.
|
||||
|
||||
Args:
|
||||
hook_cfg (dict): Hook config. It should have at least keys 'type'
|
||||
and 'priority' indicating its type and priority.
|
||||
|
||||
Notes:
|
||||
The specific hook class to register should not use 'type' and
|
||||
'priority' arguments during initialization.
|
||||
"""
|
||||
hook_cfg = hook_cfg.copy()
|
||||
priority = hook_cfg.pop('priority', 'NORMAL')
|
||||
hook = mmcv.build_from_cfg(hook_cfg, HOOKS)
|
||||
self.register_hook(hook, priority=priority)
|
||||
|
||||
def call_hook(self, fn_name):
|
||||
"""Call all hooks.
|
||||
|
||||
|
|
|
@ -0,0 +1,26 @@
|
|||
# Copyright (c) Open-MMLab. All rights reserved.
|
||||
import torch.distributed as dist
|
||||
|
||||
from .hook import HOOKS, Hook
|
||||
|
||||
|
||||
@HOOKS.register_module()
|
||||
class SyncBuffersHook(Hook):
|
||||
"""Synchronize model buffers such as running_mean and running_var in BN at
|
||||
the end of each epoch.
|
||||
|
||||
Args:
|
||||
distributed (bool): Whether distributed training is used. It is
|
||||
effective only for distributed training. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, distributed=True):
|
||||
self.distributed = distributed
|
||||
|
||||
def after_epoch(self, runner):
|
||||
"""All-reduce model buffers at the end of each epoch."""
|
||||
if self.distributed:
|
||||
buffers = runner.model.buffers()
|
||||
world_size = dist.get_world_size()
|
||||
for tensor in buffers:
|
||||
dist.all_reduce(tensor.div_(world_size))
|
|
@ -18,11 +18,7 @@ from torch.utils.data import DataLoader
|
|||
|
||||
from mmcv.runner import (EpochBasedRunner, IterTimerHook, MlflowLoggerHook,
|
||||
PaviLoggerHook, WandbLoggerHook)
|
||||
from mmcv.runner.hooks.lr_updater import (CosineAnnealingLrUpdaterHook,
|
||||
CosineRestartLrUpdaterHook,
|
||||
CyclicLrUpdaterHook)
|
||||
from mmcv.runner.hooks.momentum_updater import (
|
||||
CosineAnnealingMomentumUpdaterHook, CyclicMomentumUpdaterHook)
|
||||
from mmcv.runner.hooks.lr_updater import CosineRestartLrUpdaterHook
|
||||
|
||||
|
||||
def test_pavi_hook():
|
||||
|
@ -53,21 +49,23 @@ def test_momentum_runner_hook():
|
|||
runner = _build_demo_runner()
|
||||
|
||||
# add momentum scheduler
|
||||
hook = CyclicMomentumUpdaterHook(
|
||||
hook_cfg = dict(
|
||||
type='CyclicMomentumUpdaterHook',
|
||||
by_epoch=False,
|
||||
target_ratio=(0.85 / 0.95, 1),
|
||||
cyclic_times=1,
|
||||
step_ratio_up=0.4)
|
||||
runner.register_hook(hook)
|
||||
runner.register_hook_from_cfg(hook_cfg)
|
||||
|
||||
# add momentum LR scheduler
|
||||
hook = CyclicLrUpdaterHook(
|
||||
hook_cfg = dict(
|
||||
type='CyclicLrUpdaterHook',
|
||||
by_epoch=False,
|
||||
target_ratio=(10, 1),
|
||||
cyclic_times=1,
|
||||
step_ratio_up=0.4)
|
||||
runner.register_hook(hook)
|
||||
runner.register_hook(IterTimerHook())
|
||||
runner.register_hook_from_cfg(hook_cfg)
|
||||
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
|
||||
|
||||
# add pavi hook
|
||||
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
||||
|
@ -101,19 +99,25 @@ def test_cosine_runner_hook():
|
|||
runner = _build_demo_runner()
|
||||
|
||||
# add momentum scheduler
|
||||
hook = CosineAnnealingMomentumUpdaterHook(
|
||||
|
||||
hook_cfg = dict(
|
||||
type='CosineAnnealingMomentumUpdaterHook',
|
||||
min_momentum_ratio=0.99 / 0.95,
|
||||
by_epoch=False,
|
||||
warmup_iters=2,
|
||||
warmup_ratio=0.9 / 0.95)
|
||||
runner.register_hook(hook)
|
||||
runner.register_hook_from_cfg(hook_cfg)
|
||||
|
||||
# add momentum LR scheduler
|
||||
hook = CosineAnnealingLrUpdaterHook(
|
||||
by_epoch=False, min_lr_ratio=0, warmup_iters=2, warmup_ratio=0.9)
|
||||
runner.register_hook(hook)
|
||||
hook_cfg = dict(
|
||||
type='CosineAnnealingLrUpdaterHook',
|
||||
by_epoch=False,
|
||||
min_lr_ratio=0,
|
||||
warmup_iters=2,
|
||||
warmup_ratio=0.9)
|
||||
runner.register_hook_from_cfg(hook_cfg)
|
||||
runner.register_hook_from_cfg(dict(type='IterTimerHook'))
|
||||
runner.register_hook(IterTimerHook())
|
||||
|
||||
# add pavi hook
|
||||
hook = PaviLoggerHook(interval=1, add_graph=False, add_last_ckpt=True)
|
||||
runner.register_hook(hook)
|
||||
|
|
Loading…
Reference in New Issue