mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Refactor] Refactor unit test of ParamSchedulerHook (#809)
* Refactor unit test of param_schemeduler hook * Refactor unit test of param_schemeduler hook
This commit is contained in:
parent
29f399441f
commit
aa69ba1a86
@ -1,19 +1,19 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import copy
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock
|
||||||
|
|
||||||
import pytest
|
|
||||||
|
|
||||||
from mmengine.hooks import ParamSchedulerHook
|
from mmengine.hooks import ParamSchedulerHook
|
||||||
from mmengine.optim import _ParamScheduler
|
from mmengine.optim import _ParamScheduler
|
||||||
|
from mmengine.testing import RunnerTestCase
|
||||||
|
|
||||||
|
|
||||||
class TestParamSchedulerHook:
|
class TestParamSchedulerHook(RunnerTestCase):
|
||||||
error_msg = ('runner.param_schedulers should be list of ParamScheduler or '
|
error_msg = ('runner.param_schedulers should be list of ParamScheduler or '
|
||||||
'a dict containing list of ParamScheduler')
|
'a dict containing list of ParamScheduler')
|
||||||
|
|
||||||
def test_after_train_iter(self):
|
def test_after_train_iter(self):
|
||||||
# runner.param_schedulers should be a list or dict
|
# runner.param_schedulers should be a list or dict
|
||||||
with pytest.raises(TypeError, match=self.error_msg):
|
with self.assertRaisesRegex(TypeError, self.error_msg):
|
||||||
hook = ParamSchedulerHook()
|
hook = ParamSchedulerHook()
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
scheduler = Mock()
|
scheduler = Mock()
|
||||||
@ -48,7 +48,7 @@ class TestParamSchedulerHook:
|
|||||||
|
|
||||||
def test_after_train_epoch(self):
|
def test_after_train_epoch(self):
|
||||||
# runner.param_schedulers should be a list or dict
|
# runner.param_schedulers should be a list or dict
|
||||||
with pytest.raises(TypeError, match=self.error_msg):
|
with self.assertRaisesRegex(TypeError, self.error_msg):
|
||||||
hook = ParamSchedulerHook()
|
hook = ParamSchedulerHook()
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
scheduler = Mock()
|
scheduler = Mock()
|
||||||
@ -93,7 +93,7 @@ class TestParamSchedulerHook:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
# runner.param_schedulers should be a list or dict
|
# runner.param_schedulers should be a list or dict
|
||||||
with pytest.raises(TypeError, match=self.error_msg):
|
with self.assertRaisesRegex(TypeError, self.error_msg):
|
||||||
hook = ParamSchedulerHook()
|
hook = ParamSchedulerHook()
|
||||||
runner = Mock()
|
runner = Mock()
|
||||||
scheduler = Mock()
|
scheduler = Mock()
|
||||||
@ -127,3 +127,60 @@ class TestParamSchedulerHook:
|
|||||||
hook.after_val_epoch(runner, metrics)
|
hook.after_val_epoch(runner, metrics)
|
||||||
scheduler1.step.assert_called_with(metrics)
|
scheduler1.step.assert_called_with(metrics)
|
||||||
scheduler2.step.assert_called_with(metrics)
|
scheduler2.step.assert_called_with(metrics)
|
||||||
|
|
||||||
|
def test_with_runner(self):
|
||||||
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
cfg.train_cfg.max_epochs = 3
|
||||||
|
cfg.param_scheduler = [
|
||||||
|
dict(
|
||||||
|
type='ConstantLR',
|
||||||
|
factor=0.5,
|
||||||
|
begin=0,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
type='ConstantLR',
|
||||||
|
factor=0.5,
|
||||||
|
begin=1,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
init_lr = cfg.optim_wrapper.optimizer.lr
|
||||||
|
runner = self.build_runner(cfg)
|
||||||
|
runner.train()
|
||||||
|
|
||||||
|
# Length of train log is 4
|
||||||
|
# Learning rate of the first epoch is init_lr*0.5
|
||||||
|
# Learning rate of the second epoch is init_lr*0.5*0.5
|
||||||
|
# Learning rate of the last epoch will be reset to 0.1
|
||||||
|
train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history)
|
||||||
|
target_lr = [init_lr * 0.5] * 4 + \
|
||||||
|
[init_lr * 0.5 * 0.5] * 4 + \
|
||||||
|
[init_lr] * 4
|
||||||
|
self.assertListEqual(train_lr, target_lr)
|
||||||
|
|
||||||
|
cfg = copy.deepcopy(self.iter_based_cfg)
|
||||||
|
cfg.param_scheduler = [
|
||||||
|
dict(
|
||||||
|
type='ConstantLR',
|
||||||
|
factor=0.5,
|
||||||
|
begin=0,
|
||||||
|
by_epoch=False,
|
||||||
|
),
|
||||||
|
dict(
|
||||||
|
type='ConstantLR',
|
||||||
|
factor=0.5,
|
||||||
|
begin=4,
|
||||||
|
by_epoch=False,
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
init_lr = cfg.optim_wrapper.optimizer.lr
|
||||||
|
runner = self.build_runner(cfg)
|
||||||
|
runner.train()
|
||||||
|
|
||||||
|
# Learning rate of 1-4 iteration is init_lr*0.5
|
||||||
|
# Learning rate of 5-11 iteration is init_lr*0.5*0.5
|
||||||
|
train_lr = list(runner.message_hub.get_scalar('train/lr')._log_history)
|
||||||
|
target_lr = [init_lr * 0.5] * 4 + \
|
||||||
|
[init_lr * 0.5 * 0.5] * 7 + \
|
||||||
|
[init_lr]
|
||||||
|
self.assertListEqual(train_lr, target_lr)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user