[Feature] Support CosineRestartParamScheduler. (#397)

* [Feature] Support CosineRestartParamScheduler.

* add ut and docstring

* add docstring
pull/424/head
RangiLyu 2022-08-11 17:57:35 +08:00 committed by GitHub
parent b14cbc2576
commit 813f49bf23
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 397 additions and 9 deletions

View File

@ -1,16 +1,22 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
LinearLR, MultiStepLR, OneCycleLR, PolyLR, StepLR)
# yapf: disable
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR,
ExponentialLR, LinearLR, MultiStepLR, OneCycleLR,
PolyLR, StepLR)
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum,
MultiStepMomentum, PolyMomentum, StepMomentum)
CosineRestartMomentum, ExponentialMomentum,
LinearMomentum, MultiStepMomentum,
PolyMomentum, StepMomentum)
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
CosineRestartParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, OneCycleParamScheduler,
PolyParamScheduler, StepParamScheduler,
_ParamScheduler)
# yapf: enable
__all__ = [
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum',
@ -19,5 +25,6 @@ __all__ = [
'ExponentialParamScheduler', 'LinearParamScheduler',
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler',
'OneCycleLR'
'OneCycleLR', 'CosineRestartParamScheduler', 'CosineRestartLR',
'CosineRestartMomentum'
]

View File

@ -2,6 +2,7 @@
from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
CosineRestartParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, OneCycleParamScheduler,
PolyParamScheduler, StepParamScheduler)
@ -277,3 +278,35 @@ class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
"""# noqa E501
@PARAM_SCHEDULERS.register_module()
class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler):
"""Sets the learning rate of each parameter group according to the cosine
annealing with restarts scheme. The cosine restart policy anneals the
learning rate from the initial value to `eta_min` with a cosine annealing
schedule and then restarts another period from the maximum value multiplied
with `restart_weight`.
Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
periods (list[int]): Periods for each cosine anneling cycle.
restart_weights (list[float]): Restart weights at each
restart iteration. Defaults to [1].
eta_min (float): Minimum parameter value at the end of scheduling.
Defaults to None.
eta_min_ratio (float, optional): The ratio of minimum parameter value
to the base parameter value. Either `min_lr` or `min_lr_ratio`
should be specified. Default: None.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""

View File

