mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Feature]: Add parameter schedulers. (#22)
* [Feature]: Add parameter schedulers. * update * update * update * update * add docstring to lr and momentum * resolve comments
This commit is contained in:
parent
41e1191cbc
commit
7905f039b6
20
mmengine/optim/scheduler/__init__.py
Normal file
20
mmengine/optim/scheduler/__init__.py
Normal file
@ -0,0 +1,20 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
|
||||||
|
LinearLR, MultiStepLR, StepLR)
|
||||||
|
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
|
||||||
|
ExponentialMomentum, LinearMomentum,
|
||||||
|
MultiStepMomentum, StepMomentum)
|
||||||
|
from .param_scheduler import (ConstantParamScheduler,
|
||||||
|
CosineAnnealingParamScheduler,
|
||||||
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
|
MultiStepParamScheduler, StepParamScheduler,
|
||||||
|
_ParamScheduler)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
|
||||||
|
'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum',
|
||||||
|
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
|
||||||
|
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
|
||||||
|
'ExponentialParamScheduler', 'LinearParamScheduler',
|
||||||
|
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
|
||||||
|
]
|
296
mmengine/optim/scheduler/lr_scheduler.py
Normal file
296
mmengine/optim/scheduler/lr_scheduler.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.registry import PARAM_SCHEDULERS
|
||||||
|
from .param_scheduler import (INF, ConstantParamScheduler,
|
||||||
|
CosineAnnealingParamScheduler,
|
||||||
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
|
MultiStepParamScheduler, StepParamScheduler)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class ConstantLR(ConstantParamScheduler):
|
||||||
|
"""Decays the learning rate value of each parameter group by a small
|
||||||
|
constant factor until the number of epoch reaches a pre-defined milestone:
|
||||||
|
``end``. Notice that such decay can happen simultaneously with other
|
||||||
|
changes to the learning rate value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
factor (float): The number we multiply learning rate until the
|
||||||
|
milestone. Defaults to 1./3.
|
||||||
|
begin (int): Step at which to start updating the learning rate.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the learning rate.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without state
|
||||||
|
dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled learning rate is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the learning rate for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
factor: float = 1.0 / 3,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
factor=factor,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class CosineAnnealingLR(CosineAnnealingParamScheduler):
|
||||||
|
r"""Set the learning rate of each parameter group using a cosine annealing
|
||||||
|
schedule, where :math:`\eta_{max}` is set to the initial value and
|
||||||
|
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\begin{aligned}
|
||||||
|
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
||||||
|
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
||||||
|
& T_{cur} \neq (2k+1)T_{max}; \\
|
||||||
|
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
||||||
|
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
||||||
|
& T_{cur} = (2k+1)T_{max}.
|
||||||
|
\end{aligned}
|
||||||
|
|
||||||
|
Notice that because the schedule
|
||||||
|
is defined recursively, the learning rate can be simultaneously modified
|
||||||
|
outside this scheduler by other operators. If the learning rate is set
|
||||||
|
solely by this scheduler, the learning rate at each step becomes:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
||||||
|
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
|
||||||
|
|
||||||
|
It has been proposed in
|
||||||
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this
|
||||||
|
only implements the cosine annealing part of SGDR, and not the restarts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
T_max (int): Maximum number of iterations.
|
||||||
|
eta_min (float): Minimum learning rate. Defaults to 0.
|
||||||
|
begin (int): Step at which to start updating the learning rate.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the learning rate.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled learning rate is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the learning rate for each update.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||||
|
https://arxiv.org/abs/1608.03983
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
T_max: int,
|
||||||
|
eta_min: int = 0,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
T_max=T_max,
|
||||||
|
eta_min=eta_min,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class ExponentialLR(ExponentialParamScheduler):
|
||||||
|
"""Decays the learning rate of each parameter group by gamma every epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
gamma (float): Multiplicative factor of learning rate decay.
|
||||||
|
begin (int): Step at which to start updating the learning rate.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the learning rate.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled learning rate is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the learning rate for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
gamma: float,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
gamma=gamma,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class LinearLR(LinearParamScheduler):
|
||||||
|
"""Decays the learning rate of each parameter group by linearly changing
|
||||||
|
small multiplicative factor until the number of epoch reaches a pre-defined
|
||||||
|
milestone: ``end``.
|
||||||
|
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
learning rate from outside this scheduler.
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
start_factor (float): The number we multiply learning rate in the
|
||||||
|
first epoch. The multiplication factor changes towards end_factor
|
||||||
|
in the following epochs. Defaults to 1./3.
|
||||||
|
end_factor (float): The number we multiply learning rate at the end
|
||||||
|
of linear changing process. Defaults to 1.0.
|
||||||
|
begin (int): Step at which to start updating the learning rate.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the learning rate.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled learning rate is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the learning rate for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
start_factor: float = 1.0 / 3,
|
||||||
|
end_factor: float = 1.0,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
start_factor=start_factor,
|
||||||
|
end_factor=end_factor,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class MultiStepLR(MultiStepParamScheduler):
|
||||||
|
"""Decays the specified learning rate in each parameter group by gamma once
|
||||||
|
the number of epoch reaches one of the milestones. Notice that such decay
|
||||||
|
can happen simultaneously with other changes to the learning rate from
|
||||||
|
outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
milestones (list): List of epoch indices. Must be increasing.
|
||||||
|
gamma (float): Multiplicative factor of learning rate decay.
|
||||||
|
Defaults to 0.1.
|
||||||
|
begin (int): Step at which to start updating the learning rate.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the learning rate.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled learning rate is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the learning rate for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
milestones: List[int],
|
||||||
|
gamma: float = 0.1,
|
||||||
|
last_step: int = -1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
milestones=milestones,
|
||||||
|
gamma=gamma,
|
||||||
|
last_step=last_step,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class StepLR(StepParamScheduler):
|
||||||
|
"""Decays the learning rate of each parameter group by gamma every
|
||||||
|
step_size epochs. Notice that such decay can happen simultaneously with
|
||||||
|
other changes to the learning rate from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
step_size (int): Period of learning rate decay.
|
||||||
|
gamma (float): Multiplicative factor of learning rate decay.
|
||||||
|
Defaults to 0.1.
|
||||||
|
begin (int): Step at which to start updating the learning rate.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the learning rate.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled learning rate is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the learning rate for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
step_size: int,
|
||||||
|
gamma: float = 0.1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
step_size=step_size,
|
||||||
|
gamma=gamma,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
296
mmengine/optim/scheduler/momentum_scheduler.py
Normal file
296
mmengine/optim/scheduler/momentum_scheduler.py
Normal file
@ -0,0 +1,296 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from mmengine.registry import PARAM_SCHEDULERS
|
||||||
|
from .param_scheduler import (INF, ConstantParamScheduler,
|
||||||
|
CosineAnnealingParamScheduler,
|
||||||
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
|
MultiStepParamScheduler, StepParamScheduler)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class ConstantMomentum(ConstantParamScheduler):
|
||||||
|
"""Decays the momentum value of each parameter group by a small constant
|
||||||
|
factor until the number of epoch reaches a pre-defined milestone: ``end``.
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
momentum value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
factor (float): The number we multiply momentum until the milestone.
|
||||||
|
Defaults to 1./3.
|
||||||
|
begin (int): Step at which to start updating the momentum.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the momentum.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without state
|
||||||
|
dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled momentum is updated by epochs.
|
||||||
|
Defaults to True.
|
||||||
|
verbose (bool): Whether to print the momentum for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
factor: float = 1.0 / 3,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='momentum',
|
||||||
|
factor=factor,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class CosineAnnealingMomentum(CosineAnnealingParamScheduler):
|
||||||
|
r"""Set the momentum of each parameter group using a cosine annealing
|
||||||
|
schedule, where :math:`\eta_{max}` is set to the initial value and
|
||||||
|
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\begin{aligned}
|
||||||
|
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
||||||
|
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
||||||
|
& T_{cur} \neq (2k+1)T_{max}; \\
|
||||||
|
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
||||||
|
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
||||||
|
& T_{cur} = (2k+1)T_{max}.
|
||||||
|
\end{aligned}
|
||||||
|
|
||||||
|
Notice that because the schedule
|
||||||
|
is defined recursively, the momentum can be simultaneously modified
|
||||||
|
outside this scheduler by other operators. If the momentum is set
|
||||||
|
solely by this scheduler, the momentum at each step becomes:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
||||||
|
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
|
||||||
|
|
||||||
|
It has been proposed in
|
||||||
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this
|
||||||
|
only implements the cosine annealing part of SGDR, and not the restarts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
T_max (int): Maximum number of iterations.
|
||||||
|
eta_min (float): Minimum momentum value. Defaults to 0.
|
||||||
|
begin (int): Step at which to start updating the momentum.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the momentum.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled momentum is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the momentum for each update.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||||
|
https://arxiv.org/abs/1608.03983
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
T_max: int,
|
||||||
|
eta_min: int = 0,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='momentum',
|
||||||
|
T_max=T_max,
|
||||||
|
eta_min=eta_min,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class ExponentialMomentum(ExponentialParamScheduler):
|
||||||
|
"""Decays the momentum of each parameter group by gamma every epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
gamma (float): Multiplicative factor of momentum value decay.
|
||||||
|
begin (int): Step at which to start updating the momentum.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the momentum.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled momentum is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the momentum for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
gamma: float,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='momentum',
|
||||||
|
gamma=gamma,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class LinearMomentum(LinearParamScheduler):
|
||||||
|
"""Decays the momentum of each parameter group by linearly changing
|
||||||
|
small multiplicative factor until the number of epoch reaches a pre-defined
|
||||||
|
milestone: ``end``.
|
||||||
|
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
momentum from outside this scheduler.
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
start_factor (float): The number we multiply momentum in the
|
||||||
|
first epoch. The multiplication factor changes towards end_factor
|
||||||
|
in the following epochs. Defaults to 1./3.
|
||||||
|
end_factor (float): The number we multiply momentum at the end
|
||||||
|
of linear changing process. Defaults to 1.0.
|
||||||
|
begin (int): Step at which to start updating the momentum.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the momentum.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled momentum is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the momentum for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
start_factor: float = 1.0 / 3,
|
||||||
|
end_factor: float = 1.0,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='momentum',
|
||||||
|
start_factor=start_factor,
|
||||||
|
end_factor=end_factor,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class MultiStepMomentum(MultiStepParamScheduler):
|
||||||
|
"""Decays the specified momentum in each parameter group by gamma once the
|
||||||
|
number of epoch reaches one of the milestones. Notice that such decay can
|
||||||
|
happen simultaneously with other changes to the momentum from outside this
|
||||||
|
scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
milestones (list): List of epoch indices. Must be increasing.
|
||||||
|
gamma (float): Multiplicative factor of momentum value decay.
|
||||||
|
Defaults to 0.1.
|
||||||
|
begin (int): Step at which to start updating the momentum.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the momentum.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled momentum is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the momentum for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
milestones: List[int],
|
||||||
|
gamma: float = 0.1,
|
||||||
|
last_step: int = -1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='momentum',
|
||||||
|
milestones=milestones,
|
||||||
|
gamma=gamma,
|
||||||
|
last_step=last_step,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class StepMomentum(StepParamScheduler):
|
||||||
|
"""Decays the momentum of each parameter group by gamma every step_size
|
||||||
|
epochs. Notice that such decay can happen simultaneously with other changes
|
||||||
|
to the momentum from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
step_size (int): Period of momentum value decay.
|
||||||
|
gamma (float): Multiplicative factor of momentum value decay.
|
||||||
|
Defaults to 0.1.
|
||||||
|
begin (int): Step at which to start updating the momentum.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the momentum.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled momentum is updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the momentum for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
step_size: int,
|
||||||
|
gamma: float = 0.1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='momentum',
|
||||||
|
step_size=step_size,
|
||||||
|
gamma=gamma,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
600
mmengine/optim/scheduler/param_scheduler.py
Normal file
600
mmengine/optim/scheduler/param_scheduler.py
Normal file
@ -0,0 +1,600 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
import math
|
||||||
|
import warnings
|
||||||
|
import weakref
|
||||||
|
from collections import Counter
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Callable, List
|
||||||
|
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
from mmengine.registry import PARAM_SCHEDULERS
|
||||||
|
|
||||||
|
INF = int(1e9)
|
||||||
|
|
||||||
|
|
||||||
|
class _ParamScheduler:
|
||||||
|
"""Base class for parameter schedulers.
|
||||||
|
|
||||||
|
It should be inherited by all schedulers that schedule parameters in the
|
||||||
|
optimizer's ``param_groups``. All subclasses should overwrite the
|
||||||
|
``_get_value()`` according to their own schedule strategy.
|
||||||
|
The implementation is motivated by
|
||||||
|
https://github.com/pytorch/pytorch/blob/master/torch/optim/lr_scheduler.py.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resuming without
|
||||||
|
state dict. Default value ``-1`` means the ``step`` function is
|
||||||
|
never be called before. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
""" # noqa: E501
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
param_name: str,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
|
||||||
|
# Attach optimizer
|
||||||
|
if not isinstance(optimizer, Optimizer):
|
||||||
|
raise TypeError('``optimizer`` should be an Optimizer,'
|
||||||
|
'but got {}'.format(type(optimizer).__name__))
|
||||||
|
self.optimizer = optimizer
|
||||||
|
self.param_name = param_name
|
||||||
|
|
||||||
|
if end <= begin:
|
||||||
|
raise ValueError('end should be larger than begin, but got'
|
||||||
|
' begin={}, end={}'.format(begin, end))
|
||||||
|
self.begin = begin
|
||||||
|
self.end = end
|
||||||
|
|
||||||
|
self.by_epoch = by_epoch
|
||||||
|
|
||||||
|
assert isinstance(last_step, int) and last_step >= -1
|
||||||
|
# Initialize valid step count and base values
|
||||||
|
if last_step == -1:
|
||||||
|
for group in optimizer.param_groups:
|
||||||
|
# If the param is never be scheduled, record the current value
|
||||||
|
# as the initial value.
|
||||||
|
group.setdefault(f'initial_{param_name}', group[param_name])
|
||||||
|
else:
|
||||||
|
for i, group in enumerate(optimizer.param_groups):
|
||||||
|
if f'initial_{param_name}' not in group:
|
||||||
|
raise KeyError(
|
||||||
|
f"param 'initial_{param_name}' is not specified "
|
||||||
|
'in param_groups[{}] when resuming an optimizer'.
|
||||||
|
format(i))
|
||||||
|
self.base_values = [
|
||||||
|
group[f'initial_{param_name}'] for group in optimizer.param_groups
|
||||||
|
]
|
||||||
|
self.last_step = last_step
|
||||||
|
|
||||||
|
# Following https://github.com/pytorch/pytorch/issues/20124
|
||||||
|
# We would like to ensure that `scheduler.step()` is called after
|
||||||
|
# `optimizer.step()`
|
||||||
|
def with_counter(method: Callable):
|
||||||
|
if getattr(method, '_with_counter', False):
|
||||||
|
# `optimizer.step()` has already been replaced, return.
|
||||||
|
return method
|
||||||
|
|
||||||
|
# Keep a weak reference to the optimizer instance to prevent
|
||||||
|
# cyclic references.
|
||||||
|
instance_ref = weakref.ref(method.__self__) # type: ignore
|
||||||
|
# Get the unbound method for the same purpose.
|
||||||
|
func = method.__func__ # type: ignore
|
||||||
|
cls = instance_ref().__class__ # type: ignore
|
||||||
|
del method
|
||||||
|
|
||||||
|
@wraps(func)
|
||||||
|
def wrapper(*args, **kwargs):
|
||||||
|
instance = instance_ref()
|
||||||
|
instance._global_step += 1
|
||||||
|
wrapped = func.__get__(instance, cls)
|
||||||
|
return wrapped(*args, **kwargs)
|
||||||
|
|
||||||
|
# Note that the returned function here is no longer a bound method,
|
||||||
|
# so attributes like `__func__` and `__self__` no longer exist.
|
||||||
|
wrapper._with_counter = True # type: ignore
|
||||||
|
return wrapper
|
||||||
|
|
||||||
|
# add counter to optimizer
|
||||||
|
self.optimizer.step = with_counter(self.optimizer.step)
|
||||||
|
self.optimizer._global_step = -1
|
||||||
|
|
||||||
|
self._global_step = -1
|
||||||
|
self.verbose = verbose
|
||||||
|
|
||||||
|
self.step()
|
||||||
|
|
||||||
|
def state_dict(self) -> dict:
|
||||||
|
"""Returns the state of the scheduler as a :class:`dict`.
|
||||||
|
|
||||||
|
It contains an entry for every variable in self.__dict__ which is not
|
||||||
|
the optimizer.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
dict: scheduler state.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
key: value
|
||||||
|
for key, value in self.__dict__.items() if key != 'optimizer'
|
||||||
|
}
|
||||||
|
|
||||||
|
def load_state_dict(self, state_dict: dict):
|
||||||
|
"""Loads the schedulers state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state_dict (dict): scheduler state. Should be an object returned
|
||||||
|
from a call to :meth:`state_dict`.
|
||||||
|
"""
|
||||||
|
self.__dict__.update(state_dict)
|
||||||
|
|
||||||
|
def get_last_value(self):
|
||||||
|
"""Return the last computed value by current scheduler.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
list: A list of the last computed value of the optimizer's
|
||||||
|
``param_group``.
|
||||||
|
"""
|
||||||
|
return self._last_value
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
"""Compute value using chainable form of the scheduler."""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def print_value(self, is_verbose: bool, group: int, value: float):
|
||||||
|
"""Display the current parameter value.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
is_verbose (bool): Whether to print the value.
|
||||||
|
group (int): The index of the current ``param_group``.
|
||||||
|
value (float): The parameter value.
|
||||||
|
"""
|
||||||
|
if is_verbose:
|
||||||
|
print('Adjusting parameter value'
|
||||||
|
' of group {} to {:.4e}.'.format(group, value))
|
||||||
|
|
||||||
|
def step(self):
|
||||||
|
"""Adjusts the parameter value of each parameter group based on the
|
||||||
|
specified schedule."""
|
||||||
|
# Raise a warning if old pattern is detected
|
||||||
|
# https://github.com/pytorch/pytorch/issues/20124
|
||||||
|
if self._global_step == 0:
|
||||||
|
if not hasattr(self.optimizer.step, '_with_counter'):
|
||||||
|
warnings.warn(
|
||||||
|
'Seems like `optimizer.step()` has been overridden after'
|
||||||
|
'parameter value scheduler initialization. Please, make'
|
||||||
|
'sure to call `optimizer.step()` before'
|
||||||
|
'`scheduler.step()`. See more details at'
|
||||||
|
'https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', # noqa: E501
|
||||||
|
UserWarning)
|
||||||
|
|
||||||
|
# Just check if there were two first scheduler.step() calls
|
||||||
|
# before optimizer.step()
|
||||||
|
elif self.optimizer._global_step < 0:
|
||||||
|
warnings.warn(
|
||||||
|
'Detected call of `scheduler.step()` before'
|
||||||
|
'`optimizer.step()`. In PyTorch 1.1.0 and later, you'
|
||||||
|
'should call them in the opposite order: '
|
||||||
|
'`optimizer.step()` before `scheduler.step()`. '
|
||||||
|
'Failure to do this will result in PyTorch skipping '
|
||||||
|
'the first value of the parameter value schedule. '
|
||||||
|
'See more details at https://pytorch.org/docs/stable/optim.html#how-to-adjust-learning-rate', # noqa: E501
|
||||||
|
UserWarning)
|
||||||
|
self._global_step += 1
|
||||||
|
|
||||||
|
# Compute parameter value per param group in the effective range
|
||||||
|
if self.begin <= self._global_step < self.end:
|
||||||
|
self.last_step += 1
|
||||||
|
values = self._get_value()
|
||||||
|
|
||||||
|
for i, data in enumerate(zip(self.optimizer.param_groups, values)):
|
||||||
|
param_group, value = data
|
||||||
|
param_group[self.param_name] = value
|
||||||
|
self.print_value(self.verbose, i, value)
|
||||||
|
|
||||||
|
self._last_value = [
|
||||||
|
group[self.param_name] for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class StepParamScheduler(_ParamScheduler):
|
||||||
|
"""Decays the parameter value of each parameter group by gamma every
|
||||||
|
step_size epochs. Notice that such decay can happen simultaneously with
|
||||||
|
other changes to the parameter value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
step_size (int): Period of parameter value decay.
|
||||||
|
gamma (float): Multiplicative factor of parameter value decay.
|
||||||
|
Defaults to 0.1.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
param_name: str,
|
||||||
|
step_size: int,
|
||||||
|
gamma: float = 0.1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
self.step_size = step_size
|
||||||
|
self.gamma = gamma
|
||||||
|
super().__init__(
|
||||||
|
optimizer=optimizer,
|
||||||
|
param_name=param_name,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
if (self.last_step == 0) or (self.last_step % self.step_size != 0):
|
||||||
|
return [
|
||||||
|
group[self.param_name] for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
group[self.param_name] * self.gamma
|
||||||
|
for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class MultiStepParamScheduler(_ParamScheduler):
|
||||||
|
"""Decays the specified parameter in each parameter group by gamma once the
|
||||||
|
number of epoch reaches one of the milestones. Notice that such decay can
|
||||||
|
happen simultaneously with other changes to the parameter from outside this
|
||||||
|
scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
milestones (list): List of epoch indices. Must be increasing.
|
||||||
|
gamma (float): Multiplicative factor of parameter value decay.
|
||||||
|
Defaults to 0.1.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
param_name: str,
|
||||||
|
milestones: List[int],
|
||||||
|
gamma: float = 0.1,
|
||||||
|
last_step: int = -1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
self.milestones = Counter(milestones)
|
||||||
|
self.gamma = gamma
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name=param_name,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
if self.last_step not in self.milestones:
|
||||||
|
return [
|
||||||
|
group[self.param_name] for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
group[self.param_name] *
|
||||||
|
self.gamma**self.milestones[self.last_step]
|
||||||
|
for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class ConstantParamScheduler(_ParamScheduler):
|
||||||
|
"""Decays the parameter value of each parameter group by a small constant
|
||||||
|
factor until the number of epoch reaches a pre-defined milestone: ``end``.
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
parameter value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
factor (float): The number we multiply parameter value until the
|
||||||
|
milestone. Defaults to 1./3.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
param_name: str,
|
||||||
|
factor: float = 1.0 / 3,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
if factor > 1.0 or factor < 0:
|
||||||
|
raise ValueError(
|
||||||
|
'Constant multiplicative factor should between 0 and 1.')
|
||||||
|
|
||||||
|
self.factor = factor
|
||||||
|
self.total_iters = end - begin - 1
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name=param_name,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
if self.last_step == 0:
|
||||||
|
return [
|
||||||
|
group[self.param_name] * self.factor
|
||||||
|
for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
if (self.last_step > self.total_iters
|
||||||
|
or (self.last_step != self.total_iters)):
|
||||||
|
return [
|
||||||
|
group[self.param_name] for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
if self.last_step == self.total_iters:
|
||||||
|
return [
|
||||||
|
group[self.param_name] * (1.0 / self.factor)
|
||||||
|
for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class ExponentialParamScheduler(_ParamScheduler):
|
||||||
|
"""Decays the parameter value of each parameter group by gamma every epoch.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
gamma (float): Multiplicative factor of parameter value decay.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
param_name: str,
|
||||||
|
gamma: float,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
self.gamma = gamma
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name=param_name,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
if self.last_step == 0:
|
||||||
|
return [
|
||||||
|
group[self.param_name] for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
return [
|
||||||
|
group[self.param_name] * self.gamma
|
||||||
|
for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class CosineAnnealingParamScheduler(_ParamScheduler):
|
||||||
|
r"""Set the parameter value of each parameter group using a cosine annealing
|
||||||
|
schedule, where :math:`\eta_{max}` is set to the initial value and
|
||||||
|
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\begin{aligned}
|
||||||
|
\eta_t & = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1
|
||||||
|
+ \cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right),
|
||||||
|
& T_{cur} \neq (2k+1)T_{max}; \\
|
||||||
|
\eta_{t+1} & = \eta_{t} + \frac{1}{2}(\eta_{max} - \eta_{min})
|
||||||
|
\left(1 - \cos\left(\frac{1}{T_{max}}\pi\right)\right),
|
||||||
|
& T_{cur} = (2k+1)T_{max}.
|
||||||
|
\end{aligned}
|
||||||
|
|
||||||
|
Notice that because the schedule
|
||||||
|
is defined recursively, the parameter value can be simultaneously modified
|
||||||
|
outside this scheduler by other operators. If the parameter value is set
|
||||||
|
solely by this scheduler, the parameter value at each step becomes:
|
||||||
|
|
||||||
|
.. math::
|
||||||
|
\eta_t = \eta_{min} + \frac{1}{2}(\eta_{max} - \eta_{min})\left(1 +
|
||||||
|
\cos\left(\frac{T_{cur}}{T_{max}}\pi\right)\right)
|
||||||
|
|
||||||
|
It has been proposed in
|
||||||
|
`SGDR: Stochastic Gradient Descent with Warm Restarts`_. Note that this
|
||||||
|
only implements the cosine annealing part of SGDR, and not the restarts.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
T_max (int): Maximum number of iterations.
|
||||||
|
eta_min (float): Minimum parameter value. Defaults to 0.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
|
||||||
|
.. _SGDR\: Stochastic Gradient Descent with Warm Restarts:
|
||||||
|
https://arxiv.org/abs/1608.03983
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
param_name: str,
|
||||||
|
T_max: int,
|
||||||
|
eta_min: float = 0.,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
self.T_max = T_max
|
||||||
|
self.eta_min = eta_min
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name=param_name,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
if self.last_step == 0:
|
||||||
|
return [
|
||||||
|
group[self.param_name] for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0:
|
||||||
|
return [
|
||||||
|
group[self.param_name] + (base_value - self.eta_min) *
|
||||||
|
(1 - math.cos(math.pi / self.T_max)) / 2
|
||||||
|
for base_value, group in zip(self.base_values,
|
||||||
|
self.optimizer.param_groups)
|
||||||
|
]
|
||||||
|
return [(1 + math.cos(math.pi * self.last_step / self.T_max)) /
|
||||||
|
(1 + math.cos(math.pi * (self.last_step - 1) / self.T_max)) *
|
||||||
|
(group[self.param_name] - self.eta_min) + self.eta_min
|
||||||
|
for group in self.optimizer.param_groups]
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class LinearParamScheduler(_ParamScheduler):
|
||||||
|
"""Decays the parameter value of each parameter group by linearly changing
|
||||||
|
small multiplicative factor until the number of epoch reaches a pre-defined
|
||||||
|
milestone: ``end``.
|
||||||
|
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
parameter value from outside this scheduler.
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
start_factor (float): The number we multiply parameter value in the
|
||||||
|
first epoch. The multiplication factor changes towards end_factor
|
||||||
|
in the following epochs. Defaults to 1./3.
|
||||||
|
end_factor (float): The number we multiply parameter value at the end
|
||||||
|
of linear changing process. Defaults to 1.0.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
param_name: str,
|
||||||
|
start_factor: float = 1.0 / 3,
|
||||||
|
end_factor: float = 1.0,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
if start_factor > 1.0 or start_factor < 0:
|
||||||
|
raise ValueError(
|
||||||
|
'Starting multiplicative factor should between 0 and 1.')
|
||||||
|
|
||||||
|
if end_factor > 1.0 or end_factor < 0:
|
||||||
|
raise ValueError(
|
||||||
|
'Ending multiplicative factor should between 0 and 1.')
|
||||||
|
|
||||||
|
self.start_factor = start_factor
|
||||||
|
self.end_factor = end_factor
|
||||||
|
self.total_iters = end - begin - 1
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name=param_name,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
|
||||||
|
if self.last_step == 0:
|
||||||
|
return [
|
||||||
|
group[self.param_name] * self.start_factor
|
||||||
|
for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
return [
|
||||||
|
group[self.param_name] *
|
||||||
|
(1. + (self.end_factor - self.start_factor) /
|
||||||
|
(self.total_iters * self.start_factor + (self.last_step - 1) *
|
||||||
|
(self.end_factor - self.start_factor)))
|
||||||
|
for group in self.optimizer.param_groups
|
||||||
|
]
|
@ -1,11 +1,12 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .registry import Registry, build_from_cfg
|
from .registry import Registry, build_from_cfg
|
||||||
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS,
|
from .root import (DATA_SAMPLERS, DATASETS, HOOKS, MODELS,
|
||||||
OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, RUNNER_CONSTRUCTORS,
|
OPTIMIZER_CONSTRUCTORS, OPTIMIZERS, PARAM_SCHEDULERS,
|
||||||
RUNNERS, TASK_UTILS, TRANSFORMS, WEIGHT_INITIALIZERS)
|
RUNNER_CONSTRUCTORS, RUNNERS, TASK_UTILS, TRANSFORMS,
|
||||||
|
WEIGHT_INITIALIZERS)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
|
'Registry', 'build_from_cfg', 'RUNNERS', 'RUNNER_CONSTRUCTORS', 'HOOKS',
|
||||||
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
|
'DATASETS', 'DATA_SAMPLERS', 'TRANSFORMS', 'MODELS', 'WEIGHT_INITIALIZERS',
|
||||||
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS'
|
'OPTIMIZERS', 'OPTIMIZER_CONSTRUCTORS', 'TASK_UTILS', 'PARAM_SCHEDULERS'
|
||||||
]
|
]
|
||||||
|
@ -29,6 +29,8 @@ WEIGHT_INITIALIZERS = Registry('weight initializer')
|
|||||||
OPTIMIZERS = Registry('optimizer')
|
OPTIMIZERS = Registry('optimizer')
|
||||||
# manage constructors that customize the optimization hyperparameters.
|
# manage constructors that customize the optimization hyperparameters.
|
||||||
OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor')
|
OPTIMIZER_CONSTRUCTORS = Registry('optimizer constructor')
|
||||||
|
# mangage all kinds of parameter schedulers like `MultiStepLR`
|
||||||
|
PARAM_SCHEDULERS = Registry('parameter scheduler')
|
||||||
|
|
||||||
# manage task-specific modules like anchor generators and box coders
|
# manage task-specific modules like anchor generators and box coders
|
||||||
TASK_UTILS = Registry('task util')
|
TASK_UTILS = Registry('task util')
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import math
|
import math
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
@ -39,7 +40,7 @@ class TestLRScheduler(TestCase):
|
|||||||
_ParamScheduler(self.optimizer, param_name='lr')
|
_ParamScheduler(self.optimizer, param_name='lr')
|
||||||
|
|
||||||
def test_invalid_optimizer(self):
|
def test_invalid_optimizer(self):
|
||||||
with self.assertRaisesRegex(TypeError, 'is not an Optimizer'):
|
with self.assertRaisesRegex(TypeError, 'should be an Optimizer'):
|
||||||
StepLR('invalid_optimizer', step_size=1)
|
StepLR('invalid_optimizer', step_size=1)
|
||||||
|
|
||||||
def test_overwrite_optimzer_step(self):
|
def test_overwrite_optimzer_step(self):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import math
|
import math
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
@ -37,7 +38,7 @@ class TestMomentumScheduler(TestCase):
|
|||||||
self.model.parameters(), lr=0.01, momentum=0.05, weight_decay=5e-4)
|
self.model.parameters(), lr=0.01, momentum=0.05, weight_decay=5e-4)
|
||||||
|
|
||||||
def test_invalid_optimizer(self):
|
def test_invalid_optimizer(self):
|
||||||
with self.assertRaisesRegex(TypeError, 'is not an Optimizer'):
|
with self.assertRaisesRegex(TypeError, 'should be an Optimizer'):
|
||||||
StepMomentum('invalid_optimizer', step_size=1)
|
StepMomentum('invalid_optimizer', step_size=1)
|
||||||
|
|
||||||
def test_overwrite_optimzer_step(self):
|
def test_overwrite_optimzer_step(self):
|
||||||
|
@ -1,3 +1,4 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
import math
|
import math
|
||||||
from unittest import TestCase
|
from unittest import TestCase
|
||||||
|
|
||||||
@ -42,7 +43,7 @@ class TestParameterScheduler(TestCase):
|
|||||||
_ParamScheduler(self.optimizer, param_name='lr')
|
_ParamScheduler(self.optimizer, param_name='lr')
|
||||||
|
|
||||||
def test_invalid_optimizer(self):
|
def test_invalid_optimizer(self):
|
||||||
with self.assertRaisesRegex(TypeError, 'is not an Optimizer'):
|
with self.assertRaisesRegex(TypeError, 'should be an Optimizer'):
|
||||||
StepParamScheduler('invalid_optimizer', 'lr', step_size=1)
|
StepParamScheduler('invalid_optimizer', 'lr', step_size=1)
|
||||||
|
|
||||||
def test_overwrite_optimzer_step(self):
|
def test_overwrite_optimzer_step(self):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user