[Feature] Support convert epoch-based schedulers to iter-based. (#221)

* [Feature] Support convert epoch-based schedulers to iter-based.

* Support convert and refactor LR and Momentum to mixin.

* Add unit tests

* fix args and add runner ut

* resolve comments
pull/196/head
RangiLyu 2022-05-10 15:17:51 +08:00 committed by GitHub
parent 92b94e8e60
commit 1912660db9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 546 additions and 302 deletions

View File

@ -1,18 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler,
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
class LRSchedulerMixin:
"""A mixin class for learning rate schedulers."""
def __init__(self, optimizer, *args, **kwargs):
super().__init__(optimizer, 'lr', *args, **kwargs)
@PARAM_SCHEDULERS.register_module()
class ConstantLR(ConstantParamScheduler):
class ConstantLR(LRSchedulerMixin, ConstantParamScheduler):
"""Decays the learning rate value of each parameter group by a small
constant factor until the number of epoch reaches a pre-defined milestone:
``end``. Notice that such decay can happen simultaneously with other
@ -34,27 +37,9 @@ class ConstantLR(ConstantParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
factor: float = 1.0 / 3,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
factor=factor,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class CosineAnnealingLR(CosineAnnealingParamScheduler):
class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler):
r"""Set the learning rate of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial value and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
@ -101,29 +86,9 @@ class CosineAnnealingLR(CosineAnnealingParamScheduler):
https://arxiv.org/abs/1608.03983
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
T_max: int,
eta_min: int = 0,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
T_max=T_max,
eta_min=eta_min,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class ExponentialLR(ExponentialParamScheduler):
class ExponentialLR(LRSchedulerMixin, ExponentialParamScheduler):
"""Decays the learning rate of each parameter group by gamma every epoch.
Args:
@ -141,27 +106,9 @@ class ExponentialLR(ExponentialParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
gamma: float,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
gamma=gamma,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class LinearLR(LinearParamScheduler):
class LinearLR(LRSchedulerMixin, LinearParamScheduler):
"""Decays the learning rate of each parameter group by linearly changing
small multiplicative factor until the number of epoch reaches a pre-defined
milestone: ``end``.
@ -187,29 +134,9 @@ class LinearLR(LinearParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
start_factor: float = 1.0 / 3,
end_factor: float = 1.0,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
start_factor=start_factor,
end_factor=end_factor,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class MultiStepLR(MultiStepParamScheduler):
class MultiStepLR(LRSchedulerMixin, MultiStepParamScheduler):
"""Decays the specified learning rate in each parameter group by gamma once
the number of epoch reaches one of the milestones. Notice that such decay
can happen simultaneously with other changes to the learning rate from
@ -232,29 +159,9 @@ class MultiStepLR(MultiStepParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
milestones: List[int],
gamma: float = 0.1,
last_step: int = -1,
begin: int = 0,
end: int = INF,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
milestones=milestones,
gamma=gamma,
last_step=last_step,
begin=begin,
end=end,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class StepLR(StepParamScheduler):
class StepLR(LRSchedulerMixin, StepParamScheduler):
"""Decays the learning rate of each parameter group by gamma every
step_size epochs. Notice that such decay can happen simultaneously with
other changes to the learning rate from outside this scheduler.
@ -276,29 +183,9 @@ class StepLR(StepParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
step_size: int,
gamma: float = 0.1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
step_size=step_size,
gamma=gamma,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyLR(PolyParamScheduler):
class PolyLR(LRSchedulerMixin, PolyParamScheduler):
"""Decays the learning rate of each parameter group in a polynomial decay
scheme.
@ -321,23 +208,3 @@ class PolyLR(PolyParamScheduler):
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='lr',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)

View File

@ -1,18 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from mmengine.registry import PARAM_SCHEDULERS
from .param_scheduler import (INF, ConstantParamScheduler,
from .param_scheduler import (ConstantParamScheduler,
CosineAnnealingParamScheduler,
ExponentialParamScheduler, LinearParamScheduler,
MultiStepParamScheduler, PolyParamScheduler,
StepParamScheduler)
class MomentumSchedulerMixin:
"""A mixin class for momentum schedulers."""
def __init__(self, optimizer, *args, **kwargs):
super().__init__(optimizer, 'momentum', *args, **kwargs)
@PARAM_SCHEDULERS.register_module()
class ConstantMomentum(ConstantParamScheduler):
class ConstantMomentum(MomentumSchedulerMixin, ConstantParamScheduler):
"""Decays the momentum value of each parameter group by a small constant
factor until the number of epoch reaches a pre-defined milestone: ``end``.
Notice that such decay can happen simultaneously with other changes to the
@ -34,27 +37,10 @@ class ConstantMomentum(ConstantParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
factor: float = 1.0 / 3,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
factor=factor,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class CosineAnnealingMomentum(CosineAnnealingParamScheduler):
class CosineAnnealingMomentum(MomentumSchedulerMixin,
CosineAnnealingParamScheduler):
r"""Set the momentum of each parameter group using a cosine annealing
schedule, where :math:`\eta_{max}` is set to the initial value and
:math:`T_{cur}` is the number of epochs since the last restart in SGDR:
@ -101,29 +87,9 @@ class CosineAnnealingMomentum(CosineAnnealingParamScheduler):
https://arxiv.org/abs/1608.03983
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
T_max: int,
eta_min: int = 0,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
T_max=T_max,
eta_min=eta_min,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class ExponentialMomentum(ExponentialParamScheduler):
class ExponentialMomentum(MomentumSchedulerMixin, ExponentialParamScheduler):
"""Decays the momentum of each parameter group by gamma every epoch.
Args:
@ -141,27 +107,9 @@ class ExponentialMomentum(ExponentialParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
gamma: float,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
gamma=gamma,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class LinearMomentum(LinearParamScheduler):
class LinearMomentum(MomentumSchedulerMixin, LinearParamScheduler):
"""Decays the momentum of each parameter group by linearly changing
small multiplicative factor until the number of epoch reaches a pre-defined
milestone: ``end``.
@ -187,29 +135,9 @@ class LinearMomentum(LinearParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
start_factor: float = 1.0 / 3,
end_factor: float = 1.0,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
start_factor=start_factor,
end_factor=end_factor,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class MultiStepMomentum(MultiStepParamScheduler):
class MultiStepMomentum(MomentumSchedulerMixin, MultiStepParamScheduler):
"""Decays the specified momentum in each parameter group by gamma once the
number of epoch reaches one of the milestones. Notice that such decay can
happen simultaneously with other changes to the momentum from outside this
@ -232,29 +160,9 @@ class MultiStepMomentum(MultiStepParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
milestones: List[int],
gamma: float = 0.1,
last_step: int = -1,
begin: int = 0,
end: int = INF,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
milestones=milestones,
gamma=gamma,
last_step=last_step,
begin=begin,
end=end,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class StepMomentum(StepParamScheduler):
class StepMomentum(MomentumSchedulerMixin, StepParamScheduler):
"""Decays the momentum of each parameter group by gamma every step_size
epochs. Notice that such decay can happen simultaneously with other changes
to the momentum from outside this scheduler.
@ -276,29 +184,9 @@ class StepMomentum(StepParamScheduler):
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
step_size: int,
gamma: float = 0.1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
step_size=step_size,
gamma=gamma,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)
@PARAM_SCHEDULERS.register_module()
class PolyMomentum(PolyParamScheduler):
class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler):
"""Decays the momentum of each parameter group in a polynomial decay
scheme.
@ -321,23 +209,3 @@ class PolyMomentum(PolyParamScheduler):
verbose (bool): Whether to print the value for each update.
Defaults to False.
"""
def __init__(self,
optimizer: torch.optim.Optimizer,
eta_min: float = 0,
power: float = 1,
begin: int = 0,
end: int = INF,
last_step: int = -1,
by_epoch: bool = True,
verbose: bool = False):
super().__init__(
optimizer,
param_name='momentum',
eta_min=eta_min,
power=power,
begin=begin,
end=end,
last_step=last_step,
by_epoch=by_epoch,
verbose=verbose)

View File

@ -255,6 +255,35 @@ class StepParamScheduler(_ParamScheduler):
by_epoch=by_epoch,
verbose=verbose)
@classmethod
def build_iter_from_epoch(cls,
*args,
step_size,
begin=0,
end=INF,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
step_size = step_size * epoch_length
begin = begin * epoch_length
if end != INF:
end = end * epoch_length
return cls(
*args,
step_size=step_size,
begin=begin,
end=end,
by_epoch=by_epoch,
**kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
if (self.last_step == 0) or (self.last_step % self.step_size != 0):
@ -312,6 +341,35 @@ class MultiStepParamScheduler(_ParamScheduler):
by_epoch=by_epoch,
verbose=verbose)
@classmethod
def build_iter_from_epoch(cls,
*args,
milestones,
begin=0,
end=INF,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
milestones = [i * epoch_length for i in milestones]
begin = begin * epoch_length
if end != INF:
end = end * epoch_length
return cls(
*args,
milestones=milestones,
begin=begin,
end=end,
by_epoch=by_epoch,
**kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
if self.last_step not in self.milestones:
@ -372,6 +430,27 @@ class ConstantParamScheduler(_ParamScheduler):
by_epoch=by_epoch,
verbose=verbose)
@classmethod
def build_iter_from_epoch(cls,
*args,
begin=0,
end=INF,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
if end != INF:
end = end * epoch_length
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
if self.last_step == 0:
@ -431,6 +510,27 @@ class ExponentialParamScheduler(_ParamScheduler):
by_epoch=by_epoch,
verbose=verbose)
@classmethod
def build_iter_from_epoch(cls,
*args,
begin=0,
end=INF,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
if end != INF:
end = end * epoch_length
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
if self.last_step == 0:
@ -512,6 +612,35 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
by_epoch=by_epoch,
verbose=verbose)
@classmethod
def build_iter_from_epoch(cls,
*args,
T_max,
begin=0,
end=INF,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
T_max = T_max * epoch_length
begin = begin * epoch_length
if end != INF:
end = end * epoch_length
return cls(
*args,
T_max=T_max,
begin=begin,
end=end,
by_epoch=by_epoch,
**kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
if self.last_step == 0:
@ -589,6 +718,27 @@ class LinearParamScheduler(_ParamScheduler):
by_epoch=by_epoch,
verbose=verbose)
@classmethod
def build_iter_from_epoch(cls,
*args,
begin=0,
end=INF,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
if end != INF:
end = end * epoch_length
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
if self.last_step == 0:
@ -655,6 +805,27 @@ class PolyParamScheduler(_ParamScheduler):
by_epoch=by_epoch,
verbose=verbose)
@classmethod
def build_iter_from_epoch(cls,
*args,
begin=0,
end=INF,
by_epoch=True,
epoch_length=None,
**kwargs):
"""Build an iter-based instance of this scheduler from an epoch-based
config."""
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
'be converted to iter-based.'
assert epoch_length is not None and epoch_length > 0, \
f'`epoch_length` must be a positive integer, ' \
f'but got {epoch_length}.'
by_epoch = False
begin = begin * epoch_length
if end != INF:
end = end * epoch_length
return cls(*args, begin=begin, end=end, by_epoch=by_epoch, **kwargs)
def _get_value(self):
"""Compute value using chainable form of the scheduler."""
if self.last_step == 0:

View File

@ -848,10 +848,29 @@ class Runner:
if isinstance(_scheduler, _ParamScheduler):
param_schedulers.append(_scheduler)
elif isinstance(_scheduler, dict):
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(optimizer=self.optimizer)))
convert_to_iter = _scheduler.pop('convert_to_iter_based',
False)
if convert_to_iter:
assert _scheduler.get(
'by_epoch', True
), 'only epoch-based parameter scheduler can be ' \
'converted to iter-based'
assert isinstance(self.train_loop, BaseLoop), \
'Scheduler can only be converted to iter-based ' \
'when train loop is built.'
cls = PARAM_SCHEDULERS.get(_scheduler.pop('type'))
param_schedulers.append(
cls.build_iter_from_epoch( # type: ignore
optimizer=self.optimizer,
**_scheduler,
epoch_length=len(
self.train_loop.dataloader), # type: ignore
))
else:
param_schedulers.append(
PARAM_SCHEDULERS.build(
_scheduler,
default_args=dict(optimizer=self.optimizer)))
else:
raise TypeError(
'_scheduler should be a _ParamScheduler object or dict, '

View File

@ -352,6 +352,144 @@ class TestLRScheduler(TestCase):
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
epochs=10)
def test_step_scheduler_convert_iterbased(self):
# invalid epoch_length
with self.assertRaises(AssertionError):
scheduler = StepLR.build_iter_from_epoch(
self.optimizer, gamma=0.1, step_size=2, epoch_length=-1)
# lr = 0.05 if epoch < 2
# lr = 0.005 if 2 <= epoch < 4
epochs = 4
epoch_length = 7
single_targets = [0.05] * 2 * epoch_length + [0.005] * 2 * epoch_length
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = StepLR.build_iter_from_epoch(
self.optimizer, gamma=0.1, step_size=2, epoch_length=epoch_length)
self._test_scheduler_value(
scheduler, targets, epochs * epoch_length, param_name='lr')
def test_multi_step_scheduler_convert_iterbased(self):
# lr = 0.05 if epoch < 2
# lr = 0.005 if 2 <= epoch < 5
# lr = 0.0005 if 5 <= epoch < 9
# lr = 0.00005 if epoch >= 9
epochs = 10
epoch_length = 7
single_targets = [0.05
] * 2 * epoch_length + [0.005] * 3 * epoch_length + [
0.0005
] * 4 * epoch_length + [0.00005] * 3 * epoch_length
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = MultiStepLR.build_iter_from_epoch(
self.optimizer,
gamma=0.1,
milestones=[2, 5, 9],
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
def test_constant_scheduler_convert_iterbased(self):
# lr = 0.025 if epoch < 5
# lr = 0.005 if 5 <= epoch
epochs = 10
epoch_length = 7
single_targets = [0.025] * (5 * epoch_length -
1) + [0.05] * (5 * epoch_length + 1)
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = ConstantLR.build_iter_from_epoch(
self.optimizer, factor=1.0 / 2, end=5, epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
def test_linear_scheduler_convert_iterbased(self):
epochs = 10
start_factor = 1.0 / 2
end = 5
epoch_length = 11
iters = end * epoch_length - 1
interpolation = [
start_factor + i * (1 - start_factor) / iters for i in range(iters)
]
single_targets = [x * 0.05 for x in interpolation] + [0.05] * (
epochs * epoch_length - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = LinearLR.build_iter_from_epoch(
self.optimizer,
start_factor=start_factor,
end=end,
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs)
def test_exp_scheduler_convert_iterbased(self):
epochs = 10
epoch_length = 7
single_targets = [
0.05 * (0.9**x) for x in range(epochs * epoch_length)
]
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = ExponentialLR.build_iter_from_epoch(
self.optimizer, gamma=0.9, epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
def test_cos_anneal_scheduler_convert_iterbased(self):
epochs = 12
t = 10
eta_min = 1e-10
epoch_length = 11
single_targets = [
eta_min + (0.05 - eta_min) *
(1 + math.cos(math.pi * x / t / epoch_length)) / 2
for x in range(epochs * epoch_length)
]
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = CosineAnnealingLR.build_iter_from_epoch(
self.optimizer,
T_max=t,
eta_min=eta_min,
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler_convert_iterbased(self):
epochs = 10
power = 0.9
min_lr = 0.001
end = 5
epoch_length = 11
iters = end * epoch_length - 1
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = PolyLR.build_iter_from_epoch(
self.optimizer,
power=power,
eta_min=min_lr,
end=end,
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12

View File

@ -47,7 +47,8 @@ class TestParameterScheduler(TestCase):
def test_invalid_optimizer(self):
with self.assertRaisesRegex(TypeError, 'should be an Optimizer'):
StepParamScheduler('invalid_optimizer', 'lr', step_size=1)
StepParamScheduler(
'invalid_optimizer', step_size=1, param_name='lr')
def test_overwrite_optimzer_step(self):
# raise warning if the counter in optimizer.step() is overwritten
@ -140,7 +141,8 @@ class TestParameterScheduler(TestCase):
def test_get_last_value(self):
epochs = 10
targets = [[0.05] * 3 + [0.005] * 3 + [0.0005] * 3 + [0.00005]]
scheduler = StepParamScheduler(self.optimizer, 'lr', 3, gamma=0.1)
scheduler = StepParamScheduler(
self.optimizer, param_name='lr', step_size=3, gamma=0.1)
for epoch in range(epochs):
result = scheduler.get_last_value()
self.optimizer.step()
@ -432,6 +434,163 @@ class TestParameterScheduler(TestCase):
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
epochs=10)
def test_step_scheduler_convert_iterbased(self):
# invalid epoch_length
with self.assertRaises(AssertionError):
scheduler = StepParamScheduler.build_iter_from_epoch(
self.optimizer,
param_name='momentum',
gamma=0.1,
step_size=2,
epoch_length=-1)
# momentum = 0.01 if epoch < 2
# momentum = 0.001 if 2 <= epoch < 4
epochs = 4
epoch_length = 7
single_targets = [0.01] * 2 * epoch_length + [0.001] * 2 * epoch_length
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = StepParamScheduler.build_iter_from_epoch(
self.optimizer,
param_name='momentum',
gamma=0.1,
step_size=2,
epoch_length=epoch_length)
self._test_scheduler_value(
scheduler, targets, epochs * epoch_length, param_name='momentum')
def test_multi_step_scheduler_convert_iterbased(self):
# lr = 0.05 if epoch < 2
# lr = 0.005 if 2 <= epoch < 5
# lr = 0.0005 if 5 <= epoch < 9
# lr = 0.00005 if epoch >= 9
epochs = 10
epoch_length = 7
single_targets = [0.05
] * 2 * epoch_length + [0.005] * 3 * epoch_length + [
0.0005
] * 4 * epoch_length + [0.00005] * 3 * epoch_length
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = MultiStepParamScheduler.build_iter_from_epoch(
self.optimizer,
param_name='lr',
gamma=0.1,
milestones=[2, 5, 9],
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
def test_constant_scheduler_convert_iterbased(self):
# lr = 0.025 if epoch < 5
# lr = 0.005 if 5 <= epoch
epochs = 10
epoch_length = 7
single_targets = [0.025] * (5 * epoch_length -
1) + [0.05] * (5 * epoch_length + 1)
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = ConstantParamScheduler.build_iter_from_epoch(
self.optimizer,
param_name='lr',
factor=1.0 / 2,
end=5,
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
def test_linear_scheduler_convert_iterbased(self):
epochs = 10
start_factor = 1.0 / 2
end = 5
epoch_length = 11
iters = end * epoch_length - 1
interpolation = [
start_factor + i * (1 - start_factor) / iters for i in range(iters)
]
single_targets = [x * 0.05 for x in interpolation] + [0.05] * (
epochs * epoch_length - iters)
targets = [single_targets, [x * epochs for x in single_targets]]
scheduler = LinearParamScheduler.build_iter_from_epoch(
self.optimizer,
param_name='lr',
start_factor=start_factor,
end=end,
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs)
def test_exp_scheduler_convert_iterbased(self):
epochs = 10
epoch_length = 7
single_targets = [
0.05 * (0.9**x) for x in range(epochs * epoch_length)
]
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = ExponentialParamScheduler.build_iter_from_epoch(
self.optimizer,
param_name='lr',
gamma=0.9,
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs * epoch_length)
def test_cos_anneal_scheduler_convert_iterbased(self):
epochs = 12
t = 10
eta_min = 1e-10
epoch_length = 11
single_targets = [
eta_min + (0.05 - eta_min) *
(1 + math.cos(math.pi * x / t / epoch_length)) / 2
for x in range(epochs * epoch_length)
]
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = CosineAnnealingParamScheduler.build_iter_from_epoch(
self.optimizer,
param_name='lr',
T_max=t,
eta_min=eta_min,
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs)
def test_poly_scheduler_convert_iterbased(self):
epochs = 10
power = 0.9
min_lr = 0.001
end = 5
epoch_length = 11
iters = end * epoch_length - 1
single_targets = [
min_lr + (0.05 - min_lr) * (1 - i / iters)**power
for i in range(iters)
] + [min_lr] * (
epochs - iters)
targets = [
single_targets,
[x * epochs * epoch_length for x in single_targets]
]
scheduler = PolyParamScheduler.build_iter_from_epoch(
self.optimizer,
param_name='lr',
power=power,
eta_min=min_lr,
end=end,
epoch_length=epoch_length)
self._test_scheduler_value(scheduler, targets, epochs=10)
def test_multi_scheduler_without_overlap_linear_multi_step(self):
# use Linear in the first 5 epochs and then use MultiStep
epochs = 12

View File

@ -603,6 +603,28 @@ class TestRunner(TestCase):
self.assertIsInstance(param_schedulers[0], MultiStepLR)
self.assertIsInstance(param_schedulers[1], StepLR)
# train loop should be built before convert scheduler
cfg = dict(
type='MultiStepLR', milestones=[1, 2], convert_to_iter_based=True)
with self.assertRaisesRegex(
AssertionError,
'Scheduler can only be converted to iter-based when '
'train loop is built.'):
param_schedulers = runner.build_param_scheduler(cfg)
# convert epoch-based to iter-based scheduler
cfg = dict(
type='MultiStepLR',
milestones=[1, 2],
begin=1,
end=7,
convert_to_iter_based=True)
runner.train_loop = runner.build_train_loop(runner.train_loop)
param_schedulers = runner.build_param_scheduler(cfg)
self.assertFalse(param_schedulers[0].by_epoch)
self.assertEqual(param_schedulers[0].begin, 4)
self.assertEqual(param_schedulers[0].end, 28)
def test_build_evaluator(self):
cfg = copy.deepcopy(self.epoch_based_cfg)
cfg.experiment_name = 'test_build_evaluator'