mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add model based wd skip support. Improve cross version compat of optimizer factory. Fix #247
This commit is contained in:
parent
80078c47bb
commit
a4d8fea61e
@ -41,7 +41,10 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||
opt_lower = args.opt.lower()
|
||||
weight_decay = args.weight_decay
|
||||
if weight_decay and filter_bias_and_bn:
|
||||
parameters = add_weight_decay(model, weight_decay)
|
||||
skip = {}
|
||||
if hasattr(model, 'no_weight_decay'):
|
||||
skip = model.no_weight_decay
|
||||
parameters = add_weight_decay(model, weight_decay, skip)
|
||||
weight_decay = 0.
|
||||
else:
|
||||
parameters = model.parameters()
|
||||
@ -50,9 +53,9 @@ def create_optimizer(args, model, filter_bias_and_bn=True):
|
||||
assert has_apex and torch.cuda.is_available(), 'APEX and CUDA required for fused optimizers'
|
||||
|
||||
opt_args = dict(lr=args.lr, weight_decay=weight_decay)
|
||||
if args.opt_eps is not None:
|
||||
if hasattr(args, 'opt_eps') and args.opt_eps is not None:
|
||||
opt_args['eps'] = args.opt_eps
|
||||
if args.opt_betas is not None:
|
||||
if hasattr(args, 'opt_betas') and args.opt_betas is not None:
|
||||
opt_args['betas'] = args.opt_betas
|
||||
|
||||
opt_split = opt_lower.split('_')
|
||||
|
Loading…
x
Reference in New Issue
Block a user