mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor unit test of ParamSchedulerHook (#809)
* Refactor unit test of param_schemeduler hook * Refactor unit test of param_schemeduler hook
This commit is contained in:
parent
29f399441f
commit
aa69ba1a86
@ -1,19 +1,19 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
|
||||
from mmengine.hooks import ParamSchedulerHook
|
||||
from mmengine.optim import _ParamScheduler
|
||||
from mmengine.testing import RunnerTestCase
|
||||
|
||||
|
||||
class TestParamSchedulerHook:
|
||||
class TestParamSchedulerHook(RunnerTestCase):
|
||||
error_msg = ('runner.param_schedulers should be list of ParamScheduler or '
|
||||
'a dict containing list of ParamScheduler')
|
||||
|
||||
def test_after_train_iter(self):
|
||||
# runner.param_schedulers should be a list or dict
|
||||
with pytest.raises(TypeError, match=self.error_msg):
|
||||
with self.assertRaisesRegex(TypeError, self.error_msg):
|
||||
hook = ParamSchedulerHook()
|
||||
runner = Mock()
|
||||
scheduler = Mock()
|
||||
@ -48,7 +48,7 @@ class TestParamSchedulerHook:
|
||||
|
||||
def test_after_train_epoch(self):
|
||||
# runner.param_schedulers should be a list or dict
|
||||
with pytest.raises(TypeError, match=self.error_msg):
|
||||
with self.assertRaisesRegex(TypeError, self.error_msg):
|
||||
hook = ParamSchedulerHook()
|
||||
runner = Mock()
|
||||
scheduler = Mock()
|
||||
@ -93,7 +93,7 @@ class TestParamSchedulerHook:
|
||||
pass
|
||||
|
||||
# runner.param_schedulers should be a list or dict
|
||||
with pytest.raises(TypeError, match=self.error_msg):
|
||||
with self.assertRaisesRegex(TypeError, self.error_msg):
|
||||
hook = ParamSchedulerHook()
|
||||
runner = Mock()
|
||||
scheduler = Mock()
|
||||
@ -127,3 +127,60 @@ class TestParamSchedulerHook:
|
||||
hook.after_val_epoch(runner, metrics)
|
||||
scheduler1.step.assert_called_with(metrics)
|
||||
scheduler2.step.assert_called_with(metrics)
|
||||
|
||||
def test_with_runner(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.train_cfg.max_epochs = 3
|
||||
cfg.param_scheduler = [
|
||||
dict(
|
||||
type='ConstantLR',
|
||||
factor=0.5,
|
||||
begin=0,
|
||||
),
|
||||
dict(
|
||||
type='ConstantLR',
|
||||
factor=0.5,
|
||||
begin=1,
|
||||
)
|
||||
]
|
||||
init_lr = cfg.optim_wrapper.optimizer.lr
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
|
||||
# Length of train log is 4
|
||||
# Learning rate of the first epoch is init_lr*0.5
|
||||
# Learning rate of the second epoch is init_lr*0.5*0.5
|
||||
# Learning rate of the last epoch will be reset to 0.1
|
||||
train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history)
|
||||
target_lr = [init_lr * 0.5] * 4 + \
|
||||
[init_lr * 0.5 * 0.5] * 4 + \
|
||||
[init_lr] * 4
|
||||
self.assertListEqual(train_lr, target_lr)
|
||||
|
||||
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||
cfg.param_scheduler = [
|
||||
dict(
|
||||
type='ConstantLR',
|
||||
factor=0.5,
|
||||
begin=0,
|
||||
by_epoch=False,
|
||||
),
|
||||
dict(
|
||||
type='ConstantLR',
|
||||
factor=0.5,
|
||||
begin=4,
|
||||
by_epoch=False,
|
||||
)
|
||||
]
|
||||
|
||||
init_lr = cfg.optim_wrapper.optimizer.lr
|
||||
runner = self.build_runner(cfg)
|
||||
runner.train()
|
||||
|
||||
# Learning rate of 1-4 iteration is init_lr*0.5
|
||||
# Learning rate of 5-11 iteration is init_lr*0.5*0.5
|
||||
train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history)
|
||||
target_lr = [init_lr * 0.5] * 4 + \
|
||||
[init_lr * 0.5 * 0.5] * 7 + \
|
||||
[init_lr]
|
||||
self.assertListEqual(train_lr, target_lr)
|
||||
|
Loading…
x
Reference in New Issue
Block a user