@ -2,6 +2,7 @@
from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
CosineRestartParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
@ -243,3 +244,36 @@ class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler):
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
@PARAM_SCHEDULERS.register_module()
class CosineRestartMomentum(MomentumSchedulerMixin,
CosineRestartParamScheduler):
"""Sets the momentum of each parameter group according to the cosine
annealing with restarts scheme. The cosine restart policy anneals the
momentum from the initial value to `eta_min` with a cosine annealing
schedule and then restarts another period from the maximum value multiplied
with `restart_weight`.
Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
periods (list[int]): Periods for each cosine anneling cycle.
restart_weights (list[float]): Restart weights at each
restart iteration. Defaults to [1].
eta_min (float): Minimum parameter value at the end of scheduling.
Defaults to None.
eta_min_ratio (float, optional): The ratio of minimum parameter value
to the base parameter value. Either `min_lr` or `min_lr_ratio`
should be specified. Default: None.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""

View File

@ -9,7 +9,7 @@ import warnings
import weakref
from collections import Counter
from functools import wraps
from typing import Callable, List, Optional, Union
from typing import Callable, List, Optional, Sequence, Union
from torch.optim import Optimizer
@ -227,6 +227,8 @@ class StepParamScheduler(_ParamScheduler):
Args:
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
step_size (int): Period of parameter value decay.
gamma (float): Multiplicative factor of parameter value decay.
Defaults to 0.1.
@ -313,6 +315,8 @@ class MultiStepParamScheduler(_ParamScheduler):
Args:
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
milestones (list): List of epoch indices. Must be increasing.
gamma (float): Multiplicative factor of parameter value decay.
Defaults to 0.1.
@ -401,6 +405,8 @@ class ConstantParamScheduler(_ParamScheduler):
Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
factor (float): The number we multiply parameter value until the
milestone. Defaults to 1./3.
begin (int): Step at which to start updating the parameters.
@ -488,6 +494,8 @@ class ExponentialParamScheduler(_ParamScheduler):
Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
gamma (float): Multiplicative factor of parameter value decay.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
@ -585,6 +593,8 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
T_max (int, optional): Maximum number of iterations. If not specified,
use ``end - begin``. Defaults to None.
eta_min (float): Minimum parameter value. Defaults to 0.
@ -684,6 +694,8 @@ class LinearParamScheduler(_ParamScheduler):
Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
start_factor (float): The number we multiply parameter value in the
first epoch. The multiplication factor changes towards end_factor
in the following epochs. Defaults to 1./3.
@ -780,6 +792,8 @@ class PolyParamScheduler(_ParamScheduler):
Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
eta_min (float): Minimum parameter value at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
@ -882,6 +896,8 @@ class OneCycleParamScheduler(_ParamScheduler):
Args:
optimizer (Optimizer): Wrapped optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
eta_max (float or list): Upper parameter value boundaries in the cycle
for each parameter group.
total_steps (int): The total number of steps in the cycle. Note that
@ -1094,3 +1110,159 @@ class OneCycleParamScheduler(_ParamScheduler):
params.append(computed_param)
return params
@PARAM_SCHEDULERS.register_module()
class CosineRestartParamScheduler(_ParamScheduler):
"""Sets the parameters of each parameter group according to the cosine
annealing with restarts scheme. The cosine restart policy anneals the
parameter from the initial value to `eta_min` with a cosine annealing
schedule and then restarts another period from the maximum value multiplied
with `restart_weight`.
Args:
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
optimizer.
param_name (str): Name of the parameter to be adjusted, such as
``lr``, ``momentum``.
periods (list[int]): Periods for each cosine anneling cycle.
restart_weights (list[float]): Restart weights at each
restart iteration. Defaults to [1].
eta_min (float): Minimum parameter value at the end of scheduling.
Defaults to None.
eta_min_ratio (float, optional): The ratio of minimum parameter value
to the base parameter value. Either `min_lr` or `min_lr_ratio`
should be specified. Default: None.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: Union[Optimizer, OptimWrapper],
param_name: str,
periods: List[int],
restart_weights: Sequence[float] = (1, ),
eta_min: Optional[float] = None,
eta_min_ratio: Optional[float] = None,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
assert (eta_min is None) ^ (eta_min_ratio is None)
self.periods = periods
self.eta_min = eta_min
self.eta_min_ratio = eta_min_ratio
self.restart_weights = restart_weights
assert (len(self.periods) == len(self.restart_weights)
), 'periods and restart_weights should have the same length.'
self.cumulative_periods = [
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
]
super().__init__(
optimizer,
param_name=param_name,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@classmethod
def build_iter_from_epoch(cls,
*args,
periods,
begin=0,
end=INF,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
periods = [p * epoch_length for p in periods]
by_epoch = False
begin = int(begin * epoch_length)
if end != INF:
end = int(end * epoch_length)
return cls(
*args,
periods=periods,
begin=begin,
end=end,
by_epoch=by_epoch,
**kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
idx = self.get_position_from_periods(self.last_step,
self.cumulative_periods)
# if current step is not in the periods, return origin parameters
if idx is None:
return [
group[self.param_name] for group in self.optimizer.param_groups
]
current_weight = self.restart_weights[idx]
nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
current_periods = self.periods[idx]
step = self.last_step - nearest_restart
values = []
for base_value, group in zip(self.base_values,
self.optimizer.param_groups):
eta_max = base_value * current_weight
if self.eta_min_ratio is None:
eta_min = self.eta_min * (1 - current_weight)
else:
eta_min = base_value * self.eta_min_ratio * (1 -
current_weight)
if step == 0:
values.append(eta_max)
elif (step - 1 - current_periods) % (2 * current_periods) == 0:
values.append(group[self.param_name] + (eta_max - eta_min) *
(1 - math.cos(math.pi / current_periods)) / 2)
else:
values.append(
(1 + math.cos(math.pi * step / current_periods)) /
(1 + math.cos(math.pi * (step - 1) / current_periods)) *
(group[self.param_name] - eta_min) + eta_min)
return values
@staticmethod
def get_position_from_periods(
iteration: int, cumulative_periods: List[int]) -> Optional[int]:
"""Get the position from a period list.
It will return the index of the right-closest number in the period
list.
For example, the cumulative_periods = [100, 200, 300, 400],
if iteration == 50, return 0;
if iteration == 210, return 2;
if iteration == 300, return 3.
Args:
iteration (int): Current iteration.
cumulative_periods (list[int]): Cumulative period list.
Returns:
Optional[int]: The position of the right-closest number in the
period list. If not in the period, return None.
"""
for i, period in enumerate(cumulative_periods):
if iteration < period:
return i
return None

View File

@ -7,8 +7,8 @@ import torch.nn.functional as F
import torch.optim as optim
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
ExponentialLR, LinearLR, MultiStepLR,
OneCycleLR, PolyLR, StepLR,
CosineRestartLR, ExponentialLR, LinearLR,
MultiStepLR, OneCycleLR, PolyLR, StepLR,
_ParamScheduler)
from mmengine.testing import assert_allclose
@ -333,6 +333,34 @@ class TestLRScheduler(TestCase):
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def test_cosine_restart_scheduler(self):
with self.assertRaises(AssertionError):
CosineRestartLR(
self.optimizer,
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0,
eta_min_ratio=0.1)
with self.assertRaises(AssertionError):
CosineRestartLR(
self.optimizer,
periods=[4, 5],
restart_weights=[1, 0.5, 0.0],
eta_min=0)
single_targets = [
0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
0.0086372, 0.0023872, 0.0023872
]
targets = [
single_targets, [t * self.layer2_mult for t in single_targets]
]
scheduler = CosineRestartLR(
self.optimizer,
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
@ -387,6 +415,20 @@ class TestLRScheduler(TestCase):
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_cosine_restart_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: CosineRestartLR(
self.optimizer,
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0),
lambda: CosineRestartLR(
self.optimizer,
periods=[4, 6],
restart_weights=[1, 0.5],
eta_min=0),
epochs=10)
def test_step_scheduler_convert_iterbased(self):
# invalid epoch_length
with self.assertRaises(AssertionError):

View File

@ -8,6 +8,7 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantMomentum,
CosineAnnealingMomentum,
CosineRestartMomentum,
ExponentialMomentum, LinearMomentum,
MultiStepMomentum, PolyMomentum,
StepMomentum, _ParamScheduler)
@ -399,6 +400,43 @@ class TestMomentumScheduler(TestCase):
self._test_scheduler_value(
self.optimizer_with_betas, scheduler, targets, epochs=10)
def test_cosine_restart_scheduler(self):
with self.assertRaises(AssertionError):
CosineRestartMomentum(
self.optimizer,
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0,
eta_min_ratio=0.1)
with self.assertRaises(AssertionError):
CosineRestartMomentum(
self.optimizer,
periods=[4, 5],
restart_weights=[1, 0.5, 0.0],
eta_min=0)
single_targets = [
0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
0.0086372, 0.0023872, 0.0023872
]
targets = [
single_targets, [t * self.layer2_mult for t in single_targets]
]
scheduler = CosineRestartMomentum(
self.optimizer,
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0)
self._test_scheduler_value(
self.optimizer, scheduler, targets, epochs=10)
scheduler = CosineRestartMomentum(
self.optimizer_with_betas,
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0)
self._test_scheduler_value(
self.optimizer_with_betas, scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
@ -454,6 +492,20 @@ class TestMomentumScheduler(TestCase):
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_cosine_restart_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: CosineRestartMomentum(
self.optimizer,
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0),
lambda: CosineRestartMomentum(
self.optimizer,
periods=[4, 6],
restart_weights=[1, 0.5],
eta_min=0),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12

View File

@ -12,12 +12,13 @@ from mmengine.optim import OptimWrapper
# yapf: disable
from mmengine.optim.scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
CosineRestartParamScheduler,
ExponentialParamScheduler,
LinearParamScheduler,
MultiStepParamScheduler,
OneCycleParamScheduler,
PolyParamScheduler, StepParamScheduler,
_ParamScheduler)
from mmengine.optim.scheduler.param_scheduler import OneCycleParamScheduler
# yapf: enable
from mmengine.testing import assert_allclose
@ -406,6 +407,37 @@ class TestParameterScheduler(TestCase):
end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def test_cosine_restart_scheduler(self):
with self.assertRaises(AssertionError):
CosineRestartParamScheduler(
self.optimizer,
param_name='lr',
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0,
eta_min_ratio=0.1)
with self.assertRaises(AssertionError):
CosineRestartParamScheduler(
self.optimizer,
param_name='lr',
periods=[4, 5],
restart_weights=[1, 0.5, 0.0],
eta_min=0)
single_targets = [
0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
0.0086372, 0.0023872, 0.0023872
]
targets = [
single_targets, [t * self.layer2_mult for t in single_targets]
]
scheduler = CosineRestartParamScheduler(
self.optimizer,
param_name='lr',
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct()
for _ in range(epochs):
@ -483,6 +515,22 @@ class TestParameterScheduler(TestCase):
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
epochs=10)
def test_cosine_restart_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: CosineRestartParamScheduler(
self.optimizer,
param_name='lr',
periods=[4, 5],
restart_weights=[1, 0.5],
eta_min=0),
lambda: CosineRestartParamScheduler(
self.optimizer,
param_name='lr',
periods=[4, 6],
restart_weights=[1, 0.5],
eta_min=0),
epochs=10)
def test_step_scheduler_convert_iterbased(self):
# invalid epoch_length
with self.assertRaises(AssertionError):