mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix pretrained override logic for validate, checkpoint always trump pretrained flag during model create
This commit is contained in:
parent
0e1fd11ad8
commit
b9f8d40b10
@ -36,7 +36,7 @@ def create_model(
|
|||||||
else:
|
else:
|
||||||
raise RuntimeError('Unknown model (%s)' % model_name)
|
raise RuntimeError('Unknown model (%s)' % model_name)
|
||||||
|
|
||||||
if checkpoint_path and not pretrained:
|
if checkpoint_path:
|
||||||
load_checkpoint(model, checkpoint_path)
|
load_checkpoint(model, checkpoint_path)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
@ -54,6 +54,8 @@ parser.add_argument('--use-ema', dest='use_ema', action='store_true',
|
|||||||
|
|
||||||
|
|
||||||
def validate(args):
|
def validate(args):
|
||||||
|
# might as well try to validate something
|
||||||
|
args.pretrained = args.pretrained or not args.checkpoint
|
||||||
|
|
||||||
# create model
|
# create model
|
||||||
model = create_model(
|
model = create_model(
|
||||||
@ -62,10 +64,8 @@ def validate(args):
|
|||||||
in_chans=3,
|
in_chans=3,
|
||||||
pretrained=args.pretrained)
|
pretrained=args.pretrained)
|
||||||
|
|
||||||
if args.checkpoint and not args.pretrained:
|
if args.checkpoint:
|
||||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||||
else:
|
|
||||||
args.pretrained = True # might as well try to validate something...
|
|
||||||
|
|
||||||
param_count = sum([m.numel() for m in model.parameters()])
|
param_count = sum([m.numel() for m in model.parameters()])
|
||||||
print('Model %s created, param count: %d' % (args.model, param_count))
|
print('Model %s created, param count: %d' % (args.model, param_count))
|
||||||
|
Loading…
x
Reference in New Issue
Block a user