mmengine/tests/test_hooks/test_param_scheduler_hook.py
LEFTeyes 0b59a90a21
[Feature] Support ReduceOnPlateauParamScheduler(#819)
* [Feature] Add ReduceOnPlateauParamScheduler and change ParamSchedulerHook

* [Feature] add ReduceOnPlateauLR and ReduceOnPlateauMomentum

* pre-commit check

* add a little docs

* change position

* fix the conflict between isort and yapf

* fix ParamSchedulerHook after_val_epoch execute without train_loop and param_schedulers built

* Apply suggestions from code review

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>

* update ReduceOnPlateauParamScheduler, ReduceOnPlateauMomentum and ParamSchedulerHook

* fix get need_step_args attribute error in ParamSchedulerHook

* fix load_state_dict error for rule in ReduceOnPlateauParamScheduler

* add docs for ParamSchedulerHook and fix a few codes

* [Docs] add ReduceOnPlateauParamScheduler, ReduceOnPlateauMomentum and ReduceOnPlateauLR docs

* [Refactor] adjust the order of import

* [Fix] add init check for threshold in ReduceOnPlateauParamScheduler

* [Test] add test for ReduceOnPlateauParamScheduler, ReduceOnPlateauLR and ReduceOnPlateauMomentum

* [Fix] fix no attribute self.min_value

* [Fix] fix numerical problem in tests

* [Fix] fix error in tests

* [Fix] fix ignore first param in tests

* [Fix] fix bug in tests

* [Fix] fix bug in tests

* [Fix] fix bug in tests

* [Fix] increase coverage

* [Fix] fix count self._global_step bug and docs

* [Fix] fix tests

* [Fix] modified ParamSchedulerHook test

* Update mmengine/optim/scheduler/param_scheduler.py

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>

* Apply suggestions from code review

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>

* [Fix] modified something according to commented

* [Docs] add api for en and zh_cn

* [Fix] fix bug in test_param_scheduler_hook.py

* [Test] support more complicated test modes(less, greater, rel, abs) for ReduceOnPlateauParamScheduler

* [Docs] add docs for rule

* [Fix] fix pop from empty list bug in test

* [Fix] fix check param_schedulers is not built bug

* [Fix] fix step_args bug and without runner._train_loop bug

* [Fix] fix step_args bug and without runner._train_loop bug

* [Fix] fix scheduler type bug

* [Test] rename step_args to step_kwargs

* [Fix] remove redundancy check

* [Test] remove redundancy check

* Apply suggestions from code review

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* [Test] fix some defects

Co-authored-by: Mashiro <57566630+HAOCHENYE@users.noreply.github.com>
Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
2023-01-16 11:39:03 +08:00

130 lines
4.5 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
import pytest
from mmengine.hooks import ParamSchedulerHook
from mmengine.optim import _ParamScheduler
class TestParamSchedulerHook:
error_msg = ('runner.param_schedulers should be list of ParamScheduler or '
'a dict containing list of ParamScheduler')
def test_after_train_iter(self):
# runner.param_schedulers should be a list or dict
with pytest.raises(TypeError, match=self.error_msg):
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = False
runner.param_schedulers = scheduler
hook.after_train_iter(runner, 0)
scheduler.step.assert_called()
# runner.param_schedulers is a list of schedulers
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = False
runner.param_schedulers = [scheduler]
hook.after_train_iter(runner, 0)
scheduler.step.assert_called()
# runner.param_schedulers is a dict containing list of schedulers
scheduler1 = Mock()
scheduler1.step = Mock()
scheduler1.by_epoch = False
scheduler2 = Mock()
scheduler2.step = Mock()
scheduler2.by_epoch = False
runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2])
hook.after_train_epoch(runner)
hook.after_train_iter(runner, 0)
scheduler1.step.assert_called()
scheduler2.step.assert_called()
def test_after_train_epoch(self):
# runner.param_schedulers should be a list or dict
with pytest.raises(TypeError, match=self.error_msg):
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = True
runner.param_schedulers = scheduler
hook.after_train_epoch(runner)
scheduler.step.assert_called()
# runner.param_schedulers is a list of schedulers
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = True
runner.param_schedulers = [scheduler]
hook.after_train_epoch(runner)
scheduler.step.assert_called()
# runner.param_schedulers is a dict containing list of schedulers
scheduler1 = Mock()
scheduler1.step = Mock()
scheduler1.by_epoch = True
scheduler2 = Mock()
scheduler2.step = Mock()
scheduler2.by_epoch = True
runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2])
hook.after_train_epoch(runner)
scheduler1.step.assert_called()
scheduler2.step.assert_called()
def test_after_val_epoch(self):
metrics = dict(loss=1.0)
# mock super _ParamScheduler class
class MockParamScheduler(_ParamScheduler):
def __init__(self):
pass
def _get_value(self):
pass
# runner.param_schedulers should be a list or dict
with pytest.raises(TypeError, match=self.error_msg):
hook = ParamSchedulerHook()
runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = True
scheduler.need_val_args = True
runner.param_schedulers = scheduler
hook.after_val_epoch(runner, metrics)
# runner.param_schedulers is a list of schedulers
hook = ParamSchedulerHook()
runner = Mock()
scheduler = MockParamScheduler()
scheduler.step = Mock()
scheduler.by_epoch = True
scheduler.need_val_args = True
runner.param_schedulers = [scheduler]
hook.after_val_epoch(runner, metrics)
scheduler.step.assert_called_with(metrics)
# runner.param_schedulers is a dict containing list of schedulers
scheduler1 = MockParamScheduler()
scheduler1.step = Mock()
scheduler1.by_epoch = True
scheduler1.need_val_args = True
scheduler2 = MockParamScheduler()
scheduler2.step = Mock()
scheduler2.by_epoch = True
scheduler2.need_val_args = True
runner.param_schedulers = dict(key1=[scheduler1], key2=[scheduler2])
hook.after_val_epoch(runner, metrics)
scheduler1.step.assert_called_with(metrics)
scheduler2.step.assert_called_with(metrics)