[Feature] Support convert epoch-based schedulers to iter-based. (#221)
* [Feature] Support convert epoch-based schedulers to iter-based. * Support convert and refactor LR and Momentum to mixin. * Add unit tests * fix args and add runner ut * resolve commentspull/196/head
parent
92b94e8e60
commit
1912660db9
|
@ -1,18 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.registry import PARAM_SCHEDULERS
|
||||
from .param_scheduler import (INF, ConstantParamScheduler,
|
||||
from .param_scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
ExponentialParamScheduler, LinearParamScheduler,
|
||||
MultiStepParamScheduler, PolyParamScheduler,
|
||||
StepParamScheduler)
|
||||
|
||||
|
||||
class LRSchedulerMixin:
|
||||
"""A mixin class for learning rate schedulers."""
|
||||
|
||||
def __init__(self, optimizer, *args, **kwargs):
|
||||
super().__init__(optimizer, 'lr', *args, **kwargs)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class ConstantLR(ConstantParamScheduler):
|
||||
class ConstantLR(LRSchedulerMixin, ConstantParamScheduler):
|
||||
"""Decays the learning rate value of each parameter group by a small
|
||||
constant factor until the number of epoch reaches a pre-defined milestone:
|
||||
``end``. Notice that such decay can happen simultaneously with other
|
||||
|
@ -34,27 +37,9 @@ class ConstantLR(ConstantParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
factor: float = 1.0 / 3,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='lr',
|
||||
factor=factor,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class CosineAnnealingLR(CosineAnnealingParamScheduler):
|
||||
class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler):
|
||||
r"""Set the learning rate of each parameter group using a cosine annealing
|
||||
schedule, where :math:`\eta_{max}` is set to the initial value and
|
||||
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||
|
@ -101,29 +86,9 @@ class CosineAnnealingLR(CosineAnnealingParamScheduler):
|
|||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
T_max: int,
|
||||
eta_min: int = 0,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='lr',
|
||||
T_max=T_max,
|
||||
eta_min=eta_min,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class ExponentialLR(ExponentialParamScheduler):
|
||||
class ExponentialLR(LRSchedulerMixin, ExponentialParamScheduler):
|
||||
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||
|
||||
Args:
|
||||
|
@ -141,27 +106,9 @@ class ExponentialLR(ExponentialParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
gamma: float,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='lr',
|
||||
gamma=gamma,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class LinearLR(LinearParamScheduler):
|
||||
class LinearLR(LRSchedulerMixin, LinearParamScheduler):
|
||||
"""Decays the learning rate of each parameter group by linearly changing
|
||||
small multiplicative factor until the number of epoch reaches a pre-defined
|
||||
milestone: ``end``.
|
||||
|
@ -187,29 +134,9 @@ class LinearLR(LinearParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
start_factor: float = 1.0 / 3,
|
||||
end_factor: float = 1.0,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='lr',
|
||||
start_factor=start_factor,
|
||||
end_factor=end_factor,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class MultiStepLR(MultiStepParamScheduler):
|
||||
class MultiStepLR(LRSchedulerMixin, MultiStepParamScheduler):
|
||||
"""Decays the specified learning rate in each parameter group by gamma once
|
||||
the number of epoch reaches one of the milestones. Notice that such decay
|
||||
can happen simultaneously with other changes to the learning rate from
|
||||
|
@ -232,29 +159,9 @@ class MultiStepLR(MultiStepParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
milestones: List[int],
|
||||
gamma: float = 0.1,
|
||||
last_step: int = -1,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='lr',
|
||||
milestones=milestones,
|
||||
gamma=gamma,
|
||||
last_step=last_step,
|
||||
begin=begin,
|
||||
end=end,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class StepLR(StepParamScheduler):
|
||||
class StepLR(LRSchedulerMixin, StepParamScheduler):
|
||||
"""Decays the learning rate of each parameter group by gamma every
|
||||
step_size epochs. Notice that such decay can happen simultaneously with
|
||||
other changes to the learning rate from outside this scheduler.
|
||||
|
@ -276,29 +183,9 @@ class StepLR(StepParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
step_size: int,
|
||||
gamma: float = 0.1,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='lr',
|
||||
step_size=step_size,
|
||||
gamma=gamma,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class PolyLR(PolyParamScheduler):
|
||||
class PolyLR(LRSchedulerMixin, PolyParamScheduler):
|
||||
"""Decays the learning rate of each parameter group in a polynomial decay
|
||||
scheme.
|
||||
|
||||
|
@ -321,23 +208,3 @@ class PolyLR(PolyParamScheduler):
|
|||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
eta_min: float = 0,
|
||||
power: float = 1,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='lr',
|
||||
eta_min=eta_min,
|
||||
power=power,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
|
|
@ -1,18 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
|
||||
from mmengine.registry import PARAM_SCHEDULERS
|
||||
from .param_scheduler import (INF, ConstantParamScheduler,
|
||||
from .param_scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
ExponentialParamScheduler, LinearParamScheduler,
|
||||
MultiStepParamScheduler, PolyParamScheduler,
|
||||
StepParamScheduler)
|
||||
|
||||
|
||||
class MomentumSchedulerMixin:
|
||||
"""A mixin class for momentum schedulers."""
|
||||
|
||||
def __init__(self, optimizer, *args, **kwargs):
|
||||
super().__init__(optimizer, 'momentum', *args, **kwargs)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class ConstantMomentum(ConstantParamScheduler):
|
||||
class ConstantMomentum(MomentumSchedulerMixin, ConstantParamScheduler):
|
||||
"""Decays the momentum value of each parameter group by a small constant
|
||||
factor until the number of epoch reaches a pre-defined milestone: ``end``.
|
||||
Notice that such decay can happen simultaneously with other changes to the
|
||||
|
@ -34,27 +37,10 @@ class ConstantMomentum(ConstantParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
factor: float = 1.0 / 3,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='momentum',
|
||||
factor=factor,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class CosineAnnealingMomentum(CosineAnnealingParamScheduler):
|
||||
class CosineAnnealingMomentum(MomentumSchedulerMixin,
|
||||
CosineAnnealingParamScheduler):
|
||||
r"""Set the momentum of each parameter group using a cosine annealing
|
||||
schedule, where :math:`\eta_{max}` is set to the initial value and
|
||||
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||
|
@ -101,29 +87,9 @@ class CosineAnnealingMomentum(CosineAnnealingParamScheduler):
|
|||
https://arxiv.org/abs/1608.03983
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
T_max: int,
|
||||
eta_min: int = 0,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='momentum',
|
||||
T_max=T_max,
|
||||
eta_min=eta_min,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class ExponentialMomentum(ExponentialParamScheduler):
|
||||
class ExponentialMomentum(MomentumSchedulerMixin, ExponentialParamScheduler):
|
||||
"""Decays the momentum of each parameter group by gamma every epoch.
|
||||
|
||||
Args:
|
||||
|
@ -141,27 +107,9 @@ class ExponentialMomentum(ExponentialParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
gamma: float,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='momentum',
|
||||
gamma=gamma,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class LinearMomentum(LinearParamScheduler):
|
||||
class LinearMomentum(MomentumSchedulerMixin, LinearParamScheduler):
|
||||
"""Decays the momentum of each parameter group by linearly changing
|
||||
small multiplicative factor until the number of epoch reaches a pre-defined
|
||||
milestone: ``end``.
|
||||
|
@ -187,29 +135,9 @@ class LinearMomentum(LinearParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
start_factor: float = 1.0 / 3,
|
||||
end_factor: float = 1.0,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='momentum',
|
||||
start_factor=start_factor,
|
||||
end_factor=end_factor,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class MultiStepMomentum(MultiStepParamScheduler):
|
||||
class MultiStepMomentum(MomentumSchedulerMixin, MultiStepParamScheduler):
|
||||
"""Decays the specified momentum in each parameter group by gamma once the
|
||||
number of epoch reaches one of the milestones. Notice that such decay can
|
||||
happen simultaneously with other changes to the momentum from outside this
|
||||
|
@ -232,29 +160,9 @@ class MultiStepMomentum(MultiStepParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
milestones: List[int],
|
||||
gamma: float = 0.1,
|
||||
last_step: int = -1,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='momentum',
|
||||
milestones=milestones,
|
||||
gamma=gamma,
|
||||
last_step=last_step,
|
||||
begin=begin,
|
||||
end=end,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class StepMomentum(StepParamScheduler):
|
||||
class StepMomentum(MomentumSchedulerMixin, StepParamScheduler):
|
||||
"""Decays the momentum of each parameter group by gamma every step_size
|
||||
epochs. Notice that such decay can happen simultaneously with other changes
|
||||
to the momentum from outside this scheduler.
|
||||
|
@ -276,29 +184,9 @@ class StepMomentum(StepParamScheduler):
|
|||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
step_size: int,
|
||||
gamma: float = 0.1,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='momentum',
|
||||
step_size=step_size,
|
||||
gamma=gamma,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class PolyMomentum(PolyParamScheduler):
|
||||
class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler):
|
||||
"""Decays the momentum of each parameter group in a polynomial decay
|
||||
scheme.
|
||||
|
||||
|
@ -321,23 +209,3 @@ class PolyMomentum(PolyParamScheduler):
|
|||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
eta_min: float = 0,
|
||||
power: float = 1,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name='momentum',
|
||||
eta_min=eta_min,
|
||||
power=power,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
|
|
@ -255,6 +255,35 @@ class StepParamScheduler(_ParamScheduler):
|
|||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
@classmethod
|
||||
def build_iter_from_epoch(cls,
|
||||
*args,
|
||||
step_size,
|
||||
begin=0,
|
||||
end=INF,
|
||||
by_epoch=True,
|
||||
epoch_length=None,
|
||||
**kwargs):
|
||||
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||
config."""
|
||||
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||
'be converted to iter-based.'
|
||||
assert epoch_length is not None and epoch_length > 0, \
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
step_size = step_size * epoch_length
|
||||
begin = begin * epoch_length
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
return cls(
|
||||
*args,
|
||||
step_size=step_size,
|
||||
begin=begin,
|
||||
end=end,
|
||||
by_epoch=by_epoch,
|
||||
**kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
if (self.last_step == 0) or (self.last_step % self.step_size != 0):
|
||||
|
@ -312,6 +341,35 @@ class MultiStepParamScheduler(_ParamScheduler):
|
|||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
@classmethod
|
||||
def build_iter_from_epoch(cls,
|
||||
*args,
|
||||
milestones,
|
||||
begin=0,
|
||||
end=INF,
|
||||
by_epoch=True,
|
||||
epoch_length=None,
|
||||
**kwargs):
|
||||
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||
config."""
|
||||
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||
'be converted to iter-based.'
|
||||
assert epoch_length is not None and epoch_length > 0, \
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
milestones = [i * epoch_length for i in milestones]
|
||||
begin = begin * epoch_length
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
return cls(
|
||||
*args,
|
||||
milestones=milestones,
|
||||
begin=begin,
|
||||
end=end,
|
||||
by_epoch=by_epoch,
|
||||
**kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
if self.last_step not in self.milestones:
|
||||
|
@ -372,6 +430,27 @@ class ConstantParamScheduler(_ParamScheduler):
|
|||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
@classmethod
|
||||
def build_iter_from_epoch(cls,
|
||||
*args,
|
||||
begin=0,
|
||||
end=INF,
|
||||
by_epoch=True,
|
||||
epoch_length=None,
|
||||
**kwargs):
|
||||
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||
config."""
|
||||
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||
'be converted to iter-based.'
|
||||
assert epoch_length is not None and epoch_length > 0, \
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
if self.last_step == 0:
|
||||
|
@ -431,6 +510,27 @@ class ExponentialParamScheduler(_ParamScheduler):
|
|||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
@classmethod
|
||||
def build_iter_from_epoch(cls,
|
||||
*args,
|
||||
begin=0,
|
||||
end=INF,
|
||||
by_epoch=True,
|
||||
epoch_length=None,
|
||||
**kwargs):
|
||||
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||
config."""
|
||||
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||
'be converted to iter-based.'
|
||||
assert epoch_length is not None and epoch_length > 0, \
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
if self.last_step == 0:
|
||||
|
@ -512,6 +612,35 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
|
|||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
@classmethod
|
||||
def build_iter_from_epoch(cls,
|
||||
*args,
|
||||
T_max,
|
||||
begin=0,
|
||||
end=INF,
|
||||
by_epoch=True,
|
||||
epoch_length=None,
|
||||
**kwargs):
|
||||
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||
config."""
|
||||
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||
'be converted to iter-based.'
|
||||
assert epoch_length is not None and epoch_length > 0, \
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
T_max = T_max * epoch_length
|
||||
begin = begin * epoch_length
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
return cls(
|
||||
*args,
|
||||
T_max=T_max,
|
||||
begin=begin,
|
||||
end=end,
|
||||
by_epoch=by_epoch,
|
||||
**kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
if self.last_step == 0:
|
||||
|
@ -589,6 +718,27 @@ class LinearParamScheduler(_ParamScheduler):
|
|||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
@classmethod
|
||||
def build_iter_from_epoch(cls,
|
||||
*args,
|
||||
begin=0,
|
||||
end=INF,
|
||||
by_epoch=True,
|
||||
epoch_length=None,
|
||||
**kwargs):
|
||||
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||
config."""
|
||||
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||
'be converted to iter-based.'
|
||||
assert epoch_length is not None and epoch_length > 0, \
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
if self.last_step == 0:
|
||||
|
@ -655,6 +805,27 @@ class PolyParamScheduler(_ParamScheduler):
|
|||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
@classmethod
|
||||
def build_iter_from_epoch(cls,
|
||||
*args,
|
||||
begin=0,
|
||||
end=INF,
|
||||
by_epoch=True,
|
||||
epoch_length=None,
|
||||
**kwargs):
|
||||
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||
config."""
|
||||
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||
'be converted to iter-based.'
|
||||
assert epoch_length is not None and epoch_length > 0, \
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
if self.last_step == 0:
|
||||
|
|
|
@ -848,10 +848,29 @@ class Runner:
|
|||
if isinstance(_scheduler, _ParamScheduler):
|
||||
param_schedulers.append(_scheduler)
|
||||
elif isinstance(_scheduler, dict):
|
||||
param_schedulers.append(
|
||||
PARAM_SCHEDULERS.build(
|
||||
_scheduler,
|
||||
default_args=dict(optimizer=self.optimizer)))
|
||||
convert_to_iter = _scheduler.pop('convert_to_iter_based',
|
||||
False)
|
||||
if convert_to_iter:
|
||||
assert _scheduler.get(
|
||||
'by_epoch', True
|
||||
), 'only epoch-based parameter scheduler can be ' \
|
||||
'converted to iter-based'
|
||||
assert isinstance(self.train_loop, BaseLoop), \
|
||||
'Scheduler can only be converted to iter-based ' \
|
||||
'when train loop is built.'
|
||||
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
|
||||
param_schedulers.append(
|
||||
cls.build_iter_from_epoch( # type: ignore
|
||||
optimizer=self.optimizer,
|
||||
**_scheduler,
|
||||
epoch_length=len(
|
||||
self.train_loop.dataloader), # type: ignore
|
||||
))
|
||||
else:
|
||||
param_schedulers.append(
|
||||
PARAM_SCHEDULERS.build(
|
||||
_scheduler,
|
||||
default_args=dict(optimizer=self.optimizer)))
|
||||
else:
|
||||
raise TypeError(
|
||||
'_scheduler should be a _ParamScheduler object or dict, '
|
||||
|
|
|
@ -352,6 +352,144 @@ class TestLRScheduler(TestCase):
|
|||
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
|
||||
epochs=10)
|
||||
|
||||
def test_step_scheduler_convert_iterbased(self):
|
||||
# invalid epoch_length
|
||||
with self.assertRaises(AssertionError):
|
||||
scheduler = StepLR.build_iter_from_epoch(
|
||||
self.optimizer, gamma=0.1, step_size=2, epoch_length=-1)
|
||||
|
||||
# lr = 0.05 if epoch < 2
|
||||
# lr = 0.005 if 2 <= epoch < 4
|
||||
epochs = 4
|
||||
epoch_length = 7
|
||||
single_targets = [0.05] * 2 * epoch_length + [0.005] * 2 * epoch_length
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = StepLR.build_iter_from_epoch(
|
||||
self.optimizer, gamma=0.1, step_size=2, epoch_length=epoch_length)
|
||||
self._test_scheduler_value(
|
||||
scheduler, targets, epochs * epoch_length, param_name='lr')
|
||||
|
||||
def test_multi_step_scheduler_convert_iterbased(self):
|
||||
# lr = 0.05 if epoch < 2
|
||||
# lr = 0.005 if 2 <= epoch < 5
|
||||
# lr = 0.0005 if 5 <= epoch < 9
|
||||
# lr = 0.00005 if epoch >= 9
|
||||
epochs = 10
|
||||
epoch_length = 7
|
||||
single_targets = [0.05
|
||||
] * 2 * epoch_length + [0.005] * 3 * epoch_length + [
|
||||
0.0005
|
||||
] * 4 * epoch_length + [0.00005] * 3 * epoch_length
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = MultiStepLR.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
gamma=0.1,
|
||||
milestones=[2, 5, 9],
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
|
||||
|
||||
def test_constant_scheduler_convert_iterbased(self):
|
||||
# lr = 0.025 if epoch < 5
|
||||
# lr = 0.005 if 5 <= epoch
|
||||
epochs = 10
|
||||
epoch_length = 7
|
||||
single_targets = [0.025] * (5 * epoch_length -
|
||||
1) + [0.05] * (5 * epoch_length + 1)
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = ConstantLR.build_iter_from_epoch(
|
||||
self.optimizer, factor=1.0 / 2, end=5, epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
|
||||
|
||||
def test_linear_scheduler_convert_iterbased(self):
|
||||
epochs = 10
|
||||
start_factor = 1.0 / 2
|
||||
end = 5
|
||||
epoch_length = 11
|
||||
|
||||
iters = end * epoch_length - 1
|
||||
interpolation = [
|
||||
start_factor + i * (1 - start_factor) / iters for i in range(iters)
|
||||
]
|
||||
single_targets = [x * 0.05 for x in interpolation] + [0.05] * (
|
||||
epochs * epoch_length - iters)
|
||||
targets = [single_targets, [x * epochs for x in single_targets]]
|
||||
scheduler = LinearLR.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
start_factor=start_factor,
|
||||
end=end,
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs)
|
||||
|
||||
def test_exp_scheduler_convert_iterbased(self):
|
||||
epochs = 10
|
||||
epoch_length = 7
|
||||
|
||||
single_targets = [
|
||||
0.05 * (0.9**x) for x in range(epochs * epoch_length)
|
||||
]
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = ExponentialLR.build_iter_from_epoch(
|
||||
self.optimizer, gamma=0.9, epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
|
||||
|
||||
def test_cos_anneal_scheduler_convert_iterbased(self):
|
||||
epochs = 12
|
||||
t = 10
|
||||
eta_min = 1e-10
|
||||
epoch_length = 11
|
||||
single_targets = [
|
||||
eta_min + (0.05 - eta_min) *
|
||||
(1 + math.cos(math.pi * x / t / epoch_length)) / 2
|
||||
for x in range(epochs * epoch_length)
|
||||
]
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = CosineAnnealingLR.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
T_max=t,
|
||||
eta_min=eta_min,
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs)
|
||||
|
||||
def test_poly_scheduler_convert_iterbased(self):
|
||||
epochs = 10
|
||||
power = 0.9
|
||||
min_lr = 0.001
|
||||
end = 5
|
||||
epoch_length = 11
|
||||
|
||||
iters = end * epoch_length - 1
|
||||
single_targets = [
|
||||
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
|
||||
for i in range(iters)
|
||||
] + [min_lr] * (
|
||||
epochs - iters)
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = PolyLR.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
power=power,
|
||||
eta_min=min_lr,
|
||||
end=end,
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||
|
||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||
# use Linear in the first 5 epochs and then use MultiStep
|
||||
epochs = 12
|
||||
|
|
|
@ -47,7 +47,8 @@ class TestParameterScheduler(TestCase):
|
|||
|
||||
def test_invalid_optimizer(self):
|
||||
with self.assertRaisesRegex(TypeError, 'should be an Optimizer'):
|
||||
StepParamScheduler('invalid_optimizer', 'lr', step_size=1)
|
||||
StepParamScheduler(
|
||||
'invalid_optimizer', step_size=1, param_name='lr')
|
||||
|
||||
def test_overwrite_optimzer_step(self):
|
||||
# raise warning if the counter in optimizer.step() is overwritten
|
||||
|
@ -140,7 +141,8 @@ class TestParameterScheduler(TestCase):
|
|||
def test_get_last_value(self):
|
||||
epochs = 10
|
||||
targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005]]
|
||||
scheduler = StepParamScheduler(self.optimizer, 'lr', 3, gamma=0.1)
|
||||
scheduler = StepParamScheduler(
|
||||
self.optimizer, param_name='lr', step_size=3, gamma=0.1)
|
||||
for epoch in range(epochs):
|
||||
result = scheduler.get_last_value()
|
||||
self.optimizer.step()
|
||||
|
@ -432,6 +434,163 @@ class TestParameterScheduler(TestCase):
|
|||
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
|
||||
epochs=10)
|
||||
|
||||
def test_step_scheduler_convert_iterbased(self):
|
||||
# invalid epoch_length
|
||||
with self.assertRaises(AssertionError):
|
||||
scheduler = StepParamScheduler.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
param_name='momentum',
|
||||
gamma=0.1,
|
||||
step_size=2,
|
||||
epoch_length=-1)
|
||||
|
||||
# momentum = 0.01 if epoch < 2
|
||||
# momentum = 0.001 if 2 <= epoch < 4
|
||||
epochs = 4
|
||||
epoch_length = 7
|
||||
single_targets = [0.01] * 2 * epoch_length + [0.001] * 2 * epoch_length
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = StepParamScheduler.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
param_name='momentum',
|
||||
gamma=0.1,
|
||||
step_size=2,
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(
|
||||
scheduler, targets, epochs * epoch_length, param_name='momentum')
|
||||
|
||||
def test_multi_step_scheduler_convert_iterbased(self):
|
||||
# lr = 0.05 if epoch < 2
|
||||
# lr = 0.005 if 2 <= epoch < 5
|
||||
# lr = 0.0005 if 5 <= epoch < 9
|
||||
# lr = 0.00005 if epoch >= 9
|
||||
epochs = 10
|
||||
epoch_length = 7
|
||||
single_targets = [0.05
|
||||
] * 2 * epoch_length + [0.005] * 3 * epoch_length + [
|
||||
0.0005
|
||||
] * 4 * epoch_length + [0.00005] * 3 * epoch_length
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = MultiStepParamScheduler.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
gamma=0.1,
|
||||
milestones=[2, 5, 9],
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
|
||||
|
||||
def test_constant_scheduler_convert_iterbased(self):
|
||||
# lr = 0.025 if epoch < 5
|
||||
# lr = 0.005 if 5 <= epoch
|
||||
epochs = 10
|
||||
epoch_length = 7
|
||||
single_targets = [0.025] * (5 * epoch_length -
|
||||
1) + [0.05] * (5 * epoch_length + 1)
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = ConstantParamScheduler.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
factor=1.0 / 2,
|
||||
end=5,
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
|
||||
|
||||
def test_linear_scheduler_convert_iterbased(self):
|
||||
epochs = 10
|
||||
start_factor = 1.0 / 2
|
||||
end = 5
|
||||
epoch_length = 11
|
||||
|
||||
iters = end * epoch_length - 1
|
||||
interpolation = [
|
||||
start_factor + i * (1 - start_factor) / iters for i in range(iters)
|
||||
]
|
||||
single_targets = [x * 0.05 for x in interpolation] + [0.05] * (
|
||||
epochs * epoch_length - iters)
|
||||
targets = [single_targets, [x * epochs for x in single_targets]]
|
||||
scheduler = LinearParamScheduler.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
start_factor=start_factor,
|
||||
end=end,
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs)
|
||||
|
||||
def test_exp_scheduler_convert_iterbased(self):
|
||||
epochs = 10
|
||||
epoch_length = 7
|
||||
|
||||
single_targets = [
|
||||
0.05 * (0.9**x) for x in range(epochs * epoch_length)
|
||||
]
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = ExponentialParamScheduler.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
gamma=0.9,
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
|
||||
|
||||
def test_cos_anneal_scheduler_convert_iterbased(self):
|
||||
epochs = 12
|
||||
t = 10
|
||||
eta_min = 1e-10
|
||||
epoch_length = 11
|
||||
single_targets = [
|
||||
eta_min + (0.05 - eta_min) *
|
||||
(1 + math.cos(math.pi * x / t / epoch_length)) / 2
|
||||
for x in range(epochs * epoch_length)
|
||||
]
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = CosineAnnealingParamScheduler.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
T_max=t,
|
||||
eta_min=eta_min,
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs)
|
||||
|
||||
def test_poly_scheduler_convert_iterbased(self):
|
||||
epochs = 10
|
||||
power = 0.9
|
||||
min_lr = 0.001
|
||||
end = 5
|
||||
epoch_length = 11
|
||||
|
||||
iters = end * epoch_length - 1
|
||||
single_targets = [
|
||||
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
|
||||
for i in range(iters)
|
||||
] + [min_lr] * (
|
||||
epochs - iters)
|
||||
targets = [
|
||||
single_targets,
|
||||
[x * epochs * epoch_length for x in single_targets]
|
||||
]
|
||||
scheduler = PolyParamScheduler.build_iter_from_epoch(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
power=power,
|
||||
eta_min=min_lr,
|
||||
end=end,
|
||||
epoch_length=epoch_length)
|
||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||
|
||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||
# use Linear in the first 5 epochs and then use MultiStep
|
||||
epochs = 12
|
||||
|
|
|
@ -603,6 +603,28 @@ class TestRunner(TestCase):
|
|||
self.assertIsInstance(param_schedulers[0], MultiStepLR)
|
||||
self.assertIsInstance(param_schedulers[1], StepLR)
|
||||
|
||||
# train loop should be built before convert scheduler
|
||||
cfg = dict(
|
||||
type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True)
|
||||
with self.assertRaisesRegex(
|
||||
AssertionError,
|
||||
'Scheduler can only be converted to iter-based when '
|
||||
'train loop is built.'):
|
||||
param_schedulers = runner.build_param_scheduler(cfg)
|
||||
|
||||
# convert epoch-based to iter-based scheduler
|
||||
cfg = dict(
|
||||
type='MultiStepLR',
|
||||
milestones=[1, 2],
|
||||
begin=1,
|
||||
end=7,
|
||||
convert_to_iter_based=True)
|
||||
runner.train_loop = runner.build_train_loop(runner.train_loop)
|
||||
param_schedulers = runner.build_param_scheduler(cfg)
|
||||
self.assertFalse(param_schedulers[0].by_epoch)
|
||||
self.assertEqual(param_schedulers[0].begin, 4)
|
||||
self.assertEqual(param_schedulers[0].end, 28)
|
||||
|
||||
def test_build_evaluator(self):
|
||||
cfg = copy.deepcopy(self.epoch_based_cfg)
|
||||
cfg.experiment_name = 'test_build_evaluator'
|
||||
|
|
Loading…
Reference in New Issue