mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR (#188)
* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR * min_lr -> eta_min, refined docstr
This commit is contained in:
parent
e2a2b0438e
commit
c3aff4fc9a
@ -1,14 +1,14 @@
|
|||||||
# Copyright (c) OpenMMLab. All rights reserved.
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
|
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
|
||||||
LinearLR, MultiStepLR, StepLR)
|
LinearLR, MultiStepLR, PolyLR, StepLR)
|
||||||
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
|
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
|
||||||
ExponentialMomentum, LinearMomentum,
|
ExponentialMomentum, LinearMomentum,
|
||||||
MultiStepMomentum, StepMomentum)
|
MultiStepMomentum, PolyMomentum, StepMomentum)
|
||||||
from .param_scheduler import (ConstantParamScheduler,
|
from .param_scheduler import (ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
ExponentialParamScheduler, LinearParamScheduler,
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
MultiStepParamScheduler, StepParamScheduler,
|
MultiStepParamScheduler, PolyParamScheduler,
|
||||||
_ParamScheduler)
|
StepParamScheduler, _ParamScheduler)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
|
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
|
||||||
@ -16,5 +16,6 @@ __all__ = [
|
|||||||
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
|
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
|
||||||
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
|
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
|
||||||
'ExponentialParamScheduler', 'LinearParamScheduler',
|
'ExponentialParamScheduler', 'LinearParamScheduler',
|
||||||
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler'
|
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
|
||||||
|
'PolyParamScheduler', 'PolyLR', 'PolyMomentum'
|
||||||
]
|
]
|
||||||
|
@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
|
|||||||
from .param_scheduler import (INF, ConstantParamScheduler,
|
from .param_scheduler import (INF, ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
ExponentialParamScheduler, LinearParamScheduler,
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
MultiStepParamScheduler, StepParamScheduler)
|
MultiStepParamScheduler, PolyParamScheduler,
|
||||||
|
StepParamScheduler)
|
||||||
|
|
||||||
|
|
||||||
@PARAM_SCHEDULERS.register_module()
|
@PARAM_SCHEDULERS.register_module()
|
||||||
@ -294,3 +295,49 @@ class StepLR(StepParamScheduler):
|
|||||||
last_step=last_step,
|
last_step=last_step,
|
||||||
by_epoch=by_epoch,
|
by_epoch=by_epoch,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class PolyLR(PolyParamScheduler):
|
||||||
|
"""Decays the learning rate of each parameter group in a polynomial decay
|
||||||
|
scheme.
|
||||||
|
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
parameter value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
eta_min (float): Minimum learning rate at the end of scheduling.
|
||||||
|
Defaults to 0.
|
||||||
|
power (float): The power of the polynomial. Defaults to 1.0.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
eta_min: float = 0,
|
||||||
|
power: float = 1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
eta_min=eta_min,
|
||||||
|
power=power,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
|
|||||||
from .param_scheduler import (INF, ConstantParamScheduler,
|
from .param_scheduler import (INF, ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
ExponentialParamScheduler, LinearParamScheduler,
|
ExponentialParamScheduler, LinearParamScheduler,
|
||||||
MultiStepParamScheduler, StepParamScheduler)
|
MultiStepParamScheduler, PolyParamScheduler,
|
||||||
|
StepParamScheduler)
|
||||||
|
|
||||||
|
|
||||||
@PARAM_SCHEDULERS.register_module()
|
@PARAM_SCHEDULERS.register_module()
|
||||||
@ -294,3 +295,49 @@ class StepMomentum(StepParamScheduler):
|
|||||||
last_step=last_step,
|
last_step=last_step,
|
||||||
by_epoch=by_epoch,
|
by_epoch=by_epoch,
|
||||||
verbose=verbose)
|
verbose=verbose)
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class PolyMomentum(PolyParamScheduler):
|
||||||
|
"""Decays the momentum of each parameter group in a polynomial decay
|
||||||
|
scheme.
|
||||||
|
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
parameter value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
eta_min (float): Minimum momentum at the end of scheduling.
|
||||||
|
Defaults to 0.
|
||||||
|
power (float): The power of the polynomial. Defaults to 1.0.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: torch.optim.Optimizer,
|
||||||
|
eta_min: float = 0,
|
||||||
|
power: float = 1,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name='momentum',
|
||||||
|
eta_min=eta_min,
|
||||||
|
power=power,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
@ -534,6 +534,7 @@ class LinearParamScheduler(_ParamScheduler):
|
|||||||
|
|
||||||
Notice that such decay can happen simultaneously with other changes to the
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
parameter value from outside this scheduler.
|
parameter value from outside this scheduler.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
optimizer (Optimizer): Wrapped optimizer.
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
start_factor (float): The number we multiply parameter value in the
|
start_factor (float): The number we multiply parameter value in the
|
||||||
@ -598,3 +599,64 @@ class LinearParamScheduler(_ParamScheduler):
|
|||||||
(self.end_factor - self.start_factor)))
|
(self.end_factor - self.start_factor)))
|
||||||
for group in self.optimizer.param_groups
|
for group in self.optimizer.param_groups
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@PARAM_SCHEDULERS.register_module()
|
||||||
|
class PolyParamScheduler(_ParamScheduler):
|
||||||
|
"""Decays the parameter value of each parameter group in a polynomial decay
|
||||||
|
scheme.
|
||||||
|
|
||||||
|
Notice that such decay can happen simultaneously with other changes to the
|
||||||
|
parameter value from outside this scheduler.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
optimizer (Optimizer): Wrapped optimizer.
|
||||||
|
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||||
|
Defaults to 0.
|
||||||
|
power (float): The power of the polynomial. Defaults to 1.0.
|
||||||
|
begin (int): Step at which to start updating the parameters.
|
||||||
|
Defaults to 0.
|
||||||
|
end (int): Step at which to stop updating the parameters.
|
||||||
|
Defaults to INF.
|
||||||
|
last_step (int): The index of last step. Used for resume without
|
||||||
|
state dict. Defaults to -1.
|
||||||
|
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||||
|
epochs. Defaults to True.
|
||||||
|
verbose (bool): Whether to print the value for each update.
|
||||||
|
Defaults to False.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
optimizer: Optimizer,
|
||||||
|
param_name: str,
|
||||||
|
eta_min: float = 0,
|
||||||
|
power: float = 1.0,
|
||||||
|
begin: int = 0,
|
||||||
|
end: int = INF,
|
||||||
|
last_step: int = -1,
|
||||||
|
by_epoch: bool = True,
|
||||||
|
verbose: bool = False):
|
||||||
|
|
||||||
|
self.eta_min = eta_min
|
||||||
|
self.power = power
|
||||||
|
self.total_iters = end - begin - 1
|
||||||
|
|
||||||
|
super().__init__(
|
||||||
|
optimizer,
|
||||||
|
param_name=param_name,
|
||||||
|
begin=begin,
|
||||||
|
end=end,
|
||||||
|
last_step=last_step,
|
||||||
|
by_epoch=by_epoch,
|
||||||
|
verbose=verbose)
|
||||||
|
|
||||||
|
def _get_value(self):
|
||||||
|
|
||||||
|
if self.last_step == 0:
|
||||||
|
return [
|
||||||
|
group[self.param_name] for group in self.optimizer.param_groups
|
||||||
|
]
|
||||||
|
|
||||||
|
return [(group[self.param_name] - self.eta_min) *
|
||||||
|
(1 - 1 / (self.total_iters - self.last_step + 1))**self.power +
|
||||||
|
self.eta_min for group in self.optimizer.param_groups]
|
||||||
|
@ -8,7 +8,7 @@ import torch.optim as optim
|
|||||||
|
|
||||||
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
|
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
|
||||||
ExponentialLR, LinearLR, MultiStepLR,
|
ExponentialLR, LinearLR, MultiStepLR,
|
||||||
StepLR, _ParamScheduler)
|
PolyLR, StepLR, _ParamScheduler)
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
|
||||||
@ -283,6 +283,21 @@ class TestLRScheduler(TestCase):
|
|||||||
scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min)
|
scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min)
|
||||||
self._test_scheduler_value(scheduler, targets, epochs)
|
self._test_scheduler_value(scheduler, targets, epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler(self):
|
||||||
|
epochs = 10
|
||||||
|
power = 0.9
|
||||||
|
min_lr = 0.001
|
||||||
|
iters = 4
|
||||||
|
single_targets = [
|
||||||
|
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
|
||||||
|
for i in range(iters)
|
||||||
|
] + [min_lr] * (
|
||||||
|
epochs - iters)
|
||||||
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
||||||
|
scheduler = PolyLR(
|
||||||
|
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
|
||||||
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
@ -331,6 +346,12 @@ class TestLRScheduler(TestCase):
|
|||||||
lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3),
|
lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3),
|
||||||
epochs=epochs)
|
epochs=epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: PolyLR(self.optimizer, power=0.5, eta_min=0.001),
|
||||||
|
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
|
||||||
|
epochs=10)
|
||||||
|
|
||||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||||
# use Linear in the first 5 epochs and then use MultiStep
|
# use Linear in the first 5 epochs and then use MultiStep
|
||||||
epochs = 12
|
epochs = 12
|
||||||
|
@ -9,8 +9,8 @@ import torch.optim as optim
|
|||||||
from mmengine.optim.scheduler import (ConstantMomentum,
|
from mmengine.optim.scheduler import (ConstantMomentum,
|
||||||
CosineAnnealingMomentum,
|
CosineAnnealingMomentum,
|
||||||
ExponentialMomentum, LinearMomentum,
|
ExponentialMomentum, LinearMomentum,
|
||||||
MultiStepMomentum, StepMomentum,
|
MultiStepMomentum, PolyMomentum,
|
||||||
_ParamScheduler)
|
StepMomentum, _ParamScheduler)
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
|
||||||
@ -284,6 +284,21 @@ class TestMomentumScheduler(TestCase):
|
|||||||
self.optimizer, T_max=t, eta_min=eta_min)
|
self.optimizer, T_max=t, eta_min=eta_min)
|
||||||
self._test_scheduler_value(scheduler, targets, epochs)
|
self._test_scheduler_value(scheduler, targets, epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler(self):
|
||||||
|
epochs = 10
|
||||||
|
power = 0.9
|
||||||
|
min_lr = 0.001
|
||||||
|
iters = 4
|
||||||
|
single_targets = [
|
||||||
|
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
|
||||||
|
for i in range(iters)
|
||||||
|
] + [min_lr] * (
|
||||||
|
epochs - iters)
|
||||||
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
||||||
|
scheduler = PolyMomentum(
|
||||||
|
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
|
||||||
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
@ -333,6 +348,12 @@ class TestMomentumScheduler(TestCase):
|
|||||||
self.optimizer, start_factor=0, end_factor=0.3),
|
self.optimizer, start_factor=0, end_factor=0.3),
|
||||||
epochs=epochs)
|
epochs=epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: PolyMomentum(self.optimizer, power=0.5, eta_min=0.001),
|
||||||
|
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
|
||||||
|
epochs=10)
|
||||||
|
|
||||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||||
# use Linear in the first 5 epochs and then use MultiStep
|
# use Linear in the first 5 epochs and then use MultiStep
|
||||||
epochs = 12
|
epochs = 12
|
||||||
|
@ -6,12 +6,15 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
import torch.optim as optim
|
import torch.optim as optim
|
||||||
|
|
||||||
|
# yapf: disable
|
||||||
from mmengine.optim.scheduler import (ConstantParamScheduler,
|
from mmengine.optim.scheduler import (ConstantParamScheduler,
|
||||||
CosineAnnealingParamScheduler,
|
CosineAnnealingParamScheduler,
|
||||||
ExponentialParamScheduler,
|
ExponentialParamScheduler,
|
||||||
LinearParamScheduler,
|
LinearParamScheduler,
|
||||||
MultiStepParamScheduler,
|
MultiStepParamScheduler,
|
||||||
StepParamScheduler, _ParamScheduler)
|
PolyParamScheduler, StepParamScheduler,
|
||||||
|
_ParamScheduler)
|
||||||
|
# yapf: enable
|
||||||
from mmengine.testing import assert_allclose
|
from mmengine.testing import assert_allclose
|
||||||
|
|
||||||
|
|
||||||
@ -336,6 +339,25 @@ class TestParameterScheduler(TestCase):
|
|||||||
self.optimizer, param_name='lr', T_max=t, eta_min=eta_min)
|
self.optimizer, param_name='lr', T_max=t, eta_min=eta_min)
|
||||||
self._test_scheduler_value(scheduler, targets, epochs)
|
self._test_scheduler_value(scheduler, targets, epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler(self):
|
||||||
|
epochs = 10
|
||||||
|
power = 0.9
|
||||||
|
min_lr = 0.001
|
||||||
|
iters = 4
|
||||||
|
single_targets = [
|
||||||
|
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
|
||||||
|
for i in range(iters)
|
||||||
|
] + [min_lr] * (
|
||||||
|
epochs - iters)
|
||||||
|
targets = [single_targets, [x * epochs for x in single_targets]]
|
||||||
|
scheduler = PolyParamScheduler(
|
||||||
|
self.optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
power=power,
|
||||||
|
eta_min=min_lr,
|
||||||
|
end=iters + 1)
|
||||||
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
@ -402,6 +424,14 @@ class TestParameterScheduler(TestCase):
|
|||||||
end_factor=0.3),
|
end_factor=0.3),
|
||||||
epochs=epochs)
|
epochs=epochs)
|
||||||
|
|
||||||
|
def test_poly_scheduler_state_dict(self):
|
||||||
|
self._check_scheduler_state_dict(
|
||||||
|
lambda: PolyParamScheduler(
|
||||||
|
self.optimizer, param_name='lr', power=0.5, eta_min=0.001),
|
||||||
|
lambda: PolyParamScheduler(
|
||||||
|
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
|
||||||
|
epochs=10)
|
||||||
|
|
||||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||||
# use Linear in the first 5 epochs and then use MultiStep
|
# use Linear in the first 5 epochs and then use MultiStep
|
||||||
epochs = 12
|
epochs = 12
|
||||||
|
Loading…
x
Reference in New Issue
Block a user