pytorch-image-models/optim/optim_factory.py
Ross Wightman 9c3859fb9c Uniform pretrained model handling.
* All models have 'default_cfgs' dict
* load/resume/pretrained helpers factored out
* pretrained load operates on state_dict based on default_cfg
* test all models in validate
* schedule, optim factor factored out
* test time pool wrapper applied based on default_cfg
2019-04-11 21:32:16 -07:00

31 lines
1.2 KiB
Python

from torch import optim as optim
from optim import Nadam, AdaBound
def create_optimizer(args, parameters):
if args.opt.lower() == 'sgd':
optimizer = optim.SGD(
parameters, lr=args.lr,
momentum=args.momentum, weight_decay=args.weight_decay, nesterov=True)
elif args.opt.lower() == 'adam':
optimizer = optim.Adam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'nadam':
optimizer = Nadam(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'adabound':
optimizer = AdaBound(
parameters, lr=args.lr / 100, weight_decay=args.weight_decay, eps=args.opt_eps,
final_lr=args.lr)
elif args.opt.lower() == 'adadelta':
optimizer = optim.Adadelta(
parameters, lr=args.lr, weight_decay=args.weight_decay, eps=args.opt_eps)
elif args.opt.lower() == 'rmsprop':
optimizer = optim.RMSprop(
parameters, lr=args.lr, alpha=0.9, eps=args.opt_eps,
momentum=args.momentum, weight_decay=args.weight_decay)
else:
assert False and "Invalid optimizer"
raise ValueError
return optimizer