2018-07-04 17:32:43 +08:00
|
|
|
from __future__ import absolute_import
|
|
|
|
|
2018-04-27 16:51:04 +08:00
|
|
|
import torch
|
|
|
|
|
|
|
|
|
|
|
|
def init_optim(optim, params, lr, weight_decay):
|
|
|
|
if optim == 'adam':
|
|
|
|
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
|
2018-07-02 20:56:46 +08:00
|
|
|
elif optim == 'amsgrad':
|
|
|
|
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, amsgrad=True)
|
2018-04-27 16:51:04 +08:00
|
|
|
elif optim == 'sgd':
|
|
|
|
return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
|
2018-05-01 00:50:17 +08:00
|
|
|
elif optim == 'rmsprop':
|
|
|
|
return torch.optim.RMSprop(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
|
2018-04-27 16:51:04 +08:00
|
|
|
else:
|
2018-07-02 17:17:14 +08:00
|
|
|
raise KeyError("Unsupported optimizer: {}".format(optim))
|