[Fix] param_scheduler can not None when training models (#208)
* [Fix] param_scheduler can not None when training models * update unit tests * fix unit tests * refactor ParamSchedulerHook * refactor unit tests * param_schedulers can be an empty listpull/216/head
parent
6996bdc892
commit
661e759063
|
@ -261,7 +261,10 @@ class Runner:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'param_scheduler should be None when optimizer is None, '
|
'param_scheduler should be None when optimizer is None, '
|
||||||
f'but got {param_scheduler}')
|
f'but got {param_scheduler}')
|
||||||
if not isinstance(param_scheduler, Sequence):
|
|
||||||
|
if param_scheduler is None:
|
||||||
|
self.param_schedulers = []
|
||||||
|
elif not isinstance(param_scheduler, Sequence):
|
||||||
self.param_schedulers = [param_scheduler]
|
self.param_schedulers = [param_scheduler]
|
||||||
else:
|
else:
|
||||||
self.param_schedulers = param_scheduler
|
self.param_schedulers = param_scheduler
|
||||||
|
@ -1022,8 +1025,9 @@ class Runner:
|
||||||
# because the latter depends on the former
|
# because the latter depends on the former
|
||||||
self.optimizer = self.build_optimizer(self.optimizer)
|
self.optimizer = self.build_optimizer(self.optimizer)
|
||||||
|
|
||||||
self.param_schedulers = self.build_param_scheduler( # type: ignore
|
if self.param_schedulers:
|
||||||
self.param_schedulers) # type: ignore
|
self.param_schedulers = self.build_param_scheduler( # type: ignore
|
||||||
|
self.param_schedulers) # type: ignore
|
||||||
|
|
||||||
return loop # type: ignore
|
return loop # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -263,6 +263,7 @@ class TestRunner(TestCase):
|
||||||
cfg.pop('param_scheduler')
|
cfg.pop('param_scheduler')
|
||||||
runner = Runner(**cfg)
|
runner = Runner(**cfg)
|
||||||
self.assertIsInstance(runner, Runner)
|
self.assertIsInstance(runner, Runner)
|
||||||
|
self.assertEqual(runner.param_schedulers, [])
|
||||||
|
|
||||||
# param_scheduler should be None when optimizer is None
|
# param_scheduler should be None when optimizer is None
|
||||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||||
|
@ -681,6 +682,12 @@ class TestRunner(TestCase):
|
||||||
# input is a Loop object
|
# input is a Loop object
|
||||||
self.assertEqual(id(runner.build_train_loop(loop)), id(loop))
|
self.assertEqual(id(runner.build_train_loop(loop)), id(loop))
|
||||||
|
|
||||||
|
# param_schedulers can be []
|
||||||
|
cfg = dict(type='EpochBasedTrainLoop', max_epochs=3)
|
||||||
|
runner.param_schedulers = []
|
||||||
|
loop = runner.build_train_loop(cfg)
|
||||||
|
self.assertIsInstance(loop, EpochBasedTrainLoop)
|
||||||
|
|
||||||
# test custom training loop
|
# test custom training loop
|
||||||
cfg = dict(type='CustomTrainLoop', max_epochs=3)
|
cfg = dict(type='CustomTrainLoop', max_epochs=3)
|
||||||
loop = runner.build_train_loop(cfg)
|
loop = runner.build_train_loop(cfg)
|
||||||
|
|
Loading…
Reference in New Issue