From 661e75906396eaaf1e09dab979685b7e9ee9c752 Mon Sep 17 00:00:00 2001 From: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Date: Wed, 27 Apr 2022 19:45:27 +0800 Subject: [PATCH] [Fix] param_scheduler can not None when training models (#208) * [Fix] param_scheduler can not None when training models * update unit tests * fix unit tests * refactor ParamSchedulerHook * refactor unit tests * param_schedulers can be an empty list --- mmengine/runner/runner.py | 10 +++++++--- tests/test_runner/test_runner.py | 7 +++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index 5b1dd4c2..b8169545 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -261,7 +261,10 @@ class Runner: raise ValueError( 'param_scheduler should be None when optimizer is None, ' f'but got {param_scheduler}') - if not isinstance(param_scheduler, Sequence): + + if param_scheduler is None: + self.param_schedulers = [] + elif not isinstance(param_scheduler, Sequence): self.param_schedulers = [param_scheduler] else: self.param_schedulers = param_scheduler @@ -1022,8 +1025,9 @@ class Runner: # because the latter depends on the former self.optimizer = self.build_optimizer(self.optimizer) - self.param_schedulers = self.build_param_scheduler( # type: ignore - self.param_schedulers) # type: ignore + if self.param_schedulers: + self.param_schedulers = self.build_param_scheduler( # type: ignore + self.param_schedulers) # type: ignore return loop # type: ignore diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index d2750e94..9a11720c 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -263,6 +263,7 @@ class TestRunner(TestCase): cfg.pop('param_scheduler') runner = Runner(**cfg) self.assertIsInstance(runner, Runner) + self.assertEqual(runner.param_schedulers, []) # param_scheduler should be None when optimizer is None cfg = copy.deepcopy(self.epoch_based_cfg) @@ -681,6 +682,12 @@ class TestRunner(TestCase): # input is a Loop object self.assertEqual(id(runner.build_train_loop(loop)), id(loop)) + # param_schedulers can be [] + cfg = dict(type='EpochBasedTrainLoop', max_epochs=3) + runner.param_schedulers = [] + loop = runner.build_train_loop(cfg) + self.assertIsInstance(loop, EpochBasedTrainLoop) + # test custom training loop cfg = dict(type='CustomTrainLoop', max_epochs=3) loop = runner.build_train_loop(cfg)