mirror of
https://github.com/open-mmlab/mmengine.git
synced 2025-06-03 21:54:44 +08:00
[Fix] Fix the calculation error of eta_min in CosineRestart (#639)
* [Fix] fix CosineRestart eta_min * add ut case * Enhance unit test Enhance unit test * remove unused code Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>
This commit is contained in:
parent
64ac14303f
commit
090104df21
@ -1224,16 +1224,11 @@ class CosineRestartParamScheduler(_ParamScheduler):
|
|||||||
self.optimizer.param_groups):
|
self.optimizer.param_groups):
|
||||||
eta_max = base_value * current_weight
|
eta_max = base_value * current_weight
|
||||||
if self.eta_min_ratio is None:
|
if self.eta_min_ratio is None:
|
||||||
eta_min = self.eta_min * (1 - current_weight)
|
eta_min = self.eta_min
|
||||||
else:
|
else:
|
||||||
eta_min = base_value * self.eta_min_ratio * (1 -
|
eta_min = base_value * self.eta_min_ratio
|
||||||
current_weight)
|
|
||||||
if step == 0:
|
if step == 0:
|
||||||
values.append(eta_max)
|
values.append(eta_max)
|
||||||
|
|
||||||
elif (step - 1 - current_periods) % (2 * current_periods) == 0:
|
|
||||||
values.append(group[self.param_name] + (eta_max - eta_min) *
|
|
||||||
(1 - math.cos(math.pi / current_periods)) / 2)
|
|
||||||
else:
|
else:
|
||||||
values.append(
|
values.append(
|
||||||
(1 + math.cos(math.pi * step / current_periods)) /
|
(1 + math.cos(math.pi * step / current_periods)) /
|
||||||
|
@ -430,6 +430,8 @@ class TestParameterScheduler(TestCase):
|
|||||||
targets = [
|
targets = [
|
||||||
single_targets, [t * self.layer2_mult for t in single_targets]
|
single_targets, [t * self.layer2_mult for t in single_targets]
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# Test with non-zero eta-min.
|
||||||
scheduler = CosineRestartParamScheduler(
|
scheduler = CosineRestartParamScheduler(
|
||||||
self.optimizer,
|
self.optimizer,
|
||||||
param_name='lr',
|
param_name='lr',
|
||||||
@ -438,6 +440,26 @@ class TestParameterScheduler(TestCase):
|
|||||||
eta_min=0)
|
eta_min=0)
|
||||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
|
epochs = 10
|
||||||
|
t = 10
|
||||||
|
eta_min = 5e-3
|
||||||
|
targets1 = [
|
||||||
|
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2
|
||||||
|
for x in range(epochs)
|
||||||
|
]
|
||||||
|
targets2 = [
|
||||||
|
eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2
|
||||||
|
for x in range(epochs)
|
||||||
|
]
|
||||||
|
targets = [targets1, targets2]
|
||||||
|
scheduler = CosineRestartParamScheduler(
|
||||||
|
self.optimizer,
|
||||||
|
param_name='lr',
|
||||||
|
periods=[t],
|
||||||
|
restart_weights=[1],
|
||||||
|
eta_min=eta_min)
|
||||||
|
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||||
|
|
||||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||||
scheduler = construct()
|
scheduler = construct()
|
||||||
for _ in range(epochs):
|
for _ in range(epochs):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user