[Fix] param_scheduler can not None when training models (#208)

* [Fix] param_scheduler can not None when training models

* update unit tests

* fix unit tests

* refactor ParamSchedulerHook

* refactor unit tests

* param_schedulers can be an empty list
pull/216/head
Zaida Zhou 2022-04-27 19:45:27 +08:00 committed by GitHub
parent 6996bdc892
commit 661e759063
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 3 deletions

View File

@ -261,7 +261,10 @@ class Runner:
raise ValueError(
'param_scheduler should be None when optimizer is None, '
f'but got {param_scheduler}')
if not isinstance(param_scheduler, Sequence):
if param_scheduler is None:
self.param_schedulers = []
elif not isinstance(param_scheduler, Sequence):
self.param_schedulers = [param_scheduler]
else:
self.param_schedulers = param_scheduler
@ -1022,8 +1025,9 @@ class Runner:
# because the latter depends on the former
self.optimizer = self.build_optimizer(self.optimizer)
self.param_schedulers = self.build_param_scheduler( # type: ignore
self.param_schedulers) # type: ignore
if self.param_schedulers:
self.param_schedulers = self.build_param_scheduler( # type: ignore
self.param_schedulers) # type: ignore
return loop # type: ignore

View File

@ -263,6 +263,7 @@ class TestRunner(TestCase):
cfg.pop('param_scheduler')
runner = Runner(**cfg)
self.assertIsInstance(runner, Runner)
self.assertEqual(runner.param_schedulers, [])
# param_scheduler should be None when optimizer is None
cfg = copy.deepcopy(self.epoch_based_cfg)
@ -681,6 +682,12 @@ class TestRunner(TestCase):
# input is a Loop object
self.assertEqual(id(runner.build_train_loop(loop)), id(loop))
# param_schedulers can be []
cfg = dict(type='EpochBasedTrainLoop', max_epochs=3)
runner.param_schedulers = []
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, EpochBasedTrainLoop)
# test custom training loop
cfg = dict(type='CustomTrainLoop', max_epochs=3)
loop = runner.build_train_loop(cfg)