add amsgrad

pull/62/head
KaiyangZhou 2018-07-02 13:56:46 +01:00
parent 462e251006
commit 4d4d70eb94
1 changed files with 2 additions and 0 deletions

View File

@ -4,6 +4,8 @@ import torch
def init_optim(optim, params, lr, weight_decay):
if optim == 'adam':
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
elif optim == 'amsgrad':
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, amsgrad=True)
elif optim == 'sgd':
return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
elif optim == 'rmsprop':