mmengine/tests/test_hook/test_param_scheduler_hook.py
Yuan Liu 755f8b5b59
[Refactor]: Change scheduler to param_scheduler (#121)
* [Refactor]: Change scheduler to param_scheduler

* [Fix]: Fix UT of param scheduler hook

Co-authored-by: Your <you@example.com>
2022-03-12 10:47:06 +08:00

28 lines
777 B
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest.mock import Mock
from mmengine.hooks import ParamSchedulerHook
class TestParamSchedulerHook:
def test_after_iter(self):
Hook = ParamSchedulerHook()
Runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = False
Runner.param_schedulers = [scheduler]
Hook.after_train_iter(Runner)
scheduler.step.assert_called()
def test_after_epoch(self):
Hook = ParamSchedulerHook()
Runner = Mock()
scheduler = Mock()
scheduler.step = Mock()
scheduler.by_epoch = True
Runner.param_schedulers = [scheduler]
Hook.after_train_epoch(Runner)
scheduler.step.assert_called()