[Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR (#188)

* [Enhancement] Add PolyParamScheduler, PolyMomentum and PolyLR

* min_lr -> eta_min, refined docstr
This commit is contained in:
Tong Gao 2022-04-25 13:44:15 +08:00 committed by GitHub
parent e2a2b0438e
commit c3aff4fc9a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 240 additions and 11 deletions

View File

@ -1,14 +1,14 @@
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR, from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
LinearLR, MultiStepLR, StepLR) LinearLR, MultiStepLR, PolyLR, StepLR)
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum, from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum, ExponentialMomentum, LinearMomentum,
MultiStepMomentum, StepMomentum) MultiStepMomentum, PolyMomentum, StepMomentum)
from .param_scheduler import (ConstantParamScheduler, from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler, CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler, ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler, MultiStepParamScheduler, PolyParamScheduler,
_ParamScheduler) StepParamScheduler, _ParamScheduler)
__all__ = [ __all__ = [
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR', 'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
@ -16,5 +16,6 @@ __all__ = [
'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum', 'ExponentialMomentum', 'LinearMomentum', 'MultiStepMomentum',
'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler', 'StepMomentum', 'ConstantParamScheduler', 'CosineAnnealingParamScheduler',
'ExponentialParamScheduler', 'LinearParamScheduler', 'ExponentialParamScheduler', 'LinearParamScheduler',
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler' 'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
'PolyParamScheduler', 'PolyLR', 'PolyMomentum'
] ]

View File

@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler, from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler, CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler, ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler) MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
@ -294,3 +295,49 @@ class StepLR(StepParamScheduler):
last_step=last_step, last_step=last_step,
by_epoch=by_epoch, by_epoch=by_epoch,
verbose=verbose) verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyLR(PolyParamScheduler):
"""Decays the learning rate of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum learning rate at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)

View File

@ -7,7 +7,8 @@ from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler, from .param_scheduler import (INF, ConstantParamScheduler,
CosineAnnealingParamScheduler, CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler, ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, StepParamScheduler) MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
@PARAM_SCHEDULERS.register_module() @PARAM_SCHEDULERS.register_module()
@ -294,3 +295,49 @@ class StepMomentum(StepParamScheduler):
last_step=last_step, last_step=last_step,
by_epoch=by_epoch, by_epoch=by_epoch,
verbose=verbose) verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyMomentum(PolyParamScheduler):
"""Decays the momentum of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum momentum at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)

View File

@ -534,6 +534,7 @@ class LinearParamScheduler(_ParamScheduler):
Notice that such decay can happen simultaneously with other changes to the Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler. parameter value from outside this scheduler.
Args: Args:
optimizer (Optimizer): Wrapped optimizer. optimizer (Optimizer): Wrapped optimizer.
start_factor (float): The number we multiply parameter value in the start_factor (float): The number we multiply parameter value in the
@ -598,3 +599,64 @@ class LinearParamScheduler(_ParamScheduler):
(self.end_factor - self.start_factor))) (self.end_factor - self.start_factor)))
for group in self.optimizer.param_groups for group in self.optimizer.param_groups
] ]
@PARAM_SCHEDULERS.register_module()
class PolyParamScheduler(_ParamScheduler):
"""Decays the parameter value of each parameter group in a polynomial decay
scheme.
Notice that such decay can happen simultaneously with other changes to the
parameter value from outside this scheduler.
Args:
optimizer (Optimizer): Wrapped optimizer.
eta_min (float): Minimum parameter value at the end of scheduling.
Defaults to 0.
power (float): The power of the polynomial. Defaults to 1.0.
begin (int): Step at which to start updating the parameters.
Defaults to 0.
end (int): Step at which to stop updating the parameters.
Defaults to INF.
last_step (int): The index of last step. Used for resume without
state dict. Defaults to -1.
by_epoch (bool): Whether the scheduled parameters are updated by
epochs. Defaults to True.
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: Optimizer,
param_name: str,
eta_min: float = 0,
power: float = 1.0,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
self.eta_min = eta_min
self.power = power
self.total_iters = end - begin - 1
super().__init__(
optimizer,
param_name=param_name,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
def _get_value(self):
if self.last_step == 0:
return [
group[self.param_name] for group in self.optimizer.param_groups
]
return [(group[self.param_name] - self.eta_min) *
(1 - 1 / (self.total_iters - self.last_step + 1))**self.power +
self.eta_min for group in self.optimizer.param_groups]

View File

@ -8,7 +8,7 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR, from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
ExponentialLR, LinearLR, MultiStepLR, ExponentialLR, LinearLR, MultiStepLR,
StepLR, _ParamScheduler) PolyLR, StepLR, _ParamScheduler)
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
@ -283,6 +283,21 @@ class TestLRScheduler(TestCase):
scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min) scheduler = CosineAnnealingLR(self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs) self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyLR(
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10): def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct() scheduler = construct()
for _ in range(epochs): for _ in range(epochs):
@ -331,6 +346,12 @@ class TestLRScheduler(TestCase):
lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3), lambda: LinearLR(self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs) epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyLR(self.optimizer, power=0.5, eta_min=0.001),
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self): def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep # use Linear in the first 5 epochs and then use MultiStep
epochs = 12 epochs = 12

View File

@ -9,8 +9,8 @@ import torch.optim as optim
from mmengine.optim.scheduler import (ConstantMomentum, from mmengine.optim.scheduler import (ConstantMomentum,
CosineAnnealingMomentum, CosineAnnealingMomentum,
ExponentialMomentum, LinearMomentum, ExponentialMomentum, LinearMomentum,
MultiStepMomentum, StepMomentum, MultiStepMomentum, PolyMomentum,
_ParamScheduler) StepMomentum, _ParamScheduler)
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
@ -284,6 +284,21 @@ class TestMomentumScheduler(TestCase):
self.optimizer, T_max=t, eta_min=eta_min) self.optimizer, T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs) self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyMomentum(
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10): def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct() scheduler = construct()
for _ in range(epochs): for _ in range(epochs):
@ -333,6 +348,12 @@ class TestMomentumScheduler(TestCase):
self.optimizer, start_factor=0, end_factor=0.3), self.optimizer, start_factor=0, end_factor=0.3),
epochs=epochs) epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyMomentum(self.optimizer, power=0.5, eta_min=0.001),
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self): def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep # use Linear in the first 5 epochs and then use MultiStep
epochs = 12 epochs = 12

View File

@ -6,12 +6,15 @@ import torch
import torch.nn.functional as F import torch.nn.functional as F
import torch.optim as optim import torch.optim as optim
# yapf: disable
from mmengine.optim.scheduler import (ConstantParamScheduler, from mmengine.optim.scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler, CosineAnnealingParamScheduler,
ExponentialParamScheduler, ExponentialParamScheduler,
LinearParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, MultiStepParamScheduler,
StepParamScheduler, _ParamScheduler) PolyParamScheduler, StepParamScheduler,
_ParamScheduler)
# yapf: enable
from mmengine.testing import assert_allclose from mmengine.testing import assert_allclose
@ -336,6 +339,25 @@ class TestParameterScheduler(TestCase):
self.optimizer, param_name='lr', T_max=t, eta_min=eta_min) self.optimizer, param_name='lr', T_max=t, eta_min=eta_min)
self._test_scheduler_value(scheduler, targets, epochs) self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler(self):
epochs = 10
power = 0.9
min_lr = 0.001
iters = 4
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = PolyParamScheduler(
self.optimizer,
param_name='lr',
power=power,
eta_min=min_lr,
end=iters + 1)
self._test_scheduler_value(scheduler, targets, epochs=10)
def _check_scheduler_state_dict(self, construct, construct2, epochs=10): def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
scheduler = construct() scheduler = construct()
for _ in range(epochs): for _ in range(epochs):
@ -402,6 +424,14 @@ class TestParameterScheduler(TestCase):
end_factor=0.3), end_factor=0.3),
epochs=epochs) epochs=epochs)
def test_poly_scheduler_state_dict(self):
self._check_scheduler_state_dict(
lambda: PolyParamScheduler(
self.optimizer, param_name='lr', power=0.5, eta_min=0.001),
lambda: PolyParamScheduler(
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self): def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep # use Linear in the first 5 epochs and then use MultiStep
epochs = 12 epochs = 12