deep-person-reid/torchreid/optimizers.py

34 lines
1.4 KiB
Python
Raw Normal View History

2018-07-04 17:32:43 +08:00
from __future__ import absolute_import
2018-04-27 16:51:04 +08:00
import torch
2018-11-08 01:09:23 +08:00
def init_optimizer(params,
optim='adam',
lr=0.003,
weight_decay=5e-4,
momentum=0.9, # momentum factor for sgd and rmsprop
sgd_dampening=0, # sgd's dampening for momentum
sgd_nesterov=False, # whether to enable sgd's Nesterov momentum
rmsprop_alpha=0.99, # rmsprop's smoothing constant
adam_beta1=0.9, # exponential decay rate for adam's first moment
adam_beta2=0.999 # # exponential decay rate for adam's second moment
):
2018-04-27 16:51:04 +08:00
if optim == 'adam':
2018-11-08 01:09:23 +08:00
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2))
2018-07-02 20:56:46 +08:00
elif optim == 'amsgrad':
2018-11-08 01:09:23 +08:00
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2), amsgrad=True)
2018-04-27 16:51:04 +08:00
elif optim == 'sgd':
2018-11-08 01:09:23 +08:00
return torch.optim.SGD(params, lr=lr, momentum=momentum, weight_decay=weight_decay,
dampening=sgd_dampening, nesterov=sgd_nesterov)
2018-05-01 00:50:17 +08:00
elif optim == 'rmsprop':
2018-11-08 01:09:23 +08:00
return torch.optim.RMSprop(params, lr=lr, momentum=momentum, weight_decay=weight_decay,
alpha=rmsprop_alpha)
2018-04-27 16:51:04 +08:00
else:
2018-07-02 17:17:14 +08:00
raise KeyError("Unsupported optimizer: {}".format(optim))