deep-person-reid/torchreid/optim/optimizer.py

138 lines
4.9 KiB
Python
Raw Normal View History

2018-07-04 17:32:43 +08:00
from __future__ import absolute_import
2019-02-03 22:04:10 +08:00
from __future__ import print_function
2018-07-04 17:32:43 +08:00
import warnings
2018-04-27 16:51:04 +08:00
import torch
2019-02-03 22:04:10 +08:00
import torch.nn as nn
2018-04-27 16:51:04 +08:00
2019-03-20 01:26:08 +08:00
AVAI_OPTIMS = ['adam', 'amsgrad', 'sgd', 'rmsprop']
def build_optimizer(
model,
optim='adam',
lr=0.0003,
weight_decay=5e-04,
momentum=0.9,
sgd_dampening=0,
sgd_nesterov=False,
rmsprop_alpha=0.99,
adam_beta1=0.9,
adam_beta2=0.99,
staged_lr=False,
new_layers='',
2019-03-20 01:26:08 +08:00
base_lr_mult=0.1
2019-03-16 07:17:38 +08:00
):
"""A function wrapper for building an optimizer.
Args:
model (nn.Module): model.
optim (str, optional): optimizer. Default is "adam".
lr (float, optional): learning rate. Default is 0.0003.
weight_decay (float, optional): weight decay (L2 penalty). Default is 5e-04.
momentum (float, optional): momentum factor in sgd. Default is 0.9,
sgd_dampening (float, optional): dampening for momentum. Default is 0.
sgd_nesterov (bool, optional): enables Nesterov momentum. Default is False.
rmsprop_alpha (float, optional): smoothing constant for rmsprop. Default is 0.99.
adam_beta1 (float, optional): beta-1 value in adam. Default is 0.9.
adam_beta2 (float, optional): beta-2 value in adam. Default is 0.99,
staged_lr (bool, optional): uses different learning rates for base and new layers. Base
layers are pretrained layers while new layers are randomly initialized, e.g. the
identity classification layer. Enabling ``staged_lr`` can allow the base layers to
be trained with a smaller learning rate determined by ``base_lr_mult``, while the new
layers will take the ``lr``. Default is False.
new_layers (str or list): attribute names in ``model``. Default is empty.
base_lr_mult (float, optional): learning rate multiplier for base layers. Default is 0.1.
Examples::
>>> # A normal optimizer can be built by
>>> optimizer = torchreid.optim.build_optimizer(model, optim='sgd', lr=0.01)
>>> # If you want to use a smaller learning rate for pretrained layers
>>> # and the attribute name for the randomly initialized layer is 'classifier',
>>> # you can do
>>> optimizer = torchreid.optim.build_optimizer(
>>> model, optim='sgd', lr=0.01, staged_lr=True,
>>> new_layers='classifier', base_lr_mult=0.1
>>> )
>>> # Now the `classifier` has learning rate 0.01 but the base layers
>>> # have learning rate 0.01 * 0.1.
>>> # new_layers can also take multiple attribute names. Say the new layers
>>> # are 'fc' and 'classifier', you can do
>>> optimizer = torchreid.optim.build_optimizer(
>>> model, optim='sgd', lr=0.01, staged_lr=True,
>>> new_layers=['fc', 'classifier'], base_lr_mult=0.1
>>> )
"""
2019-03-20 01:26:08 +08:00
if optim not in AVAI_OPTIMS:
raise ValueError('Unsupported optim: {}. Must be one of {}'.format(optim, AVAI_OPTIMS))
if not isinstance(model, nn.Module):
raise TypeError('model given to build_optimizer must be an instance of nn.Module')
2019-03-16 07:17:38 +08:00
2019-02-03 22:04:10 +08:00
if staged_lr:
2019-03-20 01:26:08 +08:00
if isinstance(new_layers, str):
if new_layers is None:
warnings.warn('new_layers is empty, therefore, staged_lr is useless')
2019-03-20 01:26:08 +08:00
new_layers = [new_layers]
if isinstance(model, nn.DataParallel):
model = model.module
2019-02-03 22:04:10 +08:00
base_params = []
base_layers = []
new_params = []
2019-03-20 01:26:08 +08:00
2019-02-03 22:04:10 +08:00
for name, module in model.named_children():
if name in new_layers:
new_params += [p for p in module.parameters()]
else:
base_params += [p for p in module.parameters()]
base_layers.append(name)
2019-03-20 01:26:08 +08:00
2019-02-03 22:04:10 +08:00
param_groups = [
{'params': base_params, 'lr': lr * base_lr_mult},
{'params': new_params},
]
2019-03-16 07:17:38 +08:00
2019-02-03 22:04:10 +08:00
else:
param_groups = model.parameters()
2019-03-16 07:17:38 +08:00
2018-04-27 16:51:04 +08:00
if optim == 'adam':
2019-03-16 07:17:38 +08:00
optimizer = torch.optim.Adam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
)
2018-07-02 20:56:46 +08:00
elif optim == 'amsgrad':
2019-03-16 07:17:38 +08:00
optimizer = torch.optim.Adam(
param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
amsgrad=True,
)
2018-04-27 16:51:04 +08:00
elif optim == 'sgd':
2019-03-16 07:17:38 +08:00
optimizer = torch.optim.SGD(
param_groups,
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
dampening=sgd_dampening,
nesterov=sgd_nesterov,
)
2018-05-01 00:50:17 +08:00
elif optim == 'rmsprop':
2019-03-16 07:17:38 +08:00
optimizer = torch.optim.RMSprop(
param_groups,
lr=lr,
momentum=momentum,
weight_decay=weight_decay,
alpha=rmsprop_alpha,
)
2019-03-20 01:26:08 +08:00
return optimizer