[Feature] Support CosineRestartParamScheduler. (#397)
* [Feature] Support CosineRestartParamScheduler. * add ut and docstring * add docstringpull/424/head
parent
b14cbc2576
commit
813f49bf23
|
@ -1,16 +1,22 @@
|
||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
|
# yapf: disable
|
||||||
LinearLR, MultiStepLR, OneCycleLR, PolyLR, StepLR)
|
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR,
|
||||||
|
ExponentialLR, LinearLR, MultiStepLR, OneCycleLR,
|
||||||
|
PolyLR, StepLR)
|
||||||
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
|
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
|
||||||
ExponentialMomentum, LinearMomentum,
|
CosineRestartMomentum, ExponentialMomentum,
|
||||||
MultiStepMomentum, PolyMomentum, StepMomentum)
|
LinearMomentum, MultiStepMomentum,
|
||||||
|
PolyMomentum, StepMomentum)
|
||||||
from .param_scheduler import (ConstantParamScheduler,
|
from .param_scheduler import (ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
|
CosineRestartParamScheduler,
|
||||||
ExponentialParamScheduler, LinearParamScheduler,
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
MultiStepParamScheduler, OneCycleParamScheduler,
|
MultiStepParamScheduler, OneCycleParamScheduler,
|
||||||
PolyParamScheduler, StepParamScheduler,
|
PolyParamScheduler, StepParamScheduler,
|
||||||
_ParamScheduler)
|
_ParamScheduler)
|
||||||
|
|
||||||
|
# yapf: enable
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
|
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
|
||||||
'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum',
|
'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum',
|
||||||
|
@ -19,5 +25,6 @@ __all__ = [
|
||||||
'ExponentialParamScheduler', 'LinearParamScheduler',
|
'ExponentialParamScheduler', 'LinearParamScheduler',
|
||||||
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
|
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
|
||||||
'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler',
|
'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler',
|
||||||
'OneCycleLR'
|
'OneCycleLR', 'CosineRestartParamScheduler', 'CosineRestartLR',
|
||||||
|
'CosineRestartMomentum'
|
||||||
]
|
]
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from mmengine.registry import PARAM_SCHEDULERS
|
from mmengine.registry import PARAM_SCHEDULERS
|
||||||
from .param_scheduler import (ConstantParamScheduler,
|
from .param_scheduler import (ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
|
CosineRestartParamScheduler,
|
||||||
ExponentialParamScheduler, LinearParamScheduler,
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
MultiStepParamScheduler, OneCycleParamScheduler,
|
MultiStepParamScheduler, OneCycleParamScheduler,
|
||||||
PolyParamScheduler, StepParamScheduler)
|
PolyParamScheduler, StepParamScheduler)
|
||||||
|
@ -277,3 +278,35 @@ class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
|
||||||
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
||||||
https://arxiv.org/abs/1708.07120
|
https://arxiv.org/abs/1708.07120
|
||||||
"""# noqa E501
|
"""# noqa E501
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler):
|
||||||
|
"""Sets the learning rate of each parameter group according to the cosine
|
||||||
|
annealing with restarts scheme. The cosine restart policy anneals the
|
||||||
|
learning rate from the initial value to `eta_min` with a cosine annealing
|
||||||
|
schedule and then restarts another period from the maximum value multiplied
|
||||||
|
with `restart_weight`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||||
|
optimizer.
|
||||||
|
periods (list[int]): Periods for each cosine anneling cycle.
|
||||||
|
restart_weights (list[float]): Restart weights at each
|
||||||
|
restart iteration. Defaults to [1].
|
||||||
|
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||||
|
Defaults to None.
|
||||||
|
eta_min_ratio (float, optional): The ratio of minimum parameter value
|
||||||
|
to the base parameter value. Either `min_lr` or `min_lr_ratio`
|
||||||
|
should be specified. Default: None.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
|
@ -2,6 +2,7 @@
|
||||||
from mmengine.registry import PARAM_SCHEDULERS
|
from mmengine.registry import PARAM_SCHEDULERS
|
||||||
from .param_scheduler import (ConstantParamScheduler,
|
from .param_scheduler import (ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
|
CosineRestartParamScheduler,
|
||||||
ExponentialParamScheduler, LinearParamScheduler,
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
MultiStepParamScheduler, PolyParamScheduler,
|
MultiStepParamScheduler, PolyParamScheduler,
|
||||||
StepParamScheduler)
|
StepParamScheduler)
|
||||||
|
@ -243,3 +244,36 @@ class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler):
|
||||||
verbose (bool): Whether to print the value for each update.
|
verbose (bool): Whether to print the value for each update.
|
||||||
Defaults to False.
|
Defaults to False.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class CosineRestartMomentum(MomentumSchedulerMixin,
|
||||||
|
CosineRestartParamScheduler):
|
||||||
|
"""Sets the momentum of each parameter group according to the cosine
|
||||||
|
annealing with restarts scheme. The cosine restart policy anneals the
|
||||||
|
momentum from the initial value to `eta_min` with a cosine annealing
|
||||||
|
schedule and then restarts another period from the maximum value multiplied
|
||||||
|
with `restart_weight`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||||
|
optimizer.
|
||||||
|
periods (list[int]): Periods for each cosine anneling cycle.
|
||||||
|
restart_weights (list[float]): Restart weights at each
|
||||||
|
restart iteration. Defaults to [1].
|
||||||
|
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||||
|
Defaults to None.
|
||||||
|
eta_min_ratio (float, optional): The ratio of minimum parameter value
|
||||||
|
to the base parameter value. Either `min_lr` or `min_lr_ratio`
|
||||||
|
should be specified. Default: None.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
|
@ -9,7 +9,7 @@ import warnings
|
||||||
import weakref
|
import weakref
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import Callable, List, Optional, Union
|
from typing import Callable, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
|
||||||
|
@ -227,6 +227,8 @@ class StepParamScheduler(_ParamScheduler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
|
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
step_size (int): Period of parameter value decay.
|
step_size (int): Period of parameter value decay.
|
||||||
gamma (float): Multiplicative factor of parameter value decay.
|
gamma (float): Multiplicative factor of parameter value decay.
|
||||||
Defaults to 0.1.
|
Defaults to 0.1.
|
||||||
|
@ -313,6 +315,8 @@ class MultiStepParamScheduler(_ParamScheduler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
|
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
milestones (list): List of epoch indices. Must be increasing.
|
milestones (list): List of epoch indices. Must be increasing.
|
||||||
gamma (float): Multiplicative factor of parameter value decay.
|
gamma (float): Multiplicative factor of parameter value decay.
|
||||||
Defaults to 0.1.
|
Defaults to 0.1.
|
||||||
|
@ -401,6 +405,8 @@ class ConstantParamScheduler(_ParamScheduler):
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||||
optimizer.
|
optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
factor (float): The number we multiply parameter value until the
|
factor (float): The number we multiply parameter value until the
|
||||||
milestone. Defaults to 1./3.
|
milestone. Defaults to 1./3.
|
||||||
begin (int): Step at which to start updating the parameters.
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
@ -488,6 +494,8 @@ class ExponentialParamScheduler(_ParamScheduler):
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||||
optimizer.
|
optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
gamma (float): Multiplicative factor of parameter value decay.
|
gamma (float): Multiplicative factor of parameter value decay.
|
||||||
begin (int): Step at which to start updating the parameters.
|
begin (int): Step at which to start updating the parameters.
|
||||||
Defaults to 0.
|
Defaults to 0.
|
||||||
|
@ -585,6 +593,8 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||||
optimizer.
|
optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
T_max (int, optional): Maximum number of iterations. If not specified,
|
T_max (int, optional): Maximum number of iterations. If not specified,
|
||||||
use ``end - begin``. Defaults to None.
|
use ``end - begin``. Defaults to None.
|
||||||
eta_min (float): Minimum parameter value. Defaults to 0.
|
eta_min (float): Minimum parameter value. Defaults to 0.
|
||||||
|
@ -684,6 +694,8 @@ class LinearParamScheduler(_ParamScheduler):
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||||
optimizer.
|
optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
start_factor (float): The number we multiply parameter value in the
|
start_factor (float): The number we multiply parameter value in the
|
||||||
first epoch. The multiplication factor changes towards end_factor
|
first epoch. The multiplication factor changes towards end_factor
|
||||||
in the following epochs. Defaults to 1./3.
|
in the following epochs. Defaults to 1./3.
|
||||||
|
@ -780,6 +792,8 @@ class PolyParamScheduler(_ParamScheduler):
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||||
optimizer.
|
optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
eta_min (float): Minimum parameter value at the end of scheduling.
|
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||||
Defaults to 0.
|
Defaults to 0.
|
||||||
power (float): The power of the polynomial. Defaults to 1.0.
|
power (float): The power of the polynomial. Defaults to 1.0.
|
||||||
|
@ -882,6 +896,8 @@ class OneCycleParamScheduler(_ParamScheduler):
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): Wrapped optimizer.
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
eta_max (float or list): Upper parameter value boundaries in the cycle
|
eta_max (float or list): Upper parameter value boundaries in the cycle
|
||||||
for each parameter group.
|
for each parameter group.
|
||||||
total_steps (int): The total number of steps in the cycle. Note that
|
total_steps (int): The total number of steps in the cycle. Note that
|
||||||
|
@ -1094,3 +1110,159 @@ class OneCycleParamScheduler(_ParamScheduler):
|
||||||
params.append(computed_param)
|
params.append(computed_param)
|
||||||
|
|
||||||
return params
|
return params
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class CosineRestartParamScheduler(_ParamScheduler):
|
||||||
|
"""Sets the parameters of each parameter group according to the cosine
|
||||||
|
annealing with restarts scheme. The cosine restart policy anneals the
|
||||||
|
parameter from the initial value to `eta_min` with a cosine annealing
|
||||||
|
schedule and then restarts another period from the maximum value multiplied
|
||||||
|
with `restart_weight`.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||||
|
optimizer.
|
||||||
|
param_name (str): Name of the parameter to be adjusted, such as
|
||||||
|
``lr``, ``momentum``.
|
||||||
|
periods (list[int]): Periods for each cosine anneling cycle.
|
||||||
|
restart_weights (list[float]): Restart weights at each
|
||||||
|
restart iteration. Defaults to [1].
|
||||||
|
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||||
|
Defaults to None.
|
||||||
|
eta_min_ratio (float, optional): The ratio of minimum parameter value
|
||||||
|
to the base parameter value. Either `min_lr` or `min_lr_ratio`
|
||||||
|
should be specified. Default: None.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Union[Optimizer, OptimWrapper],
|
||||||
|
param_name: str,
|
||||||
|
periods: List[int],
|
||||||
|
restart_weights: Sequence[float] = (1, ),
|
||||||
|
eta_min: Optional[float] = None,
|
||||||
|
eta_min_ratio: Optional[float] = None,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
assert (eta_min is None) ^ (eta_min_ratio is None)
|
||||||
|
self.periods = periods
|
||||||
|
self.eta_min = eta_min
|
||||||
|
self.eta_min_ratio = eta_min_ratio
|
||||||
|
self.restart_weights = restart_weights
|
||||||
|
assert (len(self.periods) == len(self.restart_weights)
|
||||||
|
), 'periods and restart_weights should have the same length.'
|
||||||
|
self.cumulative_periods = [
|
||||||
|
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
||||||
|
]
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name=param_name,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def build_iter_from_epoch(cls,
|
||||||
|
*args,
|
||||||
|
periods,
|
||||||
|
begin=0,
|
||||||
|
end=INF,
|
||||||
|
by_epoch=True,
|
||||||
|
epoch_length=None,
|
||||||
|
**kwargs):
|
||||||
|
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||||
|
config."""
|
||||||
|
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||||
|
'be converted to iter-based.'
|
||||||
|
assert epoch_length is not None and epoch_length > 0, \
|
||||||
|
f'`epoch_length` must be a positive integer, ' \
|
||||||
|
f'but got {epoch_length}.'
|
||||||
|
periods = [p * epoch_length for p in periods]
|
||||||
|
by_epoch = False
|
||||||
|
begin = int(begin * epoch_length)
|
||||||
|
if end != INF:
|
||||||
|
end = int(end * epoch_length)
|
||||||
|
return cls(
|
||||||
|
*args,
|
||||||
|
periods=periods,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
"""Compute value using chainable form of the scheduler."""
|
||||||
|
idx = self.get_position_from_periods(self.last_step,
|
||||||
|
self.cumulative_periods)
|
||||||
|
# if current step is not in the periods, return origin parameters
|
||||||
|
if idx is None:
|
||||||
|
return [
|
||||||
|
group[self.param_name] for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
current_weight = self.restart_weights[idx]
|
||||||
|
nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
|
||||||
|
current_periods = self.periods[idx]
|
||||||
|
step = self.last_step - nearest_restart
|
||||||
|
values = []
|
||||||
|
for base_value, group in zip(self.base_values,
|
||||||
|
self.optimizer.param_groups):
|
||||||
|
eta_max = base_value * current_weight
|
||||||
|
if self.eta_min_ratio is None:
|
||||||
|
eta_min = self.eta_min * (1 - current_weight)
|
||||||
|
else:
|
||||||
|
eta_min = base_value * self.eta_min_ratio * (1 -
|
||||||
|
current_weight)
|
||||||
|
if step == 0:
|
||||||
|
values.append(eta_max)
|
||||||
|
|
||||||
|
elif (step - 1 - current_periods) % (2 * current_periods) == 0:
|
||||||
|
values.append(group[self.param_name] + (eta_max - eta_min) *
|
||||||
|
(1 - math.cos(math.pi / current_periods)) / 2)
|
||||||
|
else:
|
||||||
|
values.append(
|
||||||
|
(1 + math.cos(math.pi * step / current_periods)) /
|
||||||
|
(1 + math.cos(math.pi * (step - 1) / current_periods)) *
|
||||||
|
(group[self.param_name] - eta_min) + eta_min)
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_position_from_periods(
|
||||||
|
iteration: int, cumulative_periods: List[int]) -> Optional[int]:
|
||||||
|
"""Get the position from a period list.
|
||||||
|
|
||||||
|
It will return the index of the right-closest number in the period
|
||||||
|
list.
|
||||||
|
For example, the cumulative_periods = [100, 200, 300, 400],
|
||||||
|
if iteration == 50, return 0;
|
||||||
|
if iteration == 210, return 2;
|
||||||
|
if iteration == 300, return 3.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
iteration (int): Current iteration.
|
||||||
|
cumulative_periods (list[int]): Cumulative period list.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Optional[int]: The position of the right-closest number in the
|
||||||
|
period list. If not in the period, return None.
|
||||||
|
"""
|
||||||
|
for i, period in enumerate(cumulative_periods):
|
||||||
|
if iteration < period:
|
||||||
|
return i
|
||||||
|
return None
|
||||||
|
|
|
@ -7,8 +7,8 @@ import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
|
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
|
||||||
ExponentialLR, LinearLR, MultiStepLR,
|
CosineRestartLR, ExponentialLR, LinearLR,
|
||||||
OneCycleLR, PolyLR, StepLR,
|
MultiStepLR, OneCycleLR, PolyLR, StepLR,
|
||||||
_ParamScheduler)
|
_ParamScheduler)
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
@ -333,6 +333,34 @@ class TestLRScheduler(TestCase):
|
||||||
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
|
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
|
||||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
|
def test_cosine_restart_scheduler(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
CosineRestartLR(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0,
|
||||||
|
eta_min_ratio=0.1)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
CosineRestartLR(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5, 0.0],
|
||||||
|
eta_min=0)
|
||||||
|
single_targets = [
|
||||||
|
0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
|
||||||
|
0.0086372, 0.0023872, 0.0023872
|
||||||
|
]
|
||||||
|
targets = [
|
||||||
|
single_targets, [t * self.layer2_mult for t in single_targets]
|
||||||
|
]
|
||||||
|
scheduler = CosineRestartLR(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0)
|
||||||
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
|
@ -387,6 +415,20 @@ class TestLRScheduler(TestCase):
|
||||||
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
|
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
|
||||||
epochs=10)
|
epochs=10)
|
||||||
|
|
||||||
|
def test_cosine_restart_scheduler_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: CosineRestartLR(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0),
|
||||||
|
lambda: CosineRestartLR(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 6],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0),
|
||||||
|
epochs=10)
|
||||||
|
|
||||||
def test_step_scheduler_convert_iterbased(self):
|
def test_step_scheduler_convert_iterbased(self):
|
||||||
# invalid epoch_length
|
# invalid epoch_length
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
|
|
|
@ -8,6 +8,7 @@ import torch.optim as optim
|
||||||
|
|
||||||
from mmengine.optim.scheduler import (ConstantMomentum,
|
from mmengine.optim.scheduler import (ConstantMomentum,
|
||||||
CosineAnnealingMomentum,
|
CosineAnnealingMomentum,
|
||||||
|
CosineRestartMomentum,
|
||||||
ExponentialMomentum, LinearMomentum,
|
ExponentialMomentum, LinearMomentum,
|
||||||
MultiStepMomentum, PolyMomentum,
|
MultiStepMomentum, PolyMomentum,
|
||||||
StepMomentum, _ParamScheduler)
|
StepMomentum, _ParamScheduler)
|
||||||
|
@ -399,6 +400,43 @@ class TestMomentumScheduler(TestCase):
|
||||||
self._test_scheduler_value(
|
self._test_scheduler_value(
|
||||||
self.optimizer_with_betas, scheduler, targets, epochs=10)
|
self.optimizer_with_betas, scheduler, targets, epochs=10)
|
||||||
|
|
||||||
|
def test_cosine_restart_scheduler(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
CosineRestartMomentum(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0,
|
||||||
|
eta_min_ratio=0.1)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
CosineRestartMomentum(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5, 0.0],
|
||||||
|
eta_min=0)
|
||||||
|
single_targets = [
|
||||||
|
0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
|
||||||
|
0.0086372, 0.0023872, 0.0023872
|
||||||
|
]
|
||||||
|
targets = [
|
||||||
|
single_targets, [t * self.layer2_mult for t in single_targets]
|
||||||
|
]
|
||||||
|
scheduler = CosineRestartMomentum(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0)
|
||||||
|
self._test_scheduler_value(
|
||||||
|
self.optimizer, scheduler, targets, epochs=10)
|
||||||
|
|
||||||
|
scheduler = CosineRestartMomentum(
|
||||||
|
self.optimizer_with_betas,
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0)
|
||||||
|
self._test_scheduler_value(
|
||||||
|
self.optimizer_with_betas, scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
|
@ -454,6 +492,20 @@ class TestMomentumScheduler(TestCase):
|
||||||
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
|
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
|
||||||
epochs=10)
|
epochs=10)
|
||||||
|
|
||||||
|
def test_cosine_restart_scheduler_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: CosineRestartMomentum(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0),
|
||||||
|
lambda: CosineRestartMomentum(
|
||||||
|
self.optimizer,
|
||||||
|
periods=[4, 6],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0),
|
||||||
|
epochs=10)
|
||||||
|
|
||||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||||
# use Linear in the first 5 epochs and then use MultiStep
|
# use Linear in the first 5 epochs and then use MultiStep
|
||||||
epochs = 12
|
epochs = 12
|
||||||
|
|
|
@ -12,12 +12,13 @@ from mmengine.optim import OptimWrapper
|
||||||
# yapf: disable
|
# yapf: disable
|
||||||
from mmengine.optim.scheduler import (ConstantParamScheduler,
|
from mmengine.optim.scheduler import (ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
|
CosineRestartParamScheduler,
|
||||||
ExponentialParamScheduler,
|
ExponentialParamScheduler,
|
||||||
LinearParamScheduler,
|
LinearParamScheduler,
|
||||||
MultiStepParamScheduler,
|
MultiStepParamScheduler,
|
||||||
|
OneCycleParamScheduler,
|
||||||
PolyParamScheduler, StepParamScheduler,
|
PolyParamScheduler, StepParamScheduler,
|
||||||
_ParamScheduler)
|
_ParamScheduler)
|
||||||
from mmengine.optim.scheduler.param_scheduler import OneCycleParamScheduler
|
|
||||||
# yapf: enable
|
# yapf: enable
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
@ -406,6 +407,37 @@ class TestParameterScheduler(TestCase):
|
||||||
end=iters + 1)
|
end=iters + 1)
|
||||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
|
def test_cosine_restart_scheduler(self):
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
CosineRestartParamScheduler(
|
||||||
|
self.optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0,
|
||||||
|
eta_min_ratio=0.1)
|
||||||
|
with self.assertRaises(AssertionError):
|
||||||
|
CosineRestartParamScheduler(
|
||||||
|
self.optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5, 0.0],
|
||||||
|
eta_min=0)
|
||||||
|
single_targets = [
|
||||||
|
0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
|
||||||
|
0.0086372, 0.0023872, 0.0023872
|
||||||
|
]
|
||||||
|
targets = [
|
||||||
|
single_targets, [t * self.layer2_mult for t in single_targets]
|
||||||
|
]
|
||||||
|
scheduler = CosineRestartParamScheduler(
|
||||||
|
self.optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0)
|
||||||
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
|
@ -483,6 +515,22 @@ class TestParameterScheduler(TestCase):
|
||||||
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
|
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
|
||||||
epochs=10)
|
epochs=10)
|
||||||
|
|
||||||
|
def test_cosine_restart_scheduler_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: CosineRestartParamScheduler(
|
||||||
|
self.optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
periods=[4, 5],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0),
|
||||||
|
lambda: CosineRestartParamScheduler(
|
||||||
|
self.optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
periods=[4, 6],
|
||||||
|
restart_weights=[1, 0.5],
|
||||||
|
eta_min=0),
|
||||||
|
epochs=10)
|
||||||
|
|
||||||
def test_step_scheduler_convert_iterbased(self):
|
def test_step_scheduler_convert_iterbased(self):
|
||||||
# invalid epoch_length
|
# invalid epoch_length
|
||||||
with self.assertRaises(AssertionError):
|
with self.assertRaises(AssertionError):
|
||||||
|
|
Loading…
Reference in New Issue