[Feature] Support CosineRestartParamScheduler. (#397)
* [Feature] Support CosineRestartParamScheduler. * add ut and docstring * add docstringpull/424/head
parent
b14cbc2576
commit
813f49bf23
|
@ -1,16 +1,22 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, ExponentialLR,
|
||||
LinearLR, MultiStepLR, OneCycleLR, PolyLR, StepLR)
|
||||
# yapf: disable
|
||||
from .lr_scheduler import (ConstantLR, CosineAnnealingLR, CosineRestartLR,
|
||||
ExponentialLR, LinearLR, MultiStepLR, OneCycleLR,
|
||||
PolyLR, StepLR)
|
||||
from .momentum_scheduler import (ConstantMomentum, CosineAnnealingMomentum,
|
||||
ExponentialMomentum, LinearMomentum,
|
||||
MultiStepMomentum, PolyMomentum, StepMomentum)
|
||||
CosineRestartMomentum, ExponentialMomentum,
|
||||
LinearMomentum, MultiStepMomentum,
|
||||
PolyMomentum, StepMomentum)
|
||||
from .param_scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
CosineRestartParamScheduler,
|
||||
ExponentialParamScheduler, LinearParamScheduler,
|
||||
MultiStepParamScheduler, OneCycleParamScheduler,
|
||||
PolyParamScheduler, StepParamScheduler,
|
||||
_ParamScheduler)
|
||||
|
||||
# yapf: enable
|
||||
|
||||
__all__ = [
|
||||
'ConstantLR', 'CosineAnnealingLR', 'ExponentialLR', 'LinearLR',
|
||||
'MultiStepLR', 'StepLR', 'ConstantMomentum', 'CosineAnnealingMomentum',
|
||||
|
@ -19,5 +25,6 @@ __all__ = [
|
|||
'ExponentialParamScheduler', 'LinearParamScheduler',
|
||||
'MultiStepParamScheduler', 'StepParamScheduler', '_ParamScheduler',
|
||||
'PolyParamScheduler', 'PolyLR', 'PolyMomentum', 'OneCycleParamScheduler',
|
||||
'OneCycleLR'
|
||||
'OneCycleLR', 'CosineRestartParamScheduler', 'CosineRestartLR',
|
||||
'CosineRestartMomentum'
|
||||
]
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from mmengine.registry import PARAM_SCHEDULERS
|
||||
from .param_scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
CosineRestartParamScheduler,
|
||||
ExponentialParamScheduler, LinearParamScheduler,
|
||||
MultiStepParamScheduler, OneCycleParamScheduler,
|
||||
PolyParamScheduler, StepParamScheduler)
|
||||
|
@ -277,3 +278,35 @@ class OneCycleLR(LRSchedulerMixin, OneCycleParamScheduler):
|
|||
.. _Super-Convergence\: Very Fast Training of Neural Networks Using Large Learning Rates:
|
||||
https://arxiv.org/abs/1708.07120
|
||||
"""# noqa E501
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class CosineRestartLR(LRSchedulerMixin, CosineRestartParamScheduler):
|
||||
"""Sets the learning rate of each parameter group according to the cosine
|
||||
annealing with restarts scheme. The cosine restart policy anneals the
|
||||
learning rate from the initial value to `eta_min` with a cosine annealing
|
||||
schedule and then restarts another period from the maximum value multiplied
|
||||
with `restart_weight`.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
periods (list[int]): Periods for each cosine anneling cycle.
|
||||
restart_weights (list[float]): Restart weights at each
|
||||
restart iteration. Defaults to [1].
|
||||
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||
Defaults to None.
|
||||
eta_min_ratio (float, optional): The ratio of minimum parameter value
|
||||
to the base parameter value. Either `min_lr` or `min_lr_ratio`
|
||||
should be specified. Default: None.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
Defaults to 0.
|
||||
end (int): Step at which to stop updating the parameters.
|
||||
Defaults to INF.
|
||||
last_step (int): The index of last step. Used for resume without
|
||||
state dict. Defaults to -1.
|
||||
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||
epochs. Defaults to True.
|
||||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
from mmengine.registry import PARAM_SCHEDULERS
|
||||
from .param_scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
CosineRestartParamScheduler,
|
||||
ExponentialParamScheduler, LinearParamScheduler,
|
||||
MultiStepParamScheduler, PolyParamScheduler,
|
||||
StepParamScheduler)
|
||||
|
@ -243,3 +244,36 @@ class PolyMomentum(MomentumSchedulerMixin, PolyParamScheduler):
|
|||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class CosineRestartMomentum(MomentumSchedulerMixin,
|
||||
CosineRestartParamScheduler):
|
||||
"""Sets the momentum of each parameter group according to the cosine
|
||||
annealing with restarts scheme. The cosine restart policy anneals the
|
||||
momentum from the initial value to `eta_min` with a cosine annealing
|
||||
schedule and then restarts another period from the maximum value multiplied
|
||||
with `restart_weight`.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
periods (list[int]): Periods for each cosine anneling cycle.
|
||||
restart_weights (list[float]): Restart weights at each
|
||||
restart iteration. Defaults to [1].
|
||||
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||
Defaults to None.
|
||||
eta_min_ratio (float, optional): The ratio of minimum parameter value
|
||||
to the base parameter value. Either `min_lr` or `min_lr_ratio`
|
||||
should be specified. Default: None.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
Defaults to 0.
|
||||
end (int): Step at which to stop updating the parameters.
|
||||
Defaults to INF.
|
||||
last_step (int): The index of last step. Used for resume without
|
||||
state dict. Defaults to -1.
|
||||
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||
epochs. Defaults to True.
|
||||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
|
|
@ -9,7 +9,7 @@ import warnings
|
|||
import weakref
|
||||
from collections import Counter
|
||||
from functools import wraps
|
||||
from typing import Callable, List, Optional, Union
|
||||
from typing import Callable, List, Optional, Sequence, Union
|
||||
|
||||
from torch.optim import Optimizer
|
||||
|
||||
|
@ -227,6 +227,8 @@ class StepParamScheduler(_ParamScheduler):
|
|||
|
||||
Args:
|
||||
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
step_size (int): Period of parameter value decay.
|
||||
gamma (float): Multiplicative factor of parameter value decay.
|
||||
Defaults to 0.1.
|
||||
|
@ -313,6 +315,8 @@ class MultiStepParamScheduler(_ParamScheduler):
|
|||
|
||||
Args:
|
||||
optimizer (OptimWrapper or Optimizer): Wrapped optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
milestones (list): List of epoch indices. Must be increasing.
|
||||
gamma (float): Multiplicative factor of parameter value decay.
|
||||
Defaults to 0.1.
|
||||
|
@ -401,6 +405,8 @@ class ConstantParamScheduler(_ParamScheduler):
|
|||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
factor (float): The number we multiply parameter value until the
|
||||
milestone. Defaults to 1./3.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
|
@ -488,6 +494,8 @@ class ExponentialParamScheduler(_ParamScheduler):
|
|||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
gamma (float): Multiplicative factor of parameter value decay.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
Defaults to 0.
|
||||
|
@ -585,6 +593,8 @@ class CosineAnnealingParamScheduler(_ParamScheduler):
|
|||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
T_max (int, optional): Maximum number of iterations. If not specified,
|
||||
use ``end - begin``. Defaults to None.
|
||||
eta_min (float): Minimum parameter value. Defaults to 0.
|
||||
|
@ -684,6 +694,8 @@ class LinearParamScheduler(_ParamScheduler):
|
|||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
start_factor (float): The number we multiply parameter value in the
|
||||
first epoch. The multiplication factor changes towards end_factor
|
||||
in the following epochs. Defaults to 1./3.
|
||||
|
@ -780,6 +792,8 @@ class PolyParamScheduler(_ParamScheduler):
|
|||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||
Defaults to 0.
|
||||
power (float): The power of the polynomial. Defaults to 1.0.
|
||||
|
@ -882,6 +896,8 @@ class OneCycleParamScheduler(_ParamScheduler):
|
|||
|
||||
Args:
|
||||
optimizer (Optimizer): Wrapped optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
eta_max (float or list): Upper parameter value boundaries in the cycle
|
||||
for each parameter group.
|
||||
total_steps (int): The total number of steps in the cycle. Note that
|
||||
|
@ -1094,3 +1110,159 @@ class OneCycleParamScheduler(_ParamScheduler):
|
|||
params.append(computed_param)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
@PARAM_SCHEDULERS.register_module()
|
||||
class CosineRestartParamScheduler(_ParamScheduler):
|
||||
"""Sets the parameters of each parameter group according to the cosine
|
||||
annealing with restarts scheme. The cosine restart policy anneals the
|
||||
parameter from the initial value to `eta_min` with a cosine annealing
|
||||
schedule and then restarts another period from the maximum value multiplied
|
||||
with `restart_weight`.
|
||||
|
||||
Args:
|
||||
optimizer (Optimizer or OptimWrapper): optimizer or Wrapped
|
||||
optimizer.
|
||||
param_name (str): Name of the parameter to be adjusted, such as
|
||||
``lr``, ``momentum``.
|
||||
periods (list[int]): Periods for each cosine anneling cycle.
|
||||
restart_weights (list[float]): Restart weights at each
|
||||
restart iteration. Defaults to [1].
|
||||
eta_min (float): Minimum parameter value at the end of scheduling.
|
||||
Defaults to None.
|
||||
eta_min_ratio (float, optional): The ratio of minimum parameter value
|
||||
to the base parameter value. Either `min_lr` or `min_lr_ratio`
|
||||
should be specified. Default: None.
|
||||
begin (int): Step at which to start updating the parameters.
|
||||
Defaults to 0.
|
||||
end (int): Step at which to stop updating the parameters.
|
||||
Defaults to INF.
|
||||
last_step (int): The index of last step. Used for resume without
|
||||
state dict. Defaults to -1.
|
||||
by_epoch (bool): Whether the scheduled parameters are updated by
|
||||
epochs. Defaults to True.
|
||||
verbose (bool): Whether to print the value for each update.
|
||||
Defaults to False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer: Union[Optimizer, OptimWrapper],
|
||||
param_name: str,
|
||||
periods: List[int],
|
||||
restart_weights: Sequence[float] = (1, ),
|
||||
eta_min: Optional[float] = None,
|
||||
eta_min_ratio: Optional[float] = None,
|
||||
begin: int = 0,
|
||||
end: int = INF,
|
||||
last_step: int = -1,
|
||||
by_epoch: bool = True,
|
||||
verbose: bool = False):
|
||||
assert (eta_min is None) ^ (eta_min_ratio is None)
|
||||
self.periods = periods
|
||||
self.eta_min = eta_min
|
||||
self.eta_min_ratio = eta_min_ratio
|
||||
self.restart_weights = restart_weights
|
||||
assert (len(self.periods) == len(self.restart_weights)
|
||||
), 'periods and restart_weights should have the same length.'
|
||||
self.cumulative_periods = [
|
||||
sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))
|
||||
]
|
||||
|
||||
super().__init__(
|
||||
optimizer,
|
||||
param_name=param_name,
|
||||
begin=begin,
|
||||
end=end,
|
||||
last_step=last_step,
|
||||
by_epoch=by_epoch,
|
||||
verbose=verbose)
|
||||
|
||||
@classmethod
|
||||
def build_iter_from_epoch(cls,
|
||||
*args,
|
||||
periods,
|
||||
begin=0,
|
||||
end=INF,
|
||||
by_epoch=True,
|
||||
epoch_length=None,
|
||||
**kwargs):
|
||||
"""Build an iter-based instance of this scheduler from an epoch-based
|
||||
config."""
|
||||
assert by_epoch, 'Only epoch-based kwargs whose `by_epoch=True` can ' \
|
||||
'be converted to iter-based.'
|
||||
assert epoch_length is not None and epoch_length > 0, \
|
||||
f'`epoch_length` must be a positive integer, ' \
|
||||
f'but got {epoch_length}.'
|
||||
periods = [p * epoch_length for p in periods]
|
||||
by_epoch = False
|
||||
begin = int(begin * epoch_length)
|
||||
if end != INF:
|
||||
end = int(end * epoch_length)
|
||||
return cls(
|
||||
*args,
|
||||
periods=periods,
|
||||
begin=begin,
|
||||
end=end,
|
||||
by_epoch=by_epoch,
|
||||
**kwargs)
|
||||
|
||||
def _get_value(self):
|
||||
"""Compute value using chainable form of the scheduler."""
|
||||
idx = self.get_position_from_periods(self.last_step,
|
||||
self.cumulative_periods)
|
||||
# if current step is not in the periods, return origin parameters
|
||||
if idx is None:
|
||||
return [
|
||||
group[self.param_name] for group in self.optimizer.param_groups
|
||||
]
|
||||
current_weight = self.restart_weights[idx]
|
||||
nearest_restart = 0 if idx == 0 else self.cumulative_periods[idx - 1]
|
||||
current_periods = self.periods[idx]
|
||||
step = self.last_step - nearest_restart
|
||||
values = []
|
||||
for base_value, group in zip(self.base_values,
|
||||
self.optimizer.param_groups):
|
||||
eta_max = base_value * current_weight
|
||||
if self.eta_min_ratio is None:
|
||||
eta_min = self.eta_min * (1 - current_weight)
|
||||
else:
|
||||
eta_min = base_value * self.eta_min_ratio * (1 -
|
||||
current_weight)
|
||||
if step == 0:
|
||||
values.append(eta_max)
|
||||
|
||||
elif (step - 1 - current_periods) % (2 * current_periods) == 0:
|
||||
values.append(group[self.param_name] + (eta_max - eta_min) *
|
||||
(1 - math.cos(math.pi / current_periods)) / 2)
|
||||
else:
|
||||
values.append(
|
||||
(1 + math.cos(math.pi * step / current_periods)) /
|
||||
(1 + math.cos(math.pi * (step - 1) / current_periods)) *
|
||||
(group[self.param_name] - eta_min) + eta_min)
|
||||
|
||||
return values
|
||||
|
||||
@staticmethod
|
||||
def get_position_from_periods(
|
||||
iteration: int, cumulative_periods: List[int]) -> Optional[int]:
|
||||
"""Get the position from a period list.
|
||||
|
||||
It will return the index of the right-closest number in the period
|
||||
list.
|
||||
For example, the cumulative_periods = [100, 200, 300, 400],
|
||||
if iteration == 50, return 0;
|
||||
if iteration == 210, return 2;
|
||||
if iteration == 300, return 3.
|
||||
|
||||
Args:
|
||||
iteration (int): Current iteration.
|
||||
cumulative_periods (list[int]): Cumulative period list.
|
||||
|
||||
Returns:
|
||||
Optional[int]: The position of the right-closest number in the
|
||||
period list. If not in the period, return None.
|
||||
"""
|
||||
for i, period in enumerate(cumulative_periods):
|
||||
if iteration < period:
|
||||
return i
|
||||
return None
|
||||
|
|
|
@ -7,8 +7,8 @@ import torch.nn.functional as F
|
|||
import torch.optim as optim
|
||||
|
||||
from mmengine.optim.scheduler import (ConstantLR, CosineAnnealingLR,
|
||||
ExponentialLR, LinearLR, MultiStepLR,
|
||||
OneCycleLR, PolyLR, StepLR,
|
||||
CosineRestartLR, ExponentialLR, LinearLR,
|
||||
MultiStepLR, OneCycleLR, PolyLR, StepLR,
|
||||
_ParamScheduler)
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
@ -333,6 +333,34 @@ class TestLRScheduler(TestCase):
|
|||
self.optimizer, power=power, eta_min=min_lr, end=iters + 1)
|
||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||
|
||||
def test_cosine_restart_scheduler(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
CosineRestartLR(
|
||||
self.optimizer,
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0,
|
||||
eta_min_ratio=0.1)
|
||||
with self.assertRaises(AssertionError):
|
||||
CosineRestartLR(
|
||||
self.optimizer,
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5, 0.0],
|
||||
eta_min=0)
|
||||
single_targets = [
|
||||
0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
|
||||
0.0086372, 0.0023872, 0.0023872
|
||||
]
|
||||
targets = [
|
||||
single_targets, [t * self.layer2_mult for t in single_targets]
|
||||
]
|
||||
scheduler = CosineRestartLR(
|
||||
self.optimizer,
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0)
|
||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||
|
||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||
scheduler = construct()
|
||||
for _ in range(epochs):
|
||||
|
@ -387,6 +415,20 @@ class TestLRScheduler(TestCase):
|
|||
lambda: PolyLR(self.optimizer, power=0.8, eta_min=0.002),
|
||||
epochs=10)
|
||||
|
||||
def test_cosine_restart_scheduler_state_dict(self):
|
||||
self._check_scheduler_state_dict(
|
||||
lambda: CosineRestartLR(
|
||||
self.optimizer,
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0),
|
||||
lambda: CosineRestartLR(
|
||||
self.optimizer,
|
||||
periods=[4, 6],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0),
|
||||
epochs=10)
|
||||
|
||||
def test_step_scheduler_convert_iterbased(self):
|
||||
# invalid epoch_length
|
||||
with self.assertRaises(AssertionError):
|
||||
|
|
|
@ -8,6 +8,7 @@ import torch.optim as optim
|
|||
|
||||
from mmengine.optim.scheduler import (ConstantMomentum,
|
||||
CosineAnnealingMomentum,
|
||||
CosineRestartMomentum,
|
||||
ExponentialMomentum, LinearMomentum,
|
||||
MultiStepMomentum, PolyMomentum,
|
||||
StepMomentum, _ParamScheduler)
|
||||
|
@ -399,6 +400,43 @@ class TestMomentumScheduler(TestCase):
|
|||
self._test_scheduler_value(
|
||||
self.optimizer_with_betas, scheduler, targets, epochs=10)
|
||||
|
||||
def test_cosine_restart_scheduler(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
CosineRestartMomentum(
|
||||
self.optimizer,
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0,
|
||||
eta_min_ratio=0.1)
|
||||
with self.assertRaises(AssertionError):
|
||||
CosineRestartMomentum(
|
||||
self.optimizer,
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5, 0.0],
|
||||
eta_min=0)
|
||||
single_targets = [
|
||||
0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
|
||||
0.0086372, 0.0023872, 0.0023872
|
||||
]
|
||||
targets = [
|
||||
single_targets, [t * self.layer2_mult for t in single_targets]
|
||||
]
|
||||
scheduler = CosineRestartMomentum(
|
||||
self.optimizer,
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0)
|
||||
self._test_scheduler_value(
|
||||
self.optimizer, scheduler, targets, epochs=10)
|
||||
|
||||
scheduler = CosineRestartMomentum(
|
||||
self.optimizer_with_betas,
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0)
|
||||
self._test_scheduler_value(
|
||||
self.optimizer_with_betas, scheduler, targets, epochs=10)
|
||||
|
||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||
scheduler = construct()
|
||||
for _ in range(epochs):
|
||||
|
@ -454,6 +492,20 @@ class TestMomentumScheduler(TestCase):
|
|||
lambda: PolyMomentum(self.optimizer, power=0.8, eta_min=0.002),
|
||||
epochs=10)
|
||||
|
||||
def test_cosine_restart_scheduler_state_dict(self):
|
||||
self._check_scheduler_state_dict(
|
||||
lambda: CosineRestartMomentum(
|
||||
self.optimizer,
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0),
|
||||
lambda: CosineRestartMomentum(
|
||||
self.optimizer,
|
||||
periods=[4, 6],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0),
|
||||
epochs=10)
|
||||
|
||||
def test_multi_scheduler_without_overlap_linear_multi_step(self):
|
||||
# use Linear in the first 5 epochs and then use MultiStep
|
||||
epochs = 12
|
||||
|
|
|
@ -12,12 +12,13 @@ from mmengine.optim import OptimWrapper
|
|||
# yapf: disable
|
||||
from mmengine.optim.scheduler import (ConstantParamScheduler,
|
||||
CosineAnnealingParamScheduler,
|
||||
CosineRestartParamScheduler,
|
||||
ExponentialParamScheduler,
|
||||
LinearParamScheduler,
|
||||
MultiStepParamScheduler,
|
||||
OneCycleParamScheduler,
|
||||
PolyParamScheduler, StepParamScheduler,
|
||||
_ParamScheduler)
|
||||
from mmengine.optim.scheduler.param_scheduler import OneCycleParamScheduler
|
||||
# yapf: enable
|
||||
from mmengine.testing import assert_allclose
|
||||
|
||||
|
@ -406,6 +407,37 @@ class TestParameterScheduler(TestCase):
|
|||
end=iters + 1)
|
||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||
|
||||
def test_cosine_restart_scheduler(self):
|
||||
with self.assertRaises(AssertionError):
|
||||
CosineRestartParamScheduler(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0,
|
||||
eta_min_ratio=0.1)
|
||||
with self.assertRaises(AssertionError):
|
||||
CosineRestartParamScheduler(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5, 0.0],
|
||||
eta_min=0)
|
||||
single_targets = [
|
||||
0.05, 0.0426776, 0.025, 0.00732233, 0.025, 0.022612712, 0.01636271,
|
||||
0.0086372, 0.0023872, 0.0023872
|
||||
]
|
||||
targets = [
|
||||
single_targets, [t * self.layer2_mult for t in single_targets]
|
||||
]
|
||||
scheduler = CosineRestartParamScheduler(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0)
|
||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||
|
||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||
scheduler = construct()
|
||||
for _ in range(epochs):
|
||||
|
@ -483,6 +515,22 @@ class TestParameterScheduler(TestCase):
|
|||
self.optimizer, param_name='lr', power=0.8, eta_min=0.002),
|
||||
epochs=10)
|
||||
|
||||
def test_cosine_restart_scheduler_state_dict(self):
|
||||
self._check_scheduler_state_dict(
|
||||
lambda: CosineRestartParamScheduler(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
periods=[4, 5],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0),
|
||||
lambda: CosineRestartParamScheduler(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
periods=[4, 6],
|
||||
restart_weights=[1, 0.5],
|
||||
eta_min=0),
|
||||
epochs=10)
|
||||
|
||||
def test_step_scheduler_convert_iterbased(self):
|
||||
# invalid epoch_length
|
||||
with self.assertRaises(AssertionError):
|
||||
|
|
Loading…
Reference in New Issue