[Fix] param_scheduler can not None when training models (#208)
* [Fix] param_scheduler can not None when training models * update unit tests * fix unit tests * refactor ParamSchedulerHook * refactor unit tests * param_schedulers can be an empty listpull/216/head
parent
6996bdc892
commit
661e759063
|
@ -261,7 +261,10 @@ class Runner:
|
|||
raise ValueError(
|
||||
'param_scheduler should be None when optimizer is None, '
|
||||
f'but got {param_scheduler}')
|
||||
if not isinstance(param_scheduler, Sequence):
|
||||
|
||||
if param_scheduler is None:
|
||||
self.param_schedulers = []
|
||||
elif not isinstance(param_scheduler, Sequence):
|
||||
self.param_schedulers = [param_scheduler]
|
||||
else:
|
||||
self.param_schedulers = param_scheduler
|
||||
|
@ -1022,8 +1025,9 @@ class Runner:
|
|||
# because the latter depends on the former
|
||||
self.optimizer = self.build_optimizer(self.optimizer)
|
||||
|
||||
self.param_schedulers = self.build_param_scheduler( # type: ignore
|
||||
self.param_schedulers) # type: ignore
|
||||
if self.param_schedulers:
|
||||
self.param_schedulers = self.build_param_scheduler( # type: ignore
|
||||
self.param_schedulers) # type: ignore
|
||||
|
||||
return loop # type: ignore
|
||||
|
||||
|
|
|
@ -263,6 +263,7 @@ class TestRunner(TestCase):
|
|||
cfg.pop('param_scheduler')
|
||||
runner = Runner(**cfg)
|
||||
self.assertIsInstance(runner, Runner)
|
||||
self.assertEqual(runner.param_schedulers, [])
|
||||
|
||||
# param_scheduler should be None when optimizer is None
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
|
@ -681,6 +682,12 @@ class TestRunner(TestCase):
|
|||
# input is a Loop object
|
||||
self.assertEqual(id(runner.build_train_loop(loop)), id(loop))
|
||||
|
||||
# param_schedulers can be []
|
||||
cfg = dict(type='EpochBasedTrainLoop', max_epochs=3)
|
||||
runner.param_schedulers = []
|
||||
loop = runner.build_train_loop(cfg)
|
||||
self.assertIsInstance(loop, EpochBasedTrainLoop)
|
||||
|
||||
# test custom training loop
|
||||
cfg = dict(type='CustomTrainLoop', max_epochs=3)
|
||||
loop = runner.build_train_loop(cfg)
|
||||
|
|
Loading…
Reference in New Issue