mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix optimizer factory regressin for optimizers like sgd/momentum that don't have an eps arg
This commit is contained in:
parent
27a93e9de7
commit
477a78ed81
@ -61,8 +61,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|||||||
opt_split = opt_lower.split('_')
|
opt_split = opt_lower.split('_')
|
||||||
opt_lower = opt_split[-1]
|
opt_lower = opt_split[-1]
|
||||||
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
if opt_lower == 'sgd' or opt_lower == 'nesterov':
|
||||||
|
del opt_args['eps']
|
||||||
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
||||||
elif opt_lower == 'momentum':
|
elif opt_lower == 'momentum':
|
||||||
|
del opt_args['eps']
|
||||||
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
optimizer = optim.SGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
||||||
elif opt_lower == 'adam':
|
elif opt_lower == 'adam':
|
||||||
optimizer = optim.Adam(parameters, **opt_args)
|
optimizer = optim.Adam(parameters, **opt_args)
|
||||||
@ -93,8 +95,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
|||||||
elif opt_lower == 'nvnovograd':
|
elif opt_lower == 'nvnovograd':
|
||||||
optimizer = NvNovoGrad(parameters, **opt_args)
|
optimizer = NvNovoGrad(parameters, **opt_args)
|
||||||
elif opt_lower == 'fusedsgd':
|
elif opt_lower == 'fusedsgd':
|
||||||
|
del opt_args['eps']
|
||||||
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=True, **opt_args)
|
||||||
elif opt_lower == 'fusedmomentum':
|
elif opt_lower == 'fusedmomentum':
|
||||||
|
del opt_args['eps']
|
||||||
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
optimizer = FusedSGD(parameters, momentum=args.momentum, nesterov=False, **opt_args)
|
||||||
elif opt_lower == 'fusedadam':
|
elif opt_lower == 'fusedadam':
|
||||||
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
optimizer = FusedAdam(parameters, adam_w_mode=False, **opt_args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user