144 lines
5.1 KiB
Python
144 lines
5.1 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import print_function
|
|
|
|
import warnings
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
|
|
AVAI_OPTIMS = ['adam', 'amsgrad', 'sgd', 'rmsprop']
|
|
|
|
|
|
def build_optimizer(
|
|
model,
|
|
optim='adam',
|
|
lr=0.0003,
|
|
weight_decay=5e-04,
|
|
momentum=0.9,
|
|
sgd_dampening=0,
|
|
sgd_nesterov=False,
|
|
rmsprop_alpha=0.99,
|
|
adam_beta1=0.9,
|
|
adam_beta2=0.99,
|
|
staged_lr=False,
|
|
new_layers='',
|
|
base_lr_mult=0.1
|
|
):
|
|
"""A function wrapper for building an optimizer.
|
|
|
|
Args:
|
|
model (nn.Module): model.
|
|
optim (str, optional): optimizer. Default is "adam".
|
|
lr (float, optional): learning rate. Default is 0.0003.
|
|
weight_decay (float, optional): weight decay (L2 penalty). Default is 5e-04.
|
|
momentum (float, optional): momentum factor in sgd. Default is 0.9,
|
|
sgd_dampening (float, optional): dampening for momentum. Default is 0.
|
|
sgd_nesterov (bool, optional): enables Nesterov momentum. Default is False.
|
|
rmsprop_alpha (float, optional): smoothing constant for rmsprop. Default is 0.99.
|
|
adam_beta1 (float, optional): beta-1 value in adam. Default is 0.9.
|
|
adam_beta2 (float, optional): beta-2 value in adam. Default is 0.99,
|
|
staged_lr (bool, optional): uses different learning rates for base and new layers. Base
|
|
layers are pretrained layers while new layers are randomly initialized, e.g. the
|
|
identity classification layer. Enabling ``staged_lr`` can allow the base layers to
|
|
be trained with a smaller learning rate determined by ``base_lr_mult``, while the new
|
|
layers will take the ``lr``. Default is False.
|
|
new_layers (str or list): attribute names in ``model``. Default is empty.
|
|
base_lr_mult (float, optional): learning rate multiplier for base layers. Default is 0.1.
|
|
|
|
Examples::
|
|
>>> # A normal optimizer can be built by
|
|
>>> optimizer = torchreid.optim.build_optimizer(model, optim='sgd', lr=0.01)
|
|
>>> # If you want to use a smaller learning rate for pretrained layers
|
|
>>> # and the attribute name for the randomly initialized layer is 'classifier',
|
|
>>> # you can do
|
|
>>> optimizer = torchreid.optim.build_optimizer(
|
|
>>> model, optim='sgd', lr=0.01, staged_lr=True,
|
|
>>> new_layers='classifier', base_lr_mult=0.1
|
|
>>> )
|
|
>>> # Now the `classifier` has learning rate 0.01 but the base layers
|
|
>>> # have learning rate 0.01 * 0.1.
|
|
>>> # new_layers can also take multiple attribute names. Say the new layers
|
|
>>> # are 'fc' and 'classifier', you can do
|
|
>>> optimizer = torchreid.optim.build_optimizer(
|
|
>>> model, optim='sgd', lr=0.01, staged_lr=True,
|
|
>>> new_layers=['fc', 'classifier'], base_lr_mult=0.1
|
|
>>> )
|
|
"""
|
|
if optim not in AVAI_OPTIMS:
|
|
raise ValueError('Unsupported optim: {}. Must be one of {}'.format(optim, AVAI_OPTIMS))
|
|
|
|
if not isinstance(model, nn.Module):
|
|
raise TypeError('model given to build_optimizer must be an instance of nn.Module')
|
|
|
|
if staged_lr:
|
|
if isinstance(new_layers, str):
|
|
if new_layers is None:
|
|
warnings.warn('new_layers is empty, therefore, staged_lr is useless')
|
|
new_layers = [new_layers]
|
|
|
|
if isinstance(model, nn.DataParallel):
|
|
model = model.module
|
|
|
|
base_params = []
|
|
base_layers = []
|
|
new_params = []
|
|
|
|
for name, module in model.named_children():
|
|
if name in new_layers:
|
|
new_params += [p for p in module.parameters()]
|
|
else:
|
|
base_params += [p for p in module.parameters()]
|
|
base_layers.append(name)
|
|
|
|
param_groups = [
|
|
{'params': base_params, 'lr': lr * base_lr_mult},
|
|
{'params': new_params},
|
|
]
|
|
|
|
print('Use staged learning rate')
|
|
print('Base layers (lr*{}): {}'.format(base_lr_mult, base_layers))
|
|
print('New layers (lr): {}'.format(new_layers))
|
|
|
|
else:
|
|
param_groups = model.parameters()
|
|
|
|
print('Initializing optimizer: {}'.format(optim))
|
|
|
|
if optim == 'adam':
|
|
optimizer = torch.optim.Adam(
|
|
param_groups,
|
|
lr=lr,
|
|
weight_decay=weight_decay,
|
|
betas=(adam_beta1, adam_beta2),
|
|
)
|
|
|
|
elif optim == 'amsgrad':
|
|
optimizer = torch.optim.Adam(
|
|
param_groups,
|
|
lr=lr,
|
|
weight_decay=weight_decay,
|
|
betas=(adam_beta1, adam_beta2),
|
|
amsgrad=True,
|
|
)
|
|
|
|
elif optim == 'sgd':
|
|
optimizer = torch.optim.SGD(
|
|
param_groups,
|
|
lr=lr,
|
|
momentum=momentum,
|
|
weight_decay=weight_decay,
|
|
dampening=sgd_dampening,
|
|
nesterov=sgd_nesterov,
|
|
)
|
|
|
|
elif optim == 'rmsprop':
|
|
optimizer = torch.optim.RMSprop(
|
|
param_groups,
|
|
lr=lr,
|
|
momentum=momentum,
|
|
weight_decay=weight_decay,
|
|
alpha=rmsprop_alpha,
|
|
)
|
|
|
|
return optimizer |