deep-person-reid/optimizers.py

16 lines
626 B
Python
Raw Normal View History

2018-07-04 10:32:43 +01:00
from __future__ import absolute_import
2018-04-27 09:51:04 +01:00
import torch
def init_optim(optim, params, lr, weight_decay):
if optim == 'adam':
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
2018-07-02 13:56:46 +01:00
elif optim == 'amsgrad':
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, amsgrad=True)
2018-04-27 09:51:04 +01:00
elif optim == 'sgd':
return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
2018-04-30 17:50:17 +01:00
elif optim == 'rmsprop':
return torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
2018-04-27 09:51:04 +01:00
else:
2018-07-02 10:17:14 +01:00
raise KeyError("Unsupported optimizer: {}".format(optim))