diff --git a/mmengine/optim/scheduler/lr_scheduler.py b/mmengine/optim/scheduler/lr_scheduler.py index cbb82dfc..b1eeabec 100644 --- a/mmengine/optim/scheduler/lr_scheduler.py +++ b/mmengine/optim/scheduler/lr_scheduler.py @@ -71,7 +71,7 @@ class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler): Args: optimizer (Optimizer or OptimWrapper): Wrapped optimizer. T_max (int): Maximum number of iterations. - eta_min (float): Minimum learning rate. Defaults to 0. + eta_min (float): Minimum learning rate. Defaults to None. begin (int): Step at which to start updating the learning rate. Defaults to 0. end (int): Step at which to stop updating the learning rate. @@ -82,6 +82,10 @@ class CosineAnnealingLR(LRSchedulerMixin, CosineAnnealingParamScheduler): epochs. Defaults to True. verbose (bool): Whether to print the learning rate for each update. Defaults to False. + eta_min_ratio (float, optional): The ratio of the minimum parameter + value to the base parameter value. Either `eta_min` or + `eta_min_ratio` should be specified. Defaults to None. + New in version 0.3.2. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983 diff --git a/mmengine/optim/scheduler/momentum_scheduler.py b/mmengine/optim/scheduler/momentum_scheduler.py index 9cb5f9e9..102b1731 100644 --- a/mmengine/optim/scheduler/momentum_scheduler.py +++ b/mmengine/optim/scheduler/momentum_scheduler.py @@ -101,7 +101,7 @@ class CosineAnnealingMomentum(MomentumSchedulerMixin, optimizer (Optimizer or OptimWrapper): optimizer or Wrapped optimizer. T_max (int): Maximum number of iterations. - eta_min (float): Minimum momentum value. Defaults to 0. + eta_min (float): Minimum momentum value. Defaults to None. begin (int): Step at which to start updating the momentum. Defaults to 0. end (int): Step at which to stop updating the momentum. @@ -112,6 +112,10 @@ class CosineAnnealingMomentum(MomentumSchedulerMixin, epochs. Defaults to True. verbose (bool): Whether to print the momentum for each update. Defaults to False. + eta_min_ratio (float, optional): The ratio of the minimum parameter + value to the base parameter value. Either `eta_min` or + `eta_min_ratio` should be specified. Defaults to None. + New in version 0.3.2. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983 diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index 31101dce..7fdf5c3d 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -599,7 +599,7 @@ class CosineAnnealingParamScheduler(_ParamScheduler): ``lr``, ``momentum``. T_max (int, optional): Maximum number of iterations. If not specified, use ``end - begin``. Defaults to None. - eta_min (float): Minimum parameter value. Defaults to 0. + eta_min (float, optional): Minimum parameter value. Defaults to None. begin (int): Step at which to start updating the parameters. Defaults to 0. end (int): Step at which to stop updating the parameters. @@ -610,6 +610,10 @@ class CosineAnnealingParamScheduler(_ParamScheduler): epochs. Defaults to True. verbose (bool): Whether to print the value for each update. Defaults to False. + eta_min_ratio (float, optional): The ratio of the minimum parameter + value to the base parameter value. Either `eta_min` or + `eta_min_ratio` should be specified. Defaults to None. + New in version 0.3.2. .. _SGDR\: Stochastic Gradient Descent with Warm Restarts: https://arxiv.org/abs/1608.03983 @@ -619,14 +623,21 @@ class CosineAnnealingParamScheduler(_ParamScheduler): optimizer: Union[Optimizer, OptimWrapper], param_name: str, T_max: Optional[int] = None, - eta_min: float = 0., + eta_min: Optional[float] = None, begin: int = 0, end: int = INF, last_step: int = -1, by_epoch: bool = True, - verbose: bool = False): + verbose: bool = False, + eta_min_ratio: Optional[float] = None): + # To preserve backwards compatibility + if eta_min is None and eta_min_ratio is None: + eta_min = 0. + assert (eta_min is None) ^ (eta_min_ratio is None), \ + 'Either `eta_min` or `eta_min_ratio should be specified' self.T_max = T_max or (end - begin) self.eta_min = eta_min + self.eta_min_ratio = eta_min_ratio super().__init__( optimizer, param_name=param_name, @@ -666,23 +677,31 @@ class CosineAnnealingParamScheduler(_ParamScheduler): by_epoch=by_epoch, **kwargs) - def _get_value(self): + def _get_value(self) -> list: """Compute value using chainable form of the scheduler.""" + + def _get_eta_min(base_value): + if self.eta_min_ratio is None: + return self.eta_min + return base_value * self.eta_min_ratio + if self.last_step == 0: return [ group[self.param_name] for group in self.optimizer.param_groups ] elif (self.last_step - 1 - self.T_max) % (2 * self.T_max) == 0: return [ - group[self.param_name] + (base_value - self.eta_min) * + group[self.param_name] + + (base_value - _get_eta_min(base_value)) * (1 - math.cos(math.pi / self.T_max)) / 2 for base_value, group in zip(self.base_values, self.optimizer.param_groups) ] return [(1 + math.cos(math.pi * self.last_step / self.T_max)) / (1 + math.cos(math.pi * (self.last_step - 1) / self.T_max)) * - (group[self.param_name] - self.eta_min) + self.eta_min - for group in self.optimizer.param_groups] + (group[self.param_name] - _get_eta_min(base_value)) + + _get_eta_min(base_value) for base_value, group in zip( + self.base_values, self.optimizer.param_groups)] @PARAM_SCHEDULERS.register_module() @@ -1131,11 +1150,11 @@ class CosineRestartParamScheduler(_ParamScheduler): periods (list[int]): Periods for each cosine anneling cycle. restart_weights (list[float]): Restart weights at each restart iteration. Defaults to [1]. - eta_min (float): Minimum parameter value at the end of scheduling. - Defaults to None. + eta_min (float, optional): Minimum parameter value at the end of + scheduling. Defaults to None. eta_min_ratio (float, optional): The ratio of minimum parameter value - to the base parameter value. Either `min_lr` or `min_lr_ratio` - should be specified. Default: None. + to the base parameter value. Either `eta_min` or `eta_min_ratio` + should be specified. Defaults to None. begin (int): Step at which to start updating the parameters. Defaults to 0. end (int): Step at which to stop updating the parameters. diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index 82db0a76..ce86195b 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -364,20 +364,48 @@ class TestParameterScheduler(TestCase): self._test_scheduler_value(scheduler, targets, epochs) def test_cos_anneal_scheduler(self): + with self.assertRaises(AssertionError): + CosineAnnealingParamScheduler( + self.optimizer, + param_name='lr', + T_max=10, + eta_min=0, + eta_min_ratio=0.1) epochs = 12 t = 10 - eta_min = 1e-10 - single_targets = [ + eta_min = 5e-3 + targets1 = [ eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs) ] - targets = [ - single_targets, [x * self.layer2_mult for x in single_targets] + targets2 = [ + eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 + for x in range(epochs) ] + targets = [targets1, targets2] scheduler = CosineAnnealingParamScheduler( self.optimizer, param_name='lr', T_max=t, eta_min=eta_min) self._test_scheduler_value(scheduler, targets, epochs) + # Test `eta_min_ratio` + self.setUp() + eta_min_ratio = 1e-3 + targets1 = [ + 0.05 * eta_min_ratio + (0.05 - 0.05 * eta_min_ratio) * + (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs) + ] + targets2 = [ + 0.5 * eta_min_ratio + (0.5 - 0.5 * eta_min_ratio) * + (1 + math.cos(math.pi * x / t)) / 2 for x in range(epochs) + ] + targets = [targets1, targets2] + scheduler = CosineAnnealingParamScheduler( + self.optimizer, + param_name='lr', + T_max=t, + eta_min_ratio=eta_min_ratio) + self._test_scheduler_value(scheduler, targets, epochs) + # Test default `T_max` scheduler = CosineAnnealingParamScheduler( self.optimizer, param_name='lr', begin=5, end=100, eta_min=eta_min)