diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py index 3c774a67..69ba9f28 100644 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -1,18 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List - -import torch - from mmengine.registry import PARAM_SCHEDULERS -from .param_scheduler import (INF, ConstantParamScheduler, +from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, PolyParamScheduler, StepParamScheduler) +class LRSchedulerMixin: + """A mixin class for learning rate schedulers.""" + + def __init__(self, optimizer, *args, **kwargs): + super().__init__(optimizer, 'lr', *args, **kwargs) + + @PARAM_SCHEDULERS.register_module() -class ConstantLR(ConstantParamScheduler): +class ConstantLR(LRSchedulerMixin, ConstantParamScheduler): """Decays the learning rate value of each parameter group by a small constant factor until the number of epoch reaches a pre-defined milestone: ``end``. Notice that such decay can happen simultaneously with other @@ -34,27 +37,9 @@ class ConstantLR(ConstantParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - factor: float = 1.0 / 3, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='lr', - factor=factor, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class CosineAnnealingLR(CosineAnnealingParamScheduler): +class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler): r"""Set the learning rate of each parameter group using a cosine annealing schedule, where :math:`\eta_{max}` is set to the initial value and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: @@ -101,29 +86,9 @@ class CosineAnnealingLR(CosineAnnealingParamScheduler): https://arxiv.org/abs/1608.03983 """ - def __init__(self, - optimizer: torch.optim.Optimizer, - T_max: int, - eta_min: int = 0, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='lr', - T_max=T_max, - eta_min=eta_min, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class ExponentialLR(ExponentialParamScheduler): +class ExponentialLR(LRSchedulerMixin, ExponentialParamScheduler): """Decays the learning rate of each parameter group by gamma every epoch. Args: @@ -141,27 +106,9 @@ class ExponentialLR(ExponentialParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - gamma: float, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='lr', - gamma=gamma, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class LinearLR(LinearParamScheduler): +class LinearLR(LRSchedulerMixin, LinearParamScheduler): """Decays the learning rate of each parameter group by linearly changing small multiplicative factor until the number of epoch reaches a pre-defined milestone: ``end``. @@ -187,29 +134,9 @@ class LinearLR(LinearParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - start_factor: float = 1.0 / 3, - end_factor: float = 1.0, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='lr', - start_factor=start_factor, - end_factor=end_factor, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class MultiStepLR(MultiStepParamScheduler): +class MultiStepLR(LRSchedulerMixin, MultiStepParamScheduler): """Decays the specified learning rate in each parameter group by gamma once the number of epoch reaches one of the milestones. Notice that such decay can happen simultaneously with other changes to the learning rate from @@ -232,29 +159,9 @@ class MultiStepLR(MultiStepParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - milestones: List[int], - gamma: float = 0.1, - last_step: int = -1, - begin: int = 0, - end: int = INF, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='lr', - milestones=milestones, - gamma=gamma, - last_step=last_step, - begin=begin, - end=end, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class StepLR(StepParamScheduler): +class StepLR(LRSchedulerMixin, StepParamScheduler): """Decays the learning rate of each parameter group by gamma every step_size epochs. Notice that such decay can happen simultaneously with other changes to the learning rate from outside this scheduler. @@ -276,29 +183,9 @@ class StepLR(StepParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - step_size: int, - gamma: float = 0.1, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='lr', - step_size=step_size, - gamma=gamma, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class PolyLR(PolyParamScheduler): +class PolyLR(LRSchedulerMixin, PolyParamScheduler): """Decays the learning rate of each parameter group in a polynomial decay scheme. @@ -321,23 +208,3 @@ class PolyLR(PolyParamScheduler): verbose (bool): Whether to print the value for each update. Defaults to False. """ - - def __init__(self, - optimizer: torch.optim.Optimizer, - eta_min: float = 0, - power: float = 1, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='lr', - eta_min=eta_min, - power=power, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py index fa357eb1..5c789b2c 100644 --- a/mmengine/optim/scheduler/momentum_scheduler.py +++ b/mmengine/optim/scheduler/momentum_scheduler.py @@ -1,18 +1,21 @@ # Copyright (c) OpenMMLab. All rights reserved. -from typing import List - -import torch - from mmengine.registry import PARAM_SCHEDULERS -from .param_scheduler import (INF, ConstantParamScheduler, +from .param_scheduler import (ConstantParamScheduler, CosineAnnealingParamScheduler, ExponentialParamScheduler, LinearParamScheduler, MultiStepParamScheduler, PolyParamScheduler, StepParamScheduler) +class MomentumSchedulerMixin: + """A mixin class for momentum schedulers.""" + + def __init__(self, optimizer, *args, **kwargs): + super().__init__(optimizer, 'momentum', *args, **kwargs) + + @PARAM_SCHEDULERS.register_module() -class ConstantMomentum(ConstantParamScheduler): +class ConstantMomentum(MomentumSchedulerMixin, ConstantParamScheduler): """Decays the momentum value of each parameter group by a small constant factor until the number of epoch reaches a pre-defined milestone: ``end``. Notice that such decay can happen simultaneously with other changes to the @@ -34,27 +37,10 @@ class ConstantMomentum(ConstantParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - factor: float = 1.0 / 3, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='momentum', - factor=factor, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class CosineAnnealingMomentum(CosineAnnealingParamScheduler): +class CosineAnnealingMomentum(MomentumSchedulerMixin, + CosineAnnealingParamScheduler): r"""Set the momentum of each parameter group using a cosine annealing schedule, where :math:`\eta_{max}` is set to the initial value and :math:`T_{cur}` is the number of epochs since the last restart in SGDR: @@ -101,29 +87,9 @@ class CosineAnnealingMomentum(CosineAnnealingParamScheduler): https://arxiv.org/abs/1608.03983 """ - def __init__(self, - optimizer: torch.optim.Optimizer, - T_max: int, - eta_min: int = 0, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='momentum', - T_max=T_max, - eta_min=eta_min, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class ExponentialMomentum(ExponentialParamScheduler): +class ExponentialMomentum(MomentumSchedulerMixin, ExponentialParamScheduler): """Decays the momentum of each parameter group by gamma every epoch. Args: @@ -141,27 +107,9 @@ class ExponentialMomentum(ExponentialParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - gamma: float, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='momentum', - gamma=gamma, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class LinearMomentum(LinearParamScheduler): +class LinearMomentum(MomentumSchedulerMixin, LinearParamScheduler): """Decays the momentum of each parameter group by linearly changing small multiplicative factor until the number of epoch reaches a pre-defined milestone: ``end``. @@ -187,29 +135,9 @@ class LinearMomentum(LinearParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - start_factor: float = 1.0 / 3, - end_factor: float = 1.0, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='momentum', - start_factor=start_factor, - end_factor=end_factor, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class MultiStepMomentum(MultiStepParamScheduler): +class MultiStepMomentum(MomentumSchedulerMixin, MultiStepParamScheduler): """Decays the specified momentum in each parameter group by gamma once the number of epoch reaches one of the milestones. Notice that such decay can happen simultaneously with other changes to the momentum from outside this @@ -232,29 +160,9 @@ class MultiStepMomentum(MultiStepParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - milestones: List[int], - gamma: float = 0.1, - last_step: int = -1, - begin: int = 0, - end: int = INF, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='momentum', - milestones=milestones, - gamma=gamma, - last_step=last_step, - begin=begin, - end=end, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class StepMomentum(StepParamScheduler): +class StepMomentum(MomentumSchedulerMixin, StepParamScheduler): """Decays the momentum of each parameter group by gamma every step_size epochs. Notice that such decay can happen simultaneously with other changes to the momentum from outside this scheduler. @@ -276,29 +184,9 @@ class StepMomentum(StepParamScheduler): Defaults to False. """ - def __init__(self, - optimizer: torch.optim.Optimizer, - step_size: int, - gamma: float = 0.1, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='momentum', - step_size=step_size, - gamma=gamma, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) - @PARAM_SCHEDULERS.register_module() -class PolyMomentum(PolyParamScheduler): +class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler): """Decays the momentum of each parameter group in a polynomial decay scheme. @@ -321,23 +209,3 @@ class PolyMomentum(PolyParamScheduler): verbose (bool): Whether to print the value for each update. Defaults to False. """ - - def __init__(self, - optimizer: torch.optim.Optimizer, - eta_min: float = 0, - power: float = 1, - begin: int = 0, - end: int = INF, - last_step: int = -1, - by_epoch: bool = True, - verbose: bool = False): - super().__init__( - optimizer, - param_name='momentum', - eta_min=eta_min, - power=power, - begin=begin, - end=end, - last_step=last_step, - by_epoch=by_epoch, - verbose=verbose) diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index b8b7f260..4c9a2732 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -255,6 +255,35 @@ class StepParamScheduler(_ParamScheduler): by_epoch=by_epoch, verbose=verbose) + @classmethod + def build_iter_from_epoch(cls, + *args, + step_size, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + step_size = step_size * epoch_length + begin = begin * epoch_length + if end != INF: + end = end * epoch_length + return cls( + *args, + step_size=step_size, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) + def _get_value(self): """Compute value using chainable form of the scheduler.""" if (self.last_step == 0) or (self.last_step % self.step_size != 0): @@ -312,6 +341,35 @@ class MultiStepParamScheduler(_ParamScheduler): by_epoch=by_epoch, verbose=verbose) + @classmethod + def build_iter_from_epoch(cls, + *args, + milestones, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + milestones = [i * epoch_length for i in milestones] + begin = begin * epoch_length + if end != INF: + end = end * epoch_length + return cls( + *args, + milestones=milestones, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) + def _get_value(self): """Compute value using chainable form of the scheduler.""" if self.last_step not in self.milestones: @@ -372,6 +430,27 @@ class ConstantParamScheduler(_ParamScheduler): by_epoch=by_epoch, verbose=verbose) + @classmethod + def build_iter_from_epoch(cls, + *args, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + begin = begin * epoch_length + if end != INF: + end = end * epoch_length + return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) + def _get_value(self): """Compute value using chainable form of the scheduler.""" if self.last_step == 0: @@ -431,6 +510,27 @@ class ExponentialParamScheduler(_ParamScheduler): by_epoch=by_epoch, verbose=verbose) + @classmethod + def build_iter_from_epoch(cls, + *args, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + begin = begin * epoch_length + if end != INF: + end = end * epoch_length + return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) + def _get_value(self): """Compute value using chainable form of the scheduler.""" if self.last_step == 0: @@ -512,6 +612,35 @@ class CosineAnnealingParamScheduler(_ParamScheduler): by_epoch=by_epoch, verbose=verbose) + @classmethod + def build_iter_from_epoch(cls, + *args, + T_max, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + T_max = T_max * epoch_length + begin = begin * epoch_length + if end != INF: + end = end * epoch_length + return cls( + *args, + T_max=T_max, + begin=begin, + end=end, + by_epoch=by_epoch, + **kwargs) + def _get_value(self): """Compute value using chainable form of the scheduler.""" if self.last_step == 0: @@ -589,6 +718,27 @@ class LinearParamScheduler(_ParamScheduler): by_epoch=by_epoch, verbose=verbose) + @classmethod + def build_iter_from_epoch(cls, + *args, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + begin = begin * epoch_length + if end != INF: + end = end * epoch_length + return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) + def _get_value(self): """Compute value using chainable form of the scheduler.""" if self.last_step == 0: @@ -655,6 +805,27 @@ class PolyParamScheduler(_ParamScheduler): by_epoch=by_epoch, verbose=verbose) + @classmethod + def build_iter_from_epoch(cls, + *args, + begin=0, + end=INF, + by_epoch=True, + epoch_length=None, + **kwargs): + """Build an iter-based instance of this scheduler from an epoch-based + config.""" + assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \ + 'be converted to iter-based.' + assert epoch_length is not None and epoch_length > 0, \ + f'`epoch_length` must be a positive integer, ' \ + f'but got {epoch_length}.' + by_epoch = False + begin = begin * epoch_length + if end != INF: + end = end * epoch_length + return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs) + def _get_value(self): """Compute value using chainable form of the scheduler.""" if self.last_step == 0: diff --git a/mmengine/runner/runner.py b/mmengine/runner/runner.py index ec4fccbb..c9e1893b 100644 --- a/mmengine/runner/runner.py +++ b/mmengine/runner/runner.py @@ -848,10 +848,29 @@ class Runner: if isinstance(_scheduler, _ParamScheduler): param_schedulers.append(_scheduler) elif isinstance(_scheduler, dict): - param_schedulers.append( - PARAM_SCHEDULERS.build( - _scheduler, - default_args=dict(optimizer=self.optimizer))) + convert_to_iter = _scheduler.pop('convert_to_iter_based', + False) + if convert_to_iter: + assert _scheduler.get( + 'by_epoch', True + ), 'only epoch-based parameter scheduler can be ' \ + 'converted to iter-based' + assert isinstance(self.train_loop, BaseLoop), \ + 'Scheduler can only be converted to iter-based ' \ + 'when train loop is built.' + cls = PARAM_SCHEDULERS.get(_scheduler.pop('type')) + param_schedulers.append( + cls.build_iter_from_epoch( # type: ignore + optimizer=self.optimizer, + **_scheduler, + epoch_length=len( + self.train_loop.dataloader), # type: ignore + )) + else: + param_schedulers.append( + PARAM_SCHEDULERS.build( + _scheduler, + default_args=dict(optimizer=self.optimizer))) else: raise TypeError( '_scheduler should be a _ParamScheduler object or dict, ' diff --git a/tests/test_optim/test_scheduler/test_lr_scheduler.py b/tests/test_optim/test_scheduler/test_lr_scheduler.py index 6e8f337d..e263d2be 100644 --- a/tests/test_optim/test_scheduler/test_lr_scheduler.py +++ b/tests/test_optim/test_scheduler/test_lr_scheduler.py @@ -352,6 +352,144 @@ class TestLRScheduler(TestCase): lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002), epochs=10) + def test_step_scheduler_convert_iterbased(self): + # invalid epoch_length + with self.assertRaises(AssertionError): + scheduler = StepLR.build_iter_from_epoch( + self.optimizer, gamma=0.1, step_size=2, epoch_length=-1) + + # lr = 0.05 if epoch < 2 + # lr = 0.005 if 2 <= epoch < 4 + epochs = 4 + epoch_length = 7 + single_targets = [0.05] * 2 * epoch_length + [0.005] * 2 * epoch_length + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = StepLR.build_iter_from_epoch( + self.optimizer, gamma=0.1, step_size=2, epoch_length=epoch_length) + self._test_scheduler_value( + scheduler, targets, epochs * epoch_length, param_name='lr') + + def test_multi_step_scheduler_convert_iterbased(self): + # lr = 0.05 if epoch < 2 + # lr = 0.005 if 2 <= epoch < 5 + # lr = 0.0005 if 5 <= epoch < 9 + # lr = 0.00005 if epoch >= 9 + epochs = 10 + epoch_length = 7 + single_targets = [0.05 + ] * 2 * epoch_length + [0.005] * 3 * epoch_length + [ + 0.0005 + ] * 4 * epoch_length + [0.00005] * 3 * epoch_length + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = MultiStepLR.build_iter_from_epoch( + self.optimizer, + gamma=0.1, + milestones=[2, 5, 9], + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs * epoch_length) + + def test_constant_scheduler_convert_iterbased(self): + # lr = 0.025 if epoch < 5 + # lr = 0.005 if 5 <= epoch + epochs = 10 + epoch_length = 7 + single_targets = [0.025] * (5 * epoch_length - + 1) + [0.05] * (5 * epoch_length + 1) + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = ConstantLR.build_iter_from_epoch( + self.optimizer, factor=1.0 / 2, end=5, epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs * epoch_length) + + def test_linear_scheduler_convert_iterbased(self): + epochs = 10 + start_factor = 1.0 / 2 + end = 5 + epoch_length = 11 + + iters = end * epoch_length - 1 + interpolation = [ + start_factor + i * (1 - start_factor) / iters for i in range(iters) + ] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( + epochs * epoch_length - iters) + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = LinearLR.build_iter_from_epoch( + self.optimizer, + start_factor=start_factor, + end=end, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs) + + def test_exp_scheduler_convert_iterbased(self): + epochs = 10 + epoch_length = 7 + + single_targets = [ + 0.05 * (0.9**x) for x in range(epochs * epoch_length) + ] + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = ExponentialLR.build_iter_from_epoch( + self.optimizer, gamma=0.9, epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs * epoch_length) + + def test_cos_anneal_scheduler_convert_iterbased(self): + epochs = 12 + t = 10 + eta_min = 1e-10 + epoch_length = 11 + single_targets = [ + eta_min + (0.05 - eta_min) * + (1 + math.cos(math.pi * x / t / epoch_length)) / 2 + for x in range(epochs * epoch_length) + ] + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = CosineAnnealingLR.build_iter_from_epoch( + self.optimizer, + T_max=t, + eta_min=eta_min, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs) + + def test_poly_scheduler_convert_iterbased(self): + epochs = 10 + power = 0.9 + min_lr = 0.001 + end = 5 + epoch_length = 11 + + iters = end * epoch_length - 1 + single_targets = [ + min_lr + (0.05 - min_lr) * (1 - i / iters)**power + for i in range(iters) + ] + [min_lr] * ( + epochs - iters) + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = PolyLR.build_iter_from_epoch( + self.optimizer, + power=power, + eta_min=min_lr, + end=end, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs=10) + def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index c4703392..db8e0e8b 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -47,7 +47,8 @@ class TestParameterScheduler(TestCase): def test_invalid_optimizer(self): with self.assertRaisesRegex(TypeError, 'should be an Optimizer'): - StepParamScheduler('invalid_optimizer', 'lr', step_size=1) + StepParamScheduler( + 'invalid_optimizer', step_size=1, param_name='lr') def test_overwrite_optimzer_step(self): # raise warning if the counter in optimizer.step() is overwritten @@ -140,7 +141,8 @@ class TestParameterScheduler(TestCase): def test_get_last_value(self): epochs = 10 targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005]] - scheduler = StepParamScheduler(self.optimizer, 'lr', 3, gamma=0.1) + scheduler = StepParamScheduler( + self.optimizer, param_name='lr', step_size=3, gamma=0.1) for epoch in range(epochs): result = scheduler.get_last_value() self.optimizer.step() @@ -432,6 +434,163 @@ class TestParameterScheduler(TestCase): self.optimizer, param_name='lr', power=0.8, eta_min=0.002), epochs=10) + def test_step_scheduler_convert_iterbased(self): + # invalid epoch_length + with self.assertRaises(AssertionError): + scheduler = StepParamScheduler.build_iter_from_epoch( + self.optimizer, + param_name='momentum', + gamma=0.1, + step_size=2, + epoch_length=-1) + + # momentum = 0.01 if epoch < 2 + # momentum = 0.001 if 2 <= epoch < 4 + epochs = 4 + epoch_length = 7 + single_targets = [0.01] * 2 * epoch_length + [0.001] * 2 * epoch_length + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = StepParamScheduler.build_iter_from_epoch( + self.optimizer, + param_name='momentum', + gamma=0.1, + step_size=2, + epoch_length=epoch_length) + self._test_scheduler_value( + scheduler, targets, epochs * epoch_length, param_name='momentum') + + def test_multi_step_scheduler_convert_iterbased(self): + # lr = 0.05 if epoch < 2 + # lr = 0.005 if 2 <= epoch < 5 + # lr = 0.0005 if 5 <= epoch < 9 + # lr = 0.00005 if epoch >= 9 + epochs = 10 + epoch_length = 7 + single_targets = [0.05 + ] * 2 * epoch_length + [0.005] * 3 * epoch_length + [ + 0.0005 + ] * 4 * epoch_length + [0.00005] * 3 * epoch_length + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = MultiStepParamScheduler.build_iter_from_epoch( + self.optimizer, + param_name='lr', + gamma=0.1, + milestones=[2, 5, 9], + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs * epoch_length) + + def test_constant_scheduler_convert_iterbased(self): + # lr = 0.025 if epoch < 5 + # lr = 0.005 if 5 <= epoch + epochs = 10 + epoch_length = 7 + single_targets = [0.025] * (5 * epoch_length - + 1) + [0.05] * (5 * epoch_length + 1) + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = ConstantParamScheduler.build_iter_from_epoch( + self.optimizer, + param_name='lr', + factor=1.0 / 2, + end=5, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs * epoch_length) + + def test_linear_scheduler_convert_iterbased(self): + epochs = 10 + start_factor = 1.0 / 2 + end = 5 + epoch_length = 11 + + iters = end * epoch_length - 1 + interpolation = [ + start_factor + i * (1 - start_factor) / iters for i in range(iters) + ] + single_targets = [x * 0.05 for x in interpolation] + [0.05] * ( + epochs * epoch_length - iters) + targets = [single_targets, [x * epochs for x in single_targets]] + scheduler = LinearParamScheduler.build_iter_from_epoch( + self.optimizer, + param_name='lr', + start_factor=start_factor, + end=end, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs) + + def test_exp_scheduler_convert_iterbased(self): + epochs = 10 + epoch_length = 7 + + single_targets = [ + 0.05 * (0.9**x) for x in range(epochs * epoch_length) + ] + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = ExponentialParamScheduler.build_iter_from_epoch( + self.optimizer, + param_name='lr', + gamma=0.9, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs * epoch_length) + + def test_cos_anneal_scheduler_convert_iterbased(self): + epochs = 12 + t = 10 + eta_min = 1e-10 + epoch_length = 11 + single_targets = [ + eta_min + (0.05 - eta_min) * + (1 + math.cos(math.pi * x / t / epoch_length)) / 2 + for x in range(epochs * epoch_length) + ] + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = CosineAnnealingParamScheduler.build_iter_from_epoch( + self.optimizer, + param_name='lr', + T_max=t, + eta_min=eta_min, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs) + + def test_poly_scheduler_convert_iterbased(self): + epochs = 10 + power = 0.9 + min_lr = 0.001 + end = 5 + epoch_length = 11 + + iters = end * epoch_length - 1 + single_targets = [ + min_lr + (0.05 - min_lr) * (1 - i / iters)**power + for i in range(iters) + ] + [min_lr] * ( + epochs - iters) + targets = [ + single_targets, + [x * epochs * epoch_length for x in single_targets] + ] + scheduler = PolyParamScheduler.build_iter_from_epoch( + self.optimizer, + param_name='lr', + power=power, + eta_min=min_lr, + end=end, + epoch_length=epoch_length) + self._test_scheduler_value(scheduler, targets, epochs=10) + def test_multi_scheduler_without_overlap_linear_multi_step(self): # use Linear in the first 5 epochs and then use MultiStep epochs = 12 diff --git a/tests/test_runner/test_runner.py b/tests/test_runner/test_runner.py index 9a11720c..de07b614 100644 --- a/tests/test_runner/test_runner.py +++ b/tests/test_runner/test_runner.py @@ -603,6 +603,28 @@ class TestRunner(TestCase): self.assertIsInstance(param_schedulers[0], MultiStepLR) self.assertIsInstance(param_schedulers[1], StepLR) + # train loop should be built before convert scheduler + cfg = dict( + type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True) + with self.assertRaisesRegex( + AssertionError, + 'Scheduler can only be converted to iter-based when ' + 'train loop is built.'): + param_schedulers = runner.build_param_scheduler(cfg) + + # convert epoch-based to iter-based scheduler + cfg = dict( + type='MultiStepLR', + milestones=[1, 2], + begin=1, + end=7, + convert_to_iter_based=True) + runner.train_loop = runner.build_train_loop(runner.train_loop) + param_schedulers = runner.build_param_scheduler(cfg) + self.assertFalse(param_schedulers[0].by_epoch) + self.assertEqual(param_schedulers[0].begin, 4) + self.assertEqual(param_schedulers[0].end, 28) + def test_build_evaluator(self): cfg = copy.deepcopy(self.epoch_based_cfg) cfg.experiment_name = 'test_build_evaluator'