add amsgrad
parent
462e251006
commit
4d4d70eb94
|
@ -4,6 +4,8 @@ import torch
|
|||
def init_optim(optim, params, lr, weight_decay):
|
||||
if optim == 'adam':
|
||||
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
|
||||
elif optim == 'amsgrad':
|
||||
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay, amsgrad=True)
|
||||
elif optim == 'sgd':
|
||||
return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
|
||||
elif optim == 'rmsprop':
|
||||
|
|
Loading…
Reference in New Issue