diff --git a/mmcv/runner/runner.py b/mmcv/runner/runner.py index c29362929..403cdb22f 100644 --- a/mmcv/runner/runner.py +++ b/mmcv/runner/runner.py @@ -388,7 +388,15 @@ class Runner(object): def register_lr_hook(self, lr_config): if isinstance(lr_config, dict): assert 'policy' in lr_config - hook_type = lr_config.pop('policy').title() + 'LrUpdaterHook' + policy_type = lr_config.pop('policy') + # If the type of policy is all in lower case, e.g., 'cyclic', + # then its first letter will be capitalized, e.g., to be 'Cyclic'. + # This is for the convenient usage of Lr updater updater. + # Since this is not applicable for `CosineAnealingLrUpdater`, + # the string will not be changed if it contains capital letters. + if policy_type == policy_type.lower(): + policy_type = policy_type.title() + hook_type = policy_type + 'LrUpdaterHook' lr_config['type'] = hook_type hook = mmcv.build_from_cfg(lr_config, HOOKS) else: @@ -415,13 +423,20 @@ class Runner(object): hook = checkpoint_config self.register_hook(hook) - def register_momentum_hooks(self, momentum_config): + def register_momentum_hook(self, momentum_config): if momentum_config is None: return if isinstance(momentum_config, dict): assert 'policy' in momentum_config - hook_type = momentum_config.pop( - 'policy').title() + 'MomentumUpdaterHook' + policy_type = momentum_config.pop('policy') + # If the type of policy is all in lower case, e.g., 'cyclic', + # then its first letter will be capitalized, e.g., to be 'Cyclic'. + # This is for the convenient usage of momentum updater. + # Since this is not applicable for `CosineAnealingMomentumUpdater`, + # the string will not be changed if it contains capital letters. + if policy_type == policy_type.lower(): + policy_type = policy_type.title() + hook_type = policy_type + 'MomentumUpdaterHook' momentum_config['type'] = hook_type hook = mmcv.build_from_cfg(momentum_config, HOOKS) else: @@ -453,7 +468,7 @@ class Runner(object): - LoggerHook(s) """ self.register_lr_hook(lr_config) - self.register_momentum_hooks(momentum_config) + self.register_momentum_hook(momentum_config) self.register_optimizer_hook(optimizer_config) self.register_checkpoint_hook(checkpoint_config) self.register_hook(IterTimerHook()) diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 649c4b933..811719529 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -30,3 +30,95 @@ def test_save_checkpoint(): assert osp.realpath(latest_path) == osp.realpath(epoch1_path) torch.load(latest_path) + + +def test_build_lr_momentum_hook(): + try: + from torch import nn + except ImportError: + warnings.warn('Skipping test_save_checkpoint in the absense of torch') + return + import mmcv.runner + model = nn.Linear(1, 1) + runner = mmcv.runner.Runner( + model=model, batch_processor=lambda x: x, logger=logging.getLogger()) + + # test policy that is already title + lr_config = dict( + policy='CosineAnealing', + by_epoch=False, + min_lr_ratio=0, + warmup_iters=2, + warmup_ratio=0.9) + runner.register_lr_hook(lr_config) + assert len(runner.hooks) == 1 + + # test policy that is already title + lr_config = dict( + policy='Cyclic', + by_epoch=False, + target_ratio=(10, 1), + cyclic_times=1, + step_ratio_up=0.4) + runner.register_lr_hook(lr_config) + assert len(runner.hooks) == 2 + + # test policy that is not title + lr_config = dict( + policy='cyclic', + by_epoch=False, + target_ratio=(0.85 / 0.95, 1), + cyclic_times=1, + step_ratio_up=0.4) + runner.register_lr_hook(lr_config) + assert len(runner.hooks) == 3 + + # test policy that is title + lr_config = dict( + policy='Step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) + runner.register_lr_hook(lr_config) + assert len(runner.hooks) == 4 + + # test policy that is not title + lr_config = dict( + policy='step', + warmup='linear', + warmup_iters=500, + warmup_ratio=1.0 / 3, + step=[8, 11]) + runner.register_lr_hook(lr_config) + assert len(runner.hooks) == 5 + + # test policy that is already title + mom_config = dict( + policy='CosineAnealing', + min_momentum_ratio=0.99 / 0.95, + by_epoch=False, + warmup_iters=2, + warmup_ratio=0.9 / 0.95) + runner.register_momentum_hook(mom_config) + assert len(runner.hooks) == 6 + + # test policy that is already title + mom_config = dict( + policy='Cyclic', + by_epoch=False, + target_ratio=(0.85 / 0.95, 1), + cyclic_times=1, + step_ratio_up=0.4) + runner.register_momentum_hook(mom_config) + assert len(runner.hooks) == 7 + + # test policy that is already title + mom_config = dict( + policy='cyclic', + by_epoch=False, + target_ratio=(0.85 / 0.95, 1), + cyclic_times=1, + step_ratio_up=0.4) + runner.register_momentum_hook(mom_config) + assert len(runner.hooks) == 8