From 090104df21acd05a8aadae5a0d743a7da3314f6f Mon Sep 17 00:00:00 2001 From: Z-Fran <49083766+Z-Fran@users.noreply.github.com> Date: Tue, 1 Nov 2022 15:48:39 +0800 Subject: [PATCH] [Fix] Fix the calculation error of eta_min in CosineRestart (#639) * [Fix] fix CosineRestart eta_min * add ut case * Enhance unit test Enhance unit test * remove unused code Co-authored-by: HAOCHENYE <21724054@zju.edu.cn> --- mmengine/optim/scheduler/param_scheduler.py | 9 ++------ .../test_scheduler/test_param_scheduler.py | 22 +++++++++++++++++++ 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/mmengine/optim/scheduler/param_scheduler.py b/mmengine/optim/scheduler/param_scheduler.py index a2a8276c..df2e25da 100644 --- a/mmengine/optim/scheduler/param_scheduler.py +++ b/mmengine/optim/scheduler/param_scheduler.py @@ -1224,16 +1224,11 @@ class CosineRestartParamScheduler(_ParamScheduler): self.optimizer.param_groups): eta_max = base_value * current_weight if self.eta_min_ratio is None: - eta_min = self.eta_min * (1 - current_weight) + eta_min = self.eta_min else: - eta_min = base_value * self.eta_min_ratio * (1 - - current_weight) + eta_min = base_value * self.eta_min_ratio if step == 0: values.append(eta_max) - - elif (step - 1 - current_periods) % (2 * current_periods) == 0: - values.append(group[self.param_name] + (eta_max - eta_min) * - (1 - math.cos(math.pi / current_periods)) / 2) else: values.append( (1 + math.cos(math.pi * step / current_periods)) / diff --git a/tests/test_optim/test_scheduler/test_param_scheduler.py b/tests/test_optim/test_scheduler/test_param_scheduler.py index e27f0dd7..82db0a76 100644 --- a/tests/test_optim/test_scheduler/test_param_scheduler.py +++ b/tests/test_optim/test_scheduler/test_param_scheduler.py @@ -430,6 +430,8 @@ class TestParameterScheduler(TestCase): targets = [ single_targets, [t * self.layer2_mult for t in single_targets] ] + + # Test with non-zero eta-min. scheduler = CosineRestartParamScheduler( self.optimizer, param_name='lr', @@ -438,6 +440,26 @@ class TestParameterScheduler(TestCase): eta_min=0) self._test_scheduler_value(scheduler, targets, epochs=10) + epochs = 10 + t = 10 + eta_min = 5e-3 + targets1 = [ + eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 + for x in range(epochs) + ] + targets2 = [ + eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2 + for x in range(epochs) + ] + targets = [targets1, targets2] + scheduler = CosineRestartParamScheduler( + self.optimizer, + param_name='lr', + periods=[t], + restart_weights=[1], + eta_min=eta_min) + self._test_scheduler_value(scheduler, targets, epochs=10) + def _check_scheduler_state_dict(self, construct, construct2, epochs=10): scheduler = construct() for _ in range(epochs):