# Copyright (c) OpenMMLab. All rights reserved. import copy from unittest.mock import Mock from mmengine.hooks import ParamSchedulerHook from mmengine.optim import _ParamScheduler from mmengine.testing import RunnerTestCase class TestParamSchedulerHook(RunnerTestCase): error_msg = ('runner.param_schedulers should be list of ParamScheduler or ' 'a dict containing list of ParamScheduler') def test_after_train_iter(self): # runner.param_schedulers should be a list or dict with self.assertRaisesRegex(TypeError, self.error_msg): hook = ParamSchedulerHook() runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = False runner.param_schedulers = scheduler hook.after_train_iter(runner, 0) scheduler.step.assert_called() # runner.param_schedulers is a list of schedulers hook = ParamSchedulerHook() runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = False runner.param_schedulers = [scheduler] hook.after_train_iter(runner, 0) scheduler.step.assert_called() # runner.param_schedulers is a dict containing list of schedulers scheduler1 = Mock() scheduler1.step = Mock() scheduler1.by_epoch = False scheduler2 = Mock() scheduler2.step = Mock() scheduler2.by_epoch = False runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2]) hook.after_train_epoch(runner) hook.after_train_iter(runner, 0) scheduler1.step.assert_called() scheduler2.step.assert_called() def test_after_train_epoch(self): # runner.param_schedulers should be a list or dict with self.assertRaisesRegex(TypeError, self.error_msg): hook = ParamSchedulerHook() runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = True runner.param_schedulers = scheduler hook.after_train_epoch(runner) scheduler.step.assert_called() # runner.param_schedulers is a list of schedulers hook = ParamSchedulerHook() runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = True runner.param_schedulers = [scheduler] hook.after_train_epoch(runner) scheduler.step.assert_called() # runner.param_schedulers is a dict containing list of schedulers scheduler1 = Mock() scheduler1.step = Mock() scheduler1.by_epoch = True scheduler2 = Mock() scheduler2.step = Mock() scheduler2.by_epoch = True runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2]) hook.after_train_epoch(runner) scheduler1.step.assert_called() scheduler2.step.assert_called() def test_after_val_epoch(self): metrics = dict(loss=1.0) # mock super _ParamScheduler class class MockParamScheduler(_ParamScheduler): def __init__(self): pass def _get_value(self): pass # runner.param_schedulers should be a list or dict with self.assertRaisesRegex(TypeError, self.error_msg): hook = ParamSchedulerHook() runner = Mock() scheduler = Mock() scheduler.step = Mock() scheduler.by_epoch = True scheduler.need_val_args = True runner.param_schedulers = scheduler hook.after_val_epoch(runner, metrics) # runner.param_schedulers is a list of schedulers hook = ParamSchedulerHook() runner = Mock() scheduler = MockParamScheduler() scheduler.step = Mock() scheduler.by_epoch = True scheduler.need_val_args = True runner.param_schedulers = [scheduler] hook.after_val_epoch(runner, metrics) scheduler.step.assert_called_with(metrics) # runner.param_schedulers is a dict containing list of schedulers scheduler1 = MockParamScheduler() scheduler1.step = Mock() scheduler1.by_epoch = True scheduler1.need_val_args = True scheduler2 = MockParamScheduler() scheduler2.step = Mock() scheduler2.by_epoch = True scheduler2.need_val_args = True runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2]) hook.after_val_epoch(runner, metrics) scheduler1.step.assert_called_with(metrics) scheduler2.step.assert_called_with(metrics) def test_with_runner(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.train_cfg.max_epochs = 3 cfg.param_scheduler = [ dict( type='ConstantLR', factor=0.5, begin=0, ), dict( type='ConstantLR', factor=0.5, begin=1, ) ] init_lr = cfg.optim_wrapper.optimizer.lr runner = self.build_runner(cfg) runner.train() # Length of train log is 4 # Learning rate of the first epoch is init_lr*0.5 # Learning rate of the second epoch is init_lr*0.5*0.5 # Learning rate of the last epoch will be reset to 0.1 train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history) target_lr = [init_lr * 0.5] * 4 + \ [init_lr * 0.5 * 0.5] * 4 + \ [init_lr] * 4 self.assertListEqual(train_lr, target_lr) cfg = copy.deepcopy(self.iter_based_cfg) cfg.param_scheduler = [ dict( type='ConstantLR', factor=0.5, begin=0, by_epoch=False, ), dict( type='ConstantLR', factor=0.5, begin=4, by_epoch=False, ) ] init_lr = cfg.optim_wrapper.optimizer.lr runner = self.build_runner(cfg) runner.train() # Learning rate of 1-4 iteration is init_lr*0.5 # Learning rate of 5-11 iteration is init_lr*0.5*0.5 train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history) target_lr = [init_lr * 0.5] * 4 + \ [init_lr * 0.5 * 0.5] * 7 + \ [init_lr] self.assertListEqual(train_lr, target_lr)