[Fix] param_scheduler can not None when training models (#208)

* [Fix] param_scheduler can not None when training models

* update unit tests

* fix unit tests

* refactor ParamSchedulerHook

* refactor unit tests

* param_schedulers can be an empty list
pull/216/head
Zaida Zhou 2022-04-27 19:45:27 +08:00 committed by GitHub
parent 6996bdc892
commit 661e759063
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 14 additions and 3 deletions

View File

@ -261,7 +261,10 @@ class Runner:
raise ValueError( raise ValueError(
'param_scheduler should be None when optimizer is None, ' 'param_scheduler should be None when optimizer is None, '
f'but got {param_scheduler}') f'but got {param_scheduler}')
if not isinstance(param_scheduler, Sequence):
if param_scheduler is None:
self.param_schedulers = []
elif not isinstance(param_scheduler, Sequence):
self.param_schedulers = [param_scheduler] self.param_schedulers = [param_scheduler]
else: else:
self.param_schedulers = param_scheduler self.param_schedulers = param_scheduler
@ -1022,8 +1025,9 @@ class Runner:
# because the latter depends on the former # because the latter depends on the former
self.optimizer = self.build_optimizer(self.optimizer) self.optimizer = self.build_optimizer(self.optimizer)
self.param_schedulers = self.build_param_scheduler( # type: ignore if self.param_schedulers:
self.param_schedulers) # type: ignore self.param_schedulers = self.build_param_scheduler( # type: ignore
self.param_schedulers) # type: ignore
return loop # type: ignore return loop # type: ignore

View File

@ -263,6 +263,7 @@ class TestRunner(TestCase):
cfg.pop('param_scheduler') cfg.pop('param_scheduler')
runner = Runner(**cfg) runner = Runner(**cfg)
self.assertIsInstance(runner, Runner) self.assertIsInstance(runner, Runner)
self.assertEqual(runner.param_schedulers, [])
# param_scheduler should be None when optimizer is None # param_scheduler should be None when optimizer is None
cfg = copy.deepcopy(self.epoch_based_cfg) cfg = copy.deepcopy(self.epoch_based_cfg)
@ -681,6 +682,12 @@ class TestRunner(TestCase):
# input is a Loop object # input is a Loop object
self.assertEqual(id(runner.build_train_loop(loop)), id(loop)) self.assertEqual(id(runner.build_train_loop(loop)), id(loop))
# param_schedulers can be []
cfg = dict(type='EpochBasedTrainLoop', max_epochs=3)
runner.param_schedulers = []
loop = runner.build_train_loop(cfg)
self.assertIsInstance(loop, EpochBasedTrainLoop)
# test custom training loop # test custom training loop
cfg = dict(type='CustomTrainLoop', max_epochs=3) cfg = dict(type='CustomTrainLoop', max_epochs=3)
loop = runner.build_train_loop(cfg) loop = runner.build_train_loop(cfg)