[Refactor] Refactor unit test of ParamSchedulerHook (#809)

* Refactor unit test of param_schemeduler hook

* Refactor unit test of param_schemeduler hook
This commit is contained in:
Mashiro 2022-12-30 10:49:43 +08:00 committed by Zaida Zhou
parent 29f399441f
commit aa69ba1a86

View File

@ -1,19 +1,19 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
from unittest.mock import Mock
import pytest
from mmengine.hooks import ParamSchedulerHook
from mmengine.optim import _ParamScheduler
from mmengine.testing import RunnerTestCase
class TestParamSchedulerHook:
class TestParamSchedulerHook(RunnerTestCase):
error_msg = ('runner.param_schedulers should be list of ParamScheduler or '
'a dict containing list of ParamScheduler')
def test_after_train_iter(self):
# runner.param_schedulers should be a list or dict
with pytest.raises(TypeError, match=self.error_msg):
with self.assertRaisesRegex(TypeError, self.error_msg):
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
@ -48,7 +48,7 @@ class TestParamSchedulerHook:
def test_after_train_epoch(self):
# runner.param_schedulers should be a list or dict
with pytest.raises(TypeError, match=self.error_msg):
with self.assertRaisesRegex(TypeError, self.error_msg):
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
@ -93,7 +93,7 @@ class TestParamSchedulerHook:
pass
# runner.param_schedulers should be a list or dict
with pytest.raises(TypeError, match=self.error_msg):
with self.assertRaisesRegex(TypeError, self.error_msg):
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
@ -127,3 +127,60 @@ class TestParamSchedulerHook:
hook.after_val_epoch(runner, metrics)
scheduler1.step.assert_called_with(metrics)
scheduler2.step.assert_called_with(metrics)
def test_with_runner(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.train_cfg.max_epochs = 3
cfg.param_scheduler = [
dict(
type='ConstantLR',
factor=0.5,
begin=0,
),
dict(
type='ConstantLR',
factor=0.5,
begin=1,
)
]
init_lr = cfg.optim_wrapper.optimizer.lr
runner = self.build_runner(cfg)
runner.train()
# Length of train log is 4
# Learning rate of the first epoch is init_lr*0.5
# Learning rate of the second epoch is init_lr*0.5*0.5
# Learning rate of the last epoch will be reset to 0.1
train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history)
target_lr = [init_lr * 0.5] * 4 + \
[init_lr * 0.5 * 0.5] * 4 + \
[init_lr] * 4
self.assertListEqual(train_lr, target_lr)
cfg = copy.deepcopy(self.iter_based_cfg)
cfg.param_scheduler = [
dict(
type='ConstantLR',
factor=0.5,
begin=0,
by_epoch=False,
),
dict(
type='ConstantLR',
factor=0.5,
begin=4,
by_epoch=False,
)
]
init_lr = cfg.optim_wrapper.optimizer.lr
runner = self.build_runner(cfg)
runner.train()
# Learning rate of 1-4 iteration is init_lr*0.5
# Learning rate of 5-11 iteration is init_lr*0.5*0.5
train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history)
target_lr = [init_lr * 0.5] * 4 + \
[init_lr * 0.5 * 0.5] * 7 + \
[init_lr]
self.assertListEqual(train_lr, target_lr)