11 lines
370 B
Python
11 lines
370 B
Python
|
import torch
|
||
|
|
||
|
__all__ = ['init_optim']
|
||
|
|
||
|
def init_optim(optim, params, lr, weight_decay):
|
||
|
if optim == 'adam':
|
||
|
return torch.optim.Adam(params, lr=lr, weight_decay=weight_decay)
|
||
|
elif optim == 'sgd':
|
||
|
return torch.optim.SGD(params, lr=lr, momentum=0.9, weight_decay=weight_decay)
|
||
|
else:
|
||
|
raise KeyError("Unsupported optim: {}".format(optim))
|