mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
Fix CosineAnealingLr register bug (#265)
* Fall back to CosineLr * Fix consineanealing with unittest * Cover momentum hook * Add comments to explain
This commit is contained in:
parent
65fbc75689
commit
19e4a06cbc
@ -388,7 +388,15 @@ class Runner(object):
|
||||
def register_lr_hook(self, lr_config):
|
||||
if isinstance(lr_config, dict):
|
||||
assert 'policy' in lr_config
|
||||
hook_type = lr_config.pop('policy').title() + 'LrUpdaterHook'
|
||||
policy_type = lr_config.pop('policy')
|
||||
# If the type of policy is all in lower case, e.g., 'cyclic',
|
||||
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
|
||||
# This is for the convenient usage of Lr updater updater.
|
||||
# Since this is not applicable for `CosineAnealingLrUpdater`,
|
||||
# the string will not be changed if it contains capital letters.
|
||||
if policy_type == policy_type.lower():
|
||||
policy_type = policy_type.title()
|
||||
hook_type = policy_type + 'LrUpdaterHook'
|
||||
lr_config['type'] = hook_type
|
||||
hook = mmcv.build_from_cfg(lr_config, HOOKS)
|
||||
else:
|
||||
@ -415,13 +423,20 @@ class Runner(object):
|
||||
hook = checkpoint_config
|
||||
self.register_hook(hook)
|
||||
|
||||
def register_momentum_hooks(self, momentum_config):
|
||||
def register_momentum_hook(self, momentum_config):
|
||||
if momentum_config is None:
|
||||
return
|
||||
if isinstance(momentum_config, dict):
|
||||
assert 'policy' in momentum_config
|
||||
hook_type = momentum_config.pop(
|
||||
'policy').title() + 'MomentumUpdaterHook'
|
||||
policy_type = momentum_config.pop('policy')
|
||||
# If the type of policy is all in lower case, e.g., 'cyclic',
|
||||
# then its first letter will be capitalized, e.g., to be 'Cyclic'.
|
||||
# This is for the convenient usage of momentum updater.
|
||||
# Since this is not applicable for `CosineAnealingMomentumUpdater`,
|
||||
# the string will not be changed if it contains capital letters.
|
||||
if policy_type == policy_type.lower():
|
||||
policy_type = policy_type.title()
|
||||
hook_type = policy_type + 'MomentumUpdaterHook'
|
||||
momentum_config['type'] = hook_type
|
||||
hook = mmcv.build_from_cfg(momentum_config, HOOKS)
|
||||
else:
|
||||
@ -453,7 +468,7 @@ class Runner(object):
|
||||
- LoggerHook(s)
|
||||
"""
|
||||
self.register_lr_hook(lr_config)
|
||||
self.register_momentum_hooks(momentum_config)
|
||||
self.register_momentum_hook(momentum_config)
|
||||
self.register_optimizer_hook(optimizer_config)
|
||||
self.register_checkpoint_hook(checkpoint_config)
|
||||
self.register_hook(IterTimerHook())
|
||||
|
@ -30,3 +30,95 @@ def test_save_checkpoint():
|
||||
assert osp.realpath(latest_path) == osp.realpath(epoch1_path)
|
||||
|
||||
torch.load(latest_path)
|
||||
|
||||
|
||||
def test_build_lr_momentum_hook():
|
||||
try:
|
||||
from torch import nn
|
||||
except ImportError:
|
||||
warnings.warn('Skipping test_save_checkpoint in the absense of torch')
|
||||
return
|
||||
import mmcv.runner
|
||||
model = nn.Linear(1, 1)
|
||||
runner = mmcv.runner.Runner(
|
||||
model=model, batch_processor=lambda x: x, logger=logging.getLogger())
|
||||
|
||||
# test policy that is already title
|
||||
lr_config = dict(
|
||||
policy='CosineAnealing',
|
||||
by_epoch=False,
|
||||
min_lr_ratio=0,
|
||||
warmup_iters=2,
|
||||
warmup_ratio=0.9)
|
||||
runner.register_lr_hook(lr_config)
|
||||
assert len(runner.hooks) == 1
|
||||
|
||||
# test policy that is already title
|
||||
lr_config = dict(
|
||||
policy='Cyclic',
|
||||
by_epoch=False,
|
||||
target_ratio=(10, 1),
|
||||
cyclic_times=1,
|
||||
step_ratio_up=0.4)
|
||||
runner.register_lr_hook(lr_config)
|
||||
assert len(runner.hooks) == 2
|
||||
|
||||
# test policy that is not title
|
||||
lr_config = dict(
|
||||
policy='cyclic',
|
||||
by_epoch=False,
|
||||
target_ratio=(0.85 / 0.95, 1),
|
||||
cyclic_times=1,
|
||||
step_ratio_up=0.4)
|
||||
runner.register_lr_hook(lr_config)
|
||||
assert len(runner.hooks) == 3
|
||||
|
||||
# test policy that is title
|
||||
lr_config = dict(
|
||||
policy='Step',
|
||||
warmup='linear',
|
||||
warmup_iters=500,
|
||||
warmup_ratio=1.0 / 3,
|
||||
step=[8, 11])
|
||||
runner.register_lr_hook(lr_config)
|
||||
assert len(runner.hooks) == 4
|
||||
|
||||
# test policy that is not title
|
||||
lr_config = dict(
|
||||
policy='step',
|
||||
warmup='linear',
|
||||
warmup_iters=500,
|
||||
warmup_ratio=1.0 / 3,
|
||||
step=[8, 11])
|
||||
runner.register_lr_hook(lr_config)
|
||||
assert len(runner.hooks) == 5
|
||||
|
||||
# test policy that is already title
|
||||
mom_config = dict(
|
||||
policy='CosineAnealing',
|
||||
min_momentum_ratio=0.99 / 0.95,
|
||||
by_epoch=False,
|
||||
warmup_iters=2,
|
||||
warmup_ratio=0.9 / 0.95)
|
||||
runner.register_momentum_hook(mom_config)
|
||||
assert len(runner.hooks) == 6
|
||||
|
||||
# test policy that is already title
|
||||
mom_config = dict(
|
||||
policy='Cyclic',
|
||||
by_epoch=False,
|
||||
target_ratio=(0.85 / 0.95, 1),
|
||||
cyclic_times=1,
|
||||
step_ratio_up=0.4)
|
||||
runner.register_momentum_hook(mom_config)
|
||||
assert len(runner.hooks) == 7
|
||||
|
||||
# test policy that is already title
|
||||
mom_config = dict(
|
||||
policy='cyclic',
|
||||
by_epoch=False,
|
||||
target_ratio=(0.85 / 0.95, 1),
|
||||
cyclic_times=1,
|
||||
step_ratio_up=0.4)
|
||||
runner.register_momentum_hook(mom_config)
|
||||
assert len(runner.hooks) == 8
|
||||
|
Loading…
x
Reference in New Issue
Block a user