deep-person-reid/torchreid/optim/lr_scheduler.py

65 lines
2.4 KiB
Python
Raw Normal View History

2019-03-20 01:26:08 +08:00
from __future__ import absolute_import
from __future__ import print_function
import torch
2019-08-23 06:17:47 +08:00
AVAI_SCH = ['single_step', 'multi_step', 'cosine']
2019-03-20 01:26:08 +08:00
2019-08-23 06:17:47 +08:00
def build_lr_scheduler(optimizer, lr_scheduler='single_step', stepsize=1, gamma=0.1, max_epoch=1):
"""A function wrapper for building a learning rate scheduler.
Args:
optimizer (Optimizer): an Optimizer.
2019-08-23 06:17:47 +08:00
lr_scheduler (str, optional): learning rate scheduler method. Default is single_step.
stepsize (int or list, optional): step size to decay learning rate. When ``lr_scheduler``
is "single_step", ``stepsize`` should be an integer. When ``lr_scheduler`` is
"multi_step", ``stepsize`` is a list. Default is 1.
gamma (float, optional): decay rate. Default is 0.1.
2019-08-23 06:17:47 +08:00
max_epoch (int, optional): maximum epoch (for cosine annealing). Default is 1.
Examples::
>>> # Decay learning rate by every 20 epochs.
>>> scheduler = torchreid.optim.build_lr_scheduler(
>>> optimizer, lr_scheduler='single_step', stepsize=20
>>> )
>>> # Decay learning rate at 30, 50 and 55 epochs.
>>> scheduler = torchreid.optim.build_lr_scheduler(
>>> optimizer, lr_scheduler='multi_step', stepsize=[30, 50, 55]
>>> )
"""
2019-03-20 01:26:08 +08:00
if lr_scheduler not in AVAI_SCH:
raise ValueError('Unsupported scheduler: {}. Must be one of {}'.format(lr_scheduler, AVAI_SCH))
if lr_scheduler == 'single_step':
2019-04-26 22:45:23 +08:00
if isinstance(stepsize, list):
stepsize = stepsize[-1]
2019-03-20 01:26:08 +08:00
if not isinstance(stepsize, int):
raise TypeError(
'For single_step lr_scheduler, stepsize must '
'be an integer, but got {}'.format(type(stepsize))
)
scheduler = torch.optim.lr_scheduler.StepLR(
optimizer, step_size=stepsize, gamma=gamma
)
elif lr_scheduler == 'multi_step':
if not isinstance(stepsize, list):
raise TypeError(
'For multi_step lr_scheduler, stepsize must '
'be a list, but got {}'.format(type(stepsize))
)
scheduler = torch.optim.lr_scheduler.MultiStepLR(
optimizer, milestones=stepsize, gamma=gamma
)
2019-08-23 06:17:47 +08:00
elif lr_scheduler == 'cosine':
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
optimizer, float(max_epoch)
)
2019-03-20 01:26:08 +08:00
return scheduler