support Piecewise.learning_rate (#2899)

pull/2902/head
Tingquan Gao 2023-08-04 12:16:01 +08:00 committed by GitHub
parent 657037c4e7
commit 4247fac82e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 13 additions and 0 deletions

View File

@ -403,10 +403,23 @@ class Piecewise(LRBase):
warmup_start_lr=0.0,
last_epoch=-1,
by_epoch=False,
learning_rate=None,
**kwargs):
if learning_rate:
decay_epochs = list(range(0, epochs, 30))[1:]
values = [
learning_rate * (0.1**i)
for i in range(len(decay_epochs) + 1)
]
logger.warning(
"When 'learning_rate' of Piecewise has beed set, "
"the learning rate scheduler would be set by the rule that lr decay 10 times every 30 epochs. "
f"So, the 'decay_epochs' and 'values' have been set to {decay_epochs} and {values} respectively."
)
super(Piecewise,
self).__init__(epochs, step_each_epoch, values[0], warmup_epoch,
warmup_start_lr, last_epoch, by_epoch)
self.values = values
self.boundaries_steps = [e * step_each_epoch for e in decay_epochs]
if self.by_epoch is True: