[Features]Add OneCycleLR (#296)

* [Features]Add OnecycleLR

* [Features]Add OnecycleLR

* yapf disable

* build_iter_from_epoch

* add epoch

* fix args

* fix according to comments;

* lr-param

* fix according to comments

* defaults -> default to

* remove epoch and steps per step

* variabel names
pull/302/head
Miao Zheng 2022-06-13 21:23:59 +08:00 committed by GitHub
parent 8b0c9c5f6f
commit fd295741ca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 387 additions and 10 deletions

View File

@ -2,15 +2,18 @@
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
AmpOptimWrapper, DefaultOptimWrapperConstructor,
OptimWrapper, OptimWrapperDict, build_optim_wrapper)
# yapf: disable
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
CosineAnnealingLR, CosineAnnealingMomentum,
CosineAnnealingParamScheduler, ExponentialLR,
ExponentialMomentum, ExponentialParamScheduler,
LinearLR, LinearMomentum, LinearParamScheduler,
MultiStepLR, MultiStepMomentum,
MultiStepParamScheduler, StepLR, StepMomentum,
MultiStepParamScheduler, OneCycleLR,
OneCycleParamScheduler, StepLR, StepMomentum,
StepParamScheduler, _ParamScheduler)
# yapf: enable
__all__ = [
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optim_wrapper',
'DefaultOptimWrapperConstructor', 'ConstantLR', 'CosineAnnealingLR',
@ -19,5 +22,6 @@ __all__ = [
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict'
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict',
'OneCycleParamScheduler', 'OneCycleLR'
]

View File

@ -1,14 +1,15 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
LinearLR, MultiStepLR, PolyLR, StepLR)
LinearLR, MultiStepLR, OneCycleLR, PolyLR, StepLR)
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum,
MultiStepMomentum, PolyMomentum, StepMomentum)
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler, _ParamScheduler)
MultiStepParamScheduler, OneCycleParamScheduler,
PolyParamScheduler, StepParamScheduler,
_ParamScheduler)
__all__ = [
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
@ -17,5 +18,6 @@ __all__ = [
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
'ExponentialParamScheduler', 'LinearParamScheduler',
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
'PolyParamScheduler', 'PolyLR', 'PolyMomentum'
'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler',
'OneCycleLR'
]

View File

@ -3,8 +3,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
MultiStepParamScheduler, OneCycleParamScheduler,
PolyParamScheduler, StepParamScheduler)
class LRSchedulerMixin:
@ -208,3 +208,72 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
@PARAM_SCHEDULERS.register_module()
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
r"""Sets the learning rate of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training.
This scheduler is not chainable.
Note also that the total number of steps in the cycle can be determined in
one of two ways (listed in order of precedence):
#. A value for total_steps is explicitly provided.
#. A number of epochs (epochs) and a number of steps per epoch
(steps_per_epoch) are provided.
In this case, the number of total steps is inferred by
total_steps = epochs * steps_per_epoch
You must either provide a value for total_steps or provide a value for both
epochs and steps_per_epoch.
The default behaviour of this scheduler follows the fastai implementation
of 1cycle, which claims that "unpublished work has shown even better
results by using only two phases". To mimic the behaviour of the original
paper instead, set ``three_phase=True``.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_max (float or list): Upper parameter value boundaries in the cycle
for each parameter group.
total_steps (int): The total number of steps in the cycle. Note that
if a value is not provided here, then it must be inferred by
providing a value for epochs and steps_per_epoch.
Default to None.
pct_start (float): The percentage of the cycle (in number of steps)
spent increasing the learning rate.
Default to 0.3
anneal_strategy (str): {'cos', 'linear'}
Specifies the annealing strategy: "cos" for cosine annealing,
"linear" for linear annealing.
Default to 'cos'
div_factor (float): Determines the initial learning rate via
initial_param = eta_max/div_factor
Default to 25
final_div_factor (float): Determines the minimum learning rate via
eta_min = initial_param/final_div_factor
Default to 1e4
three_phase (bool): If ``True``, use a third phase of the schedule to
annihilate the learning rate according to 'final_div_factor'
instead of modifying the second phase (the first two phases will be
symmetrical about the step indicated by 'pct_start').
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
"""# noqa E501

View File

@ -4,7 +4,7 @@ import warnings
import weakref
from collections import Counter
from functools import wraps
from typing import Callable, List, Union
from typing import Callable, List, Optional, Union
from torch.optim import Optimizer
@ -844,3 +844,245 @@ class PolyParamScheduler(_ParamScheduler):
return [(group[self.param_name] - self.eta_min) *
(1 - 1 / (self.total_iters - self.last_step + 1))**self.power +
self.eta_min for group in self.optimizer.param_groups]
@PARAM_SCHEDULERS.register_module()
class OneCycleParamScheduler(_ParamScheduler):
r"""Sets the parameters of each parameter group according to the
1cycle learning rate policy. The 1cycle policy anneals the learning
rate from an initial learning rate to some maximum learning rate and then
from that maximum learning rate to some minimum learning rate much lower
than the initial learning rate.
This policy was initially described in the paper `Super-Convergence:
Very Fast Training of Neural Networks Using Large Learning Rates`_.
The 1cycle learning rate policy changes the learning rate after every
batch. `step` should be called after a batch has been used for training.
This scheduler is not chainable.
Note also that the total number of steps in the cycle can be determined in
one of two ways (listed in order of precedence):
#. A value for total_steps is explicitly provided.
#. If total_steps is not defined, begin and end of the ParamSchedul will
works for it. In this case, the number of total steps is inferred by
total_steps = end - begin
The default behaviour of this scheduler follows the fastai implementation
of 1cycle, which claims that "unpublished work has shown even better
results by using only two phases". To mimic the behaviour of the original
paper instead, set ``three_phase=True``.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_max (float or list): Upper parameter value boundaries in the cycle
for each parameter group.
total_steps (int): The total number of steps in the cycle. Note that
if a value is not provided here, then it will be equal to
``end - begin``. Default to None
pct_start (float): The percentage of the cycle (in number of steps)
spent increasing the learning rate.
Default to 0.3
anneal_strategy (str): {'cos', 'linear'}
Specifies the annealing strategy: "cos" for cosine annealing,
"linear" for linear annealing.
Default to 'cos'
div_factor (float): Determines the initial learning rate via
initial_param = eta_max/div_factor
Default to 25
final_div_factor (float): Determines the minimum learning rate via
eta_min = initial_param/final_div_factor
Default to 1e4
three_phase (bool): If ``True``, use a third phase of the schedule to
annihilate the learning rate according to 'final_div_factor'
instead of modifying the second phase (the first two phases will be
symmetrical about the step indicated by 'pct_start').
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
https://arxiv.org/abs/1708.07120
"""# noqa E501
def __init__(self,
optimizer: Union[Optimizer, OptimWrapper],
param_name: str,
eta_max: float = 0,
total_steps: Optional[int] = None,
pct_start: float = 0.3,
anneal_strategy: str = 'cos',
div_factor: float = 25.,
final_div_factor: float = 1e4,
three_phase: bool = False,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
assert param_name == 'lr', ('OneCycle only works for learning rate '
'updating, but got patam_name as '
f'{param_name}')
self.eta_max = eta_max
self.div_factor = div_factor
self.final_div_factor = final_div_factor
# Validate total_steps
if total_steps is not None:
if total_steps <= 0 or not isinstance(total_steps, int):
raise ValueError('Expected positive integer total_steps, '
f'but got {total_steps}')
self.total_steps = total_steps
else:
self.total_steps = self.end - self.begin
# Validate pct_start
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
raise ValueError('Expected float between 0 and 1 pct_start, '
f'but got {pct_start}')
# Validate anneal_strategy
if anneal_strategy not in ['cos', 'linear']:
raise ValueError(
'anneal_strategy must by one of "cos" or "linear", '
f'instead got {anneal_strategy}')
elif anneal_strategy == 'cos':
self.anneal_func = self._annealing_cos
elif anneal_strategy == 'linear':
self.anneal_func = self._annealing_linear
if three_phase:
self._schedule_phases = [
{
'end_step': float(pct_start * self.total_steps) - 1,
f'start_{param_name}': f'initial_{param_name}',
f'end_{param_name}': f'max_{param_name}'
},
{
'end_step': float(2 * pct_start * self.total_steps) - 2,
f'start_{param_name}': f'max_{param_name}',
f'end_{param_name}': f'initial_{param_name}'
},
{
'end_step': self.total_steps - 1,
f'start_{param_name}': f'initial_{param_name}',
f'end_{param_name}': f'min_{param_name}'
},
]
else:
self._schedule_phases = [
{
'end_step': float(pct_start * self.total_steps) - 1,
f'start_{param_name}': f'initial_{param_name}',
f'end_{param_name}': f'max_{param_name}'
},
{
'end_step': self.total_steps - 1,
f'start_{param_name}': f'max_{param_name}',
f'end_{param_name}': f'min_{param_name}'
},
]
# Initialize parameters
max_values = self._format_param(f'max_{param_name}', optimizer,
eta_max)
if last_step == -1:
for idx, group in enumerate(optimizer.param_groups):
group[f'initial_{param_name}'] = max_values[idx] / div_factor
group[f'max_{param_name}'] = max_values[idx]
group[f'min_{param_name}'] = \
group[f'initial_{param_name}'] / final_div_factor
super().__init__(
optimizer=optimizer,
param_name=param_name,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
def _format_param(self, name, optimizer, param):
"""Return correctly formatted lr/momentum for each param group."""
if isinstance(param, (list, tuple)):
if len(param) != len(optimizer.param_groups):
raise ValueError(
f'expected {len(optimizer.param_groups)} values '
f'for {name}, got { len(param)}')
return param
else:
return [param] * len(optimizer.param_groups)
def _annealing_cos(self, start, end, pct):
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
def _annealing_linear(self, start, end, pct):
"""Linearly anneal from `start` to `end` as pct goes from 0.0 to
1.0."""
return (end - start) * pct + start
@classmethod
def build_iter_from_epoch(cls,
*args,
begin=0,
end=INF,
total_steps=None,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
if end != INF:
end = end * epoch_length
if total_steps is not None:
total_steps = total_steps * epoch_length
return cls(
*args,
begin=begin,
end=end,
total_steps=total_steps,
by_epoch=by_epoch,
**kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
params = []
step_num = self.last_step
if step_num > self.total_steps:
raise ValueError(
f'Tried to step {step_num + 1} times. '
f'The specified number of total steps is {self.total_steps}')
for group in self.optimizer.param_groups:
start_step = 0
for i, phase in enumerate(self._schedule_phases):
end_step = phase['end_step']
if step_num <= end_step or i == len(self._schedule_phases) - 1:
pct = (step_num - start_step) / (end_step - start_step)
computed_param = self.anneal_func(
group[phase['start_' + self.param_name]],
group[phase['end_' + self.param_name]], pct)
break
start_step = phase['end_step']
params.append(computed_param)
return params

View File

@ -8,7 +8,8 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
ExponentialLR, LinearLR, MultiStepLR,
PolyLR, StepLR, _ParamScheduler)
OneCycleLR, PolyLR, StepLR,
_ParamScheduler)
from mmengine.testing import assert_allclose
@ -551,3 +552,44 @@ class TestLRScheduler(TestCase):
self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15)
self._test_scheduler_value([scheduler1, scheduler2], targets, epochs)
def test_onecycle_lr(self):
# test linear annealing
target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
scheduler = OneCycleLR(
self.optimizer,
eta_max=25,
final_div_factor=2,
total_steps=10,
anneal_strategy='linear')
self._test_scheduler_value(scheduler, [target], 10)
# test linear annealing three phase
target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25]
scheduler = OneCycleLR(
self.optimizer,
eta_max=25,
div_factor=25,
total_steps=10,
anneal_strategy='linear',
pct_start=0.4,
final_div_factor=4,
three_phase=True)
self._test_scheduler_value(scheduler, [target], 10)
# test cosine annealing
def annealing_cos(start, end, pct):
cos_out = math.cos(math.pi * pct) + 1
return end + (start - end) / 2.0 * cos_out
target = [
1, 13, 25,
annealing_cos(25, 0.5, 1 / 7.0),
annealing_cos(25, 0.5, 2 / 7.0),
annealing_cos(25, 0.5, 3 / 7.0),
annealing_cos(25, 0.5, 4 / 7.0),
annealing_cos(25, 0.5, 5 / 7.0),
annealing_cos(25, 0.5, 6 / 7.0), 0.5
]
scheduler = OneCycleLR(
self.optimizer, eta_max=25, final_div_factor=2, total_steps=10)
self._test_scheduler_value(scheduler, [target], 10)

View File

@ -14,6 +14,7 @@ from mmengine.optim.scheduler import (ConstantParamScheduler,
MultiStepParamScheduler,
PolyParamScheduler, StepParamScheduler,
_ParamScheduler)
from mmengine.optim.scheduler.param_scheduler import OneCycleParamScheduler
# yapf: enable
from mmengine.testing import assert_allclose
@ -677,3 +678,20 @@ class TestParameterScheduler(TestCase):
end=15)
self._test_scheduler_value([scheduler1, scheduler2], targets, epochs)
def test_onecycle_scheduler(self):
# test invalid total steps
with self.assertRaises(ValueError):
OneCycleParamScheduler(
self.optimizer, param_name='lr', total_steps=-1)
# test invalid pct_start
with self.assertRaises(ValueError):
OneCycleParamScheduler(
self.optimizer, param_name='lr', total_steps=10, pct_start=-1)
# test invalid anneal_strategy
with self.assertRaises(ValueError):
OneCycleParamScheduler(
self.optimizer,
param_name='lr',
total_steps=10,
anneal_strategy='a')