[Features]Add OneCycleLR (#296)
* [Features]Add OnecycleLR * [Features]Add OnecycleLR * yapf disable * build_iter_from_epoch * add epoch * fix args * fix according to comments; * lr-param * fix according to comments * defaults -> default to * remove epoch and steps per step * variabel namespull/302/head
parent
8b0c9c5f6f
commit
fd295741ca
|
@ -2,15 +2,18 @@
|
|||
from .optimizer import (OPTIM_WRAPPER_CONSTRUCTORS, OPTIMIZERS,
|
||||
AmpOptimWrapper, DefaultOptimWrapperConstructor,
|
||||
OptimWrapper, OptimWrapperDict, build_optim_wrapper)
|
||||
# yapf: disable
|
||||
from .scheduler import (ConstantLR, ConstantMomentum, ConstantParamScheduler,
|
||||
CosineAnnealingLR, CosineAnnealingMomentum,
|
||||
CosineAnnealingParamScheduler, ExponentialLR,
|
||||
ExponentialMomentum, ExponentialParamScheduler,
|
||||
LinearLR, LinearMomentum, LinearParamScheduler,
|
||||
MultiStepLR, MultiStepMomentum,
|
||||
MultiStepParamScheduler, StepLR, StepMomentum,
|
||||
MultiStepParamScheduler, OneCycleLR,
|
||||
OneCycleParamScheduler, StepLR, StepMomentum,
|
||||
StepParamScheduler, _ParamScheduler)
|
||||
|
||||
# yapf: enable
|
||||
__all__ = [
|
||||
'OPTIM_WRAPPER_CONSTRUCTORS', 'OPTIMIZERS', 'build_optim_wrapper',
|
||||
'DefaultOptimWrapperConstructor', 'ConstantLR', 'CosineAnnealingLR',
|
||||
|
@ -19,5 +22,6 @@ __all__ = [
|
|||
'MultiStepMomentum', 'StepMomentum', 'ConstantParamScheduler',
|
||||
'CosineAnnealingParamScheduler', 'ExponentialParamScheduler',
|
||||
'LinearParamScheduler', 'MultiStepParamScheduler', 'StepParamScheduler',
|
||||
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict'
|
||||
'_ParamScheduler', 'OptimWrapper', 'AmpOptimWrapper', 'OptimWrapperDict',
|
||||
'OneCycleParamScheduler', 'OneCycleLR'
|
||||
]
|
||||
|
|
|
@ -1,14 +1,15 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
|
||||
LinearLR, MultiStepLR, PolyLR, StepLR)
|
||||
LinearLR, MultiStepLR, OneCycleLR, PolyLR, StepLR)
|
||||
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
|
||||
ExponentialMomentum, LinearMomentum,
|
||||
MultiStepMomentum, PolyMomentum, StepMomentum)
|
||||
from .param_scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
ExponentialParamScheduler, LinearParamScheduler,
|
||||
MultiStepParamScheduler, PolyParamScheduler,
|
||||
StepParamScheduler, _ParamScheduler)
|
||||
MultiStepParamScheduler, OneCycleParamScheduler,
|
||||
PolyParamScheduler, StepParamScheduler,
|
||||
_ParamScheduler)
|
||||
|
||||
__all__ = [
|
||||
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
|
||||
|
@ -17,5 +18,6 @@ __all__ = [
|
|||
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
|
||||
'ExponentialParamScheduler', 'LinearParamScheduler',
|
||||
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
|
||||
'PolyParamScheduler', 'PolyLR', 'PolyMomentum'
|
||||
'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler',
|
||||
'OneCycleLR'
|
||||
]
|
||||
|
|
|
@ -3,8 +3,8 @@ from mmengine.registry import PARAM_SCHEDULERS
|
|||
from .param_scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
ExponentialParamScheduler, LinearParamScheduler,
|
||||
MultiStepParamScheduler, PolyParamScheduler,
|
||||
StepParamScheduler)
|
||||
MultiStepParamScheduler, OneCycleParamScheduler,
|
||||
PolyParamScheduler, StepParamScheduler)
|
||||
|
||||
|
||||
class LRSchedulerMixin:
|
||||
|
@ -208,3 +208,72 @@ class PolyLR(LRSchedulerMixin, PolyParamScheduler):
|
|||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
|
||||
r"""Sets the learning rate of each parameter group according to the
|
||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
||||
rate from an initial learning rate to some maximum learning rate and then
|
||||
from that maximum learning rate to some minimum learning rate much lower
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every
|
||||
batch. `step` should be called after a batch has been used for training.
|
||||
|
||||
This scheduler is not chainable.
|
||||
|
||||
Note also that the total number of steps in the cycle can be determined in
|
||||
one of two ways (listed in order of precedence):
|
||||
|
||||
#. A value for total_steps is explicitly provided.
|
||||
#. A number of epochs (epochs) and a number of steps per epoch
|
||||
(steps_per_epoch) are provided.
|
||||
In this case, the number of total steps is inferred by
|
||||
total_steps = epochs * steps_per_epoch
|
||||
|
||||
You must either provide a value for total_steps or provide a value for both
|
||||
epochs and steps_per_epoch.
|
||||
|
||||
The default behaviour of this scheduler follows the fastai implementation
|
||||
of 1cycle, which claims that "unpublished work has shown even better
|
||||
results by using only two phases". To mimic the behaviour of the original
|
||||
paper instead, set ``three_phase=True``.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
eta_max (float or list): Upper parameter value boundaries in the cycle
|
||||
for each parameter group.
|
||||
total_steps (int): The total number of steps in the cycle. Note that
|
||||
if a value is not provided here, then it must be inferred by
|
||||
providing a value for epochs and steps_per_epoch.
|
||||
Default to None.
|
||||
pct_start (float): The percentage of the cycle (in number of steps)
|
||||
spent increasing the learning rate.
|
||||
Default to 0.3
|
||||
anneal_strategy (str): {'cos', 'linear'}
|
||||
Specifies the annealing strategy: "cos" for cosine annealing,
|
||||
"linear" for linear annealing.
|
||||
Default to 'cos'
|
||||
div_factor (float): Determines the initial learning rate via
|
||||
initial_param = eta_max/div_factor
|
||||
Default to 25
|
||||
final_div_factor (float): Determines the minimum learning rate via
|
||||
eta_min = initial_param/final_div_factor
|
||||
Default to 1e4
|
||||
three_phase (bool): If ``True``, use a third phase of the schedule to
|
||||
annihilate the learning rate according to 'final_div_factor'
|
||||
instead of modifying the second phase (the first two phases will be
|
||||
symmetrical about the step indicated by 'pct_start').
|
||||
last_step (int): The index of last step. Used for resume without
|
||||
state dict. Defaults to -1.
|
||||
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||
epochs. Defaults to True.
|
||||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
|
||||
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
||||
https://arxiv.org/abs/1708.07120
|
||||
"""# noqa E501
|
||||
|
|
|
@ -4,7 +4,7 @@ import warnings
|
|||
import weakref
|
||||
from collections import Counter
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Union
|
||||
from typing import Callable, List, Optional, Union
|
||||
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -844,3 +844,245 @@ class PolyParamScheduler(_ParamScheduler):
|
|||
return [(group[self.param_name] - self.eta_min) *
|
||||
(1 - 1 / (self.total_iters - self.last_step + 1))**self.power +
|
||||
self.eta_min for group in self.optimizer.param_groups]
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class OneCycleParamScheduler(_ParamScheduler):
|
||||
r"""Sets the parameters of each parameter group according to the
|
||||
1cycle learning rate policy. The 1cycle policy anneals the learning
|
||||
rate from an initial learning rate to some maximum learning rate and then
|
||||
from that maximum learning rate to some minimum learning rate much lower
|
||||
than the initial learning rate.
|
||||
This policy was initially described in the paper `Super-Convergence:
|
||||
Very Fast Training of Neural Networks Using Large Learning Rates`_.
|
||||
|
||||
The 1cycle learning rate policy changes the learning rate after every
|
||||
batch. `step` should be called after a batch has been used for training.
|
||||
|
||||
This scheduler is not chainable.
|
||||
|
||||
Note also that the total number of steps in the cycle can be determined in
|
||||
one of two ways (listed in order of precedence):
|
||||
|
||||
#. A value for total_steps is explicitly provided.
|
||||
#. If total_steps is not defined, begin and end of the ParamSchedul will
|
||||
works for it. In this case, the number of total steps is inferred by
|
||||
total_steps = end - begin
|
||||
|
||||
The default behaviour of this scheduler follows the fastai implementation
|
||||
of 1cycle, which claims that "unpublished work has shown even better
|
||||
results by using only two phases". To mimic the behaviour of the original
|
||||
paper instead, set ``three_phase=True``.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
eta_max (float or list): Upper parameter value boundaries in the cycle
|
||||
for each parameter group.
|
||||
total_steps (int): The total number of steps in the cycle. Note that
|
||||
if a value is not provided here, then it will be equal to
|
||||
``end - begin``. Default to None
|
||||
pct_start (float): The percentage of the cycle (in number of steps)
|
||||
spent increasing the learning rate.
|
||||
Default to 0.3
|
||||
anneal_strategy (str): {'cos', 'linear'}
|
||||
Specifies the annealing strategy: "cos" for cosine annealing,
|
||||
"linear" for linear annealing.
|
||||
Default to 'cos'
|
||||
div_factor (float): Determines the initial learning rate via
|
||||
initial_param = eta_max/div_factor
|
||||
Default to 25
|
||||
final_div_factor (float): Determines the minimum learning rate via
|
||||
eta_min = initial_param/final_div_factor
|
||||
Default to 1e4
|
||||
three_phase (bool): If ``True``, use a third phase of the schedule to
|
||||
annihilate the learning rate according to 'final_div_factor'
|
||||
instead of modifying the second phase (the first two phases will be
|
||||
symmetrical about the step indicated by 'pct_start').
|
||||
last_step (int): The index of last step. Used for resume without
|
||||
state dict. Defaults to -1.
|
||||
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||
epochs. Defaults to True.
|
||||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
|
||||
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
||||
https://arxiv.org/abs/1708.07120
|
||||
"""# noqa E501
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Union[Optimizer, OptimWrapper],
|
||||
param_name: str,
|
||||
eta_max: float = 0,
|
||||
total_steps: Optional[int] = None,
|
||||
pct_start: float = 0.3,
|
||||
anneal_strategy: str = 'cos',
|
||||
div_factor: float = 25.,
|
||||
final_div_factor: float = 1e4,
|
||||
three_phase: bool = False,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
|
||||
assert param_name == 'lr', ('OneCycle only works for learning rate '
|
||||
'updating, but got patam_name as '
|
||||
f'{param_name}')
|
||||
|
||||
self.eta_max = eta_max
|
||||
self.div_factor = div_factor
|
||||
self.final_div_factor = final_div_factor
|
||||
|
||||
# Validate total_steps
|
||||
if total_steps is not None:
|
||||
if total_steps <= 0 or not isinstance(total_steps, int):
|
||||
raise ValueError('Expected positive integer total_steps, '
|
||||
f'but got {total_steps}')
|
||||
self.total_steps = total_steps
|
||||
else:
|
||||
self.total_steps = self.end - self.begin
|
||||
|
||||
# Validate pct_start
|
||||
if pct_start < 0 or pct_start > 1 or not isinstance(pct_start, float):
|
||||
raise ValueError('Expected float between 0 and 1 pct_start, '
|
||||
f'but got {pct_start}')
|
||||
|
||||
# Validate anneal_strategy
|
||||
if anneal_strategy not in ['cos', 'linear']:
|
||||
raise ValueError(
|
||||
'anneal_strategy must by one of "cos" or "linear", '
|
||||
f'instead got {anneal_strategy}')
|
||||
elif anneal_strategy == 'cos':
|
||||
self.anneal_func = self._annealing_cos
|
||||
elif anneal_strategy == 'linear':
|
||||
self.anneal_func = self._annealing_linear
|
||||
|
||||
if three_phase:
|
||||
self._schedule_phases = [
|
||||
{
|
||||
'end_step': float(pct_start * self.total_steps) - 1,
|
||||
f'start_{param_name}': f'initial_{param_name}',
|
||||
f'end_{param_name}': f'max_{param_name}'
|
||||
},
|
||||
{
|
||||
'end_step': float(2 * pct_start * self.total_steps) - 2,
|
||||
f'start_{param_name}': f'max_{param_name}',
|
||||
f'end_{param_name}': f'initial_{param_name}'
|
||||
},
|
||||
{
|
||||
'end_step': self.total_steps - 1,
|
||||
f'start_{param_name}': f'initial_{param_name}',
|
||||
f'end_{param_name}': f'min_{param_name}'
|
||||
},
|
||||
]
|
||||
else:
|
||||
self._schedule_phases = [
|
||||
{
|
||||
'end_step': float(pct_start * self.total_steps) - 1,
|
||||
f'start_{param_name}': f'initial_{param_name}',
|
||||
f'end_{param_name}': f'max_{param_name}'
|
||||
},
|
||||
{
|
||||
'end_step': self.total_steps - 1,
|
||||
f'start_{param_name}': f'max_{param_name}',
|
||||
f'end_{param_name}': f'min_{param_name}'
|
||||
},
|
||||
]
|
||||
|
||||
# Initialize parameters
|
||||
max_values = self._format_param(f'max_{param_name}', optimizer,
|
||||
eta_max)
|
||||
if last_step == -1:
|
||||
for idx, group in enumerate(optimizer.param_groups):
|
||||
group[f'initial_{param_name}'] = max_values[idx] / div_factor
|
||||
group[f'max_{param_name}'] = max_values[idx]
|
||||
group[f'min_{param_name}'] = \
|
||||
group[f'initial_{param_name}'] / final_div_factor
|
||||
|
||||
super().__init__(
|
||||
optimizer=optimizer,
|
||||
param_name=param_name,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
def _format_param(self, name, optimizer, param):
|
||||
"""Return correctly formatted lr/momentum for each param group."""
|
||||
if isinstance(param, (list, tuple)):
|
||||
if len(param) != len(optimizer.param_groups):
|
||||
raise ValueError(
|
||||
f'expected {len(optimizer.param_groups)} values '
|
||||
f'for {name}, got { len(param)}')
|
||||
return param
|
||||
else:
|
||||
return [param] * len(optimizer.param_groups)
|
||||
|
||||
def _annealing_cos(self, start, end, pct):
|
||||
"""Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0."""
|
||||
|
||||
cos_out = math.cos(math.pi * pct) + 1
|
||||
return end + (start - end) / 2.0 * cos_out
|
||||
|
||||
def _annealing_linear(self, start, end, pct):
|
||||
"""Linearly anneal from `start` to `end` as pct goes from 0.0 to
|
||||
1.0."""
|
||||
return (end - start) * pct + start
|
||||
|
||||
@classmethod
|
||||
def build_iter_from_epoch(cls,
|
||||
*args,
|
||||
begin=0,
|
||||
end=INF,
|
||||
total_steps=None,
|
||||
by_epoch=True,
|
||||
epoch_length=None,
|
||||
**kwargs):
|
||||
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||
config."""
|
||||
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||
'be converted to iter-based.'
|
||||
assert epoch_length is not None and epoch_length > 0, \
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
by_epoch = False
|
||||
begin = begin * epoch_length
|
||||
if end != INF:
|
||||
end = end * epoch_length
|
||||
if total_steps is not None:
|
||||
total_steps = total_steps * epoch_length
|
||||
return cls(
|
||||
*args,
|
||||
begin=begin,
|
||||
end=end,
|
||||
total_steps=total_steps,
|
||||
by_epoch=by_epoch,
|
||||
**kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
|
||||
params = []
|
||||
step_num = self.last_step
|
||||
|
||||
if step_num > self.total_steps:
|
||||
raise ValueError(
|
||||
f'Tried to step {step_num + 1} times. '
|
||||
f'The specified number of total steps is {self.total_steps}')
|
||||
|
||||
for group in self.optimizer.param_groups:
|
||||
start_step = 0
|
||||
for i, phase in enumerate(self._schedule_phases):
|
||||
end_step = phase['end_step']
|
||||
if step_num <= end_step or i == len(self._schedule_phases) - 1:
|
||||
pct = (step_num - start_step) / (end_step - start_step)
|
||||
computed_param = self.anneal_func(
|
||||
group[phase['start_' + self.param_name]],
|
||||
group[phase['end_' + self.param_name]], pct)
|
||||
break
|
||||
start_step = phase['end_step']
|
||||
|
||||
params.append(computed_param)
|
||||
|
||||
return params
|
||||
|
|
|
@ -8,7 +8,8 @@ import torch.optim as optim
|
|||
|
||||
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
|
||||
ExponentialLR, LinearLR, MultiStepLR,
|
||||
PolyLR, StepLR, _ParamScheduler)
|
||||
OneCycleLR, PolyLR, StepLR,
|
||||
_ParamScheduler)
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
||||
|
@ -551,3 +552,44 @@ class TestLRScheduler(TestCase):
|
|||
self.optimizer, T_max=5, eta_min=eta_min, begin=10, end=15)
|
||||
|
||||
self._test_scheduler_value([scheduler1, scheduler2], targets, epochs)
|
||||
|
||||
def test_onecycle_lr(self):
|
||||
# test linear annealing
|
||||
target = [1, 13, 25, 21.5, 18, 14.5, 11, 7.5, 4, 0.5]
|
||||
scheduler = OneCycleLR(
|
||||
self.optimizer,
|
||||
eta_max=25,
|
||||
final_div_factor=2,
|
||||
total_steps=10,
|
||||
anneal_strategy='linear')
|
||||
self._test_scheduler_value(scheduler, [target], 10)
|
||||
# test linear annealing three phase
|
||||
target = [1, 9, 17, 25, 17, 9, 1, 0.75, 0.5, 0.25]
|
||||
scheduler = OneCycleLR(
|
||||
self.optimizer,
|
||||
eta_max=25,
|
||||
div_factor=25,
|
||||
total_steps=10,
|
||||
anneal_strategy='linear',
|
||||
pct_start=0.4,
|
||||
final_div_factor=4,
|
||||
three_phase=True)
|
||||
self._test_scheduler_value(scheduler, [target], 10)
|
||||
|
||||
# test cosine annealing
|
||||
def annealing_cos(start, end, pct):
|
||||
cos_out = math.cos(math.pi * pct) + 1
|
||||
return end + (start - end) / 2.0 * cos_out
|
||||
|
||||
target = [
|
||||
1, 13, 25,
|
||||
annealing_cos(25, 0.5, 1 / 7.0),
|
||||
annealing_cos(25, 0.5, 2 / 7.0),
|
||||
annealing_cos(25, 0.5, 3 / 7.0),
|
||||
annealing_cos(25, 0.5, 4 / 7.0),
|
||||
annealing_cos(25, 0.5, 5 / 7.0),
|
||||
annealing_cos(25, 0.5, 6 / 7.0), 0.5
|
||||
]
|
||||
scheduler = OneCycleLR(
|
||||
self.optimizer, eta_max=25, final_div_factor=2, total_steps=10)
|
||||
self._test_scheduler_value(scheduler, [target], 10)
|
||||
|
|
|
@ -14,6 +14,7 @@ from mmengine.optim.scheduler import (ConstantParamScheduler,
|
|||
MultiStepParamScheduler,
|
||||
PolyParamScheduler, StepParamScheduler,
|
||||
_ParamScheduler)
|
||||
from mmengine.optim.scheduler.param_scheduler import OneCycleParamScheduler
|
||||
# yapf: enable
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
@ -677,3 +678,20 @@ class TestParameterScheduler(TestCase):
|
|||
end=15)
|
||||
|
||||
self._test_scheduler_value([scheduler1, scheduler2], targets, epochs)
|
||||
|
||||
def test_onecycle_scheduler(self):
|
||||
# test invalid total steps
|
||||
with self.assertRaises(ValueError):
|
||||
OneCycleParamScheduler(
|
||||
self.optimizer, param_name='lr', total_steps=-1)
|
||||
# test invalid pct_start
|
||||
with self.assertRaises(ValueError):
|
||||
OneCycleParamScheduler(
|
||||
self.optimizer, param_name='lr', total_steps=10, pct_start=-1)
|
||||
# test invalid anneal_strategy
|
||||
with self.assertRaises(ValueError):
|
||||
OneCycleParamScheduler(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
total_steps=10,
|
||||
anneal_strategy='a')
|
||||
|
|
Loading…
Reference in New Issue