[Fix] Fix the calculation error of eta_min in CosineRestart (#639)
* [Fix] fix CosineRestart eta_min * add ut case * Enhance unit test Enhance unit test * remove unused code Co-authored-by: HAOCHENYE <21724054@zju.edu.cn>pull/659/head
parent
64ac14303f
commit
090104df21
|
@ -1224,16 +1224,11 @@ class CosineRestartParamScheduler(_ParamScheduler):
|
|||
self.optimizer.param_groups):
|
||||
eta_max = base_value * current_weight
|
||||
if self.eta_min_ratio is None:
|
||||
eta_min = self.eta_min * (1 - current_weight)
|
||||
eta_min = self.eta_min
|
||||
else:
|
||||
eta_min = base_value * self.eta_min_ratio * (1 -
|
||||
current_weight)
|
||||
eta_min = base_value * self.eta_min_ratio
|
||||
if step == 0:
|
||||
values.append(eta_max)
|
||||
|
||||
elif (step - 1 - current_periods) % (2 * current_periods) == 0:
|
||||
values.append(group[self.param_name] + (eta_max - eta_min) *
|
||||
(1 - math.cos(math.pi / current_periods)) / 2)
|
||||
else:
|
||||
values.append(
|
||||
(1 + math.cos(math.pi * step / current_periods)) /
|
||||
|
|
|
@ -430,6 +430,8 @@ class TestParameterScheduler(TestCase):
|
|||
targets = [
|
||||
single_targets, [t * self.layer2_mult for t in single_targets]
|
||||
]
|
||||
|
||||
# Test with non-zero eta-min.
|
||||
scheduler = CosineRestartParamScheduler(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
|
@ -438,6 +440,26 @@ class TestParameterScheduler(TestCase):
|
|||
eta_min=0)
|
||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||
|
||||
epochs = 10
|
||||
t = 10
|
||||
eta_min = 5e-3
|
||||
targets1 = [
|
||||
eta_min + (0.05 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2
|
||||
for x in range(epochs)
|
||||
]
|
||||
targets2 = [
|
||||
eta_min + (0.5 - eta_min) * (1 + math.cos(math.pi * x / t)) / 2
|
||||
for x in range(epochs)
|
||||
]
|
||||
targets = [targets1, targets2]
|
||||
scheduler = CosineRestartParamScheduler(
|
||||
self.optimizer,
|
||||
param_name='lr',
|
||||
periods=[t],
|
||||
restart_weights=[1],
|
||||
eta_min=eta_min)
|
||||
self._test_scheduler_value(scheduler, targets, epochs=10)
|
||||
|
||||
def _check_scheduler_state_dict(self, construct, construct2, epochs=10):
|
||||
scheduler = construct()
|
||||
for _ in range(epochs):
|
||||
|
|
Loading…
Reference in New Issue