mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix potential issue with change to num_classes arg in train/validate.py defaulting to None (rely on model def / default_cfg)
This commit is contained in:
parent
587780e56b
commit
38d8f67570
@ -198,6 +198,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
||||
|
||||
classifier_name = cfg['classifier']
|
||||
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
||||
# FIXME this special case is problematic as number of pretrained weight sources increases
|
||||
# special case for imagenet trained models with extra background class in pretrained weights
|
||||
classifier_weight = state_dict[classifier_name + '.weight']
|
||||
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
||||
|
3
train.py
3
train.py
@ -337,6 +337,9 @@ def main():
|
||||
bn_eps=args.bn_eps,
|
||||
scriptable=args.torchscript,
|
||||
checkpoint_path=args.initial_checkpoint)
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
||||
|
||||
if args.local_rank == 0:
|
||||
_logger.info('Model %s created, param count: %d' %
|
||||
|
@ -137,6 +137,9 @@ def validate(args):
|
||||
in_chans=3,
|
||||
global_pool=args.gp,
|
||||
scriptable=args.torchscript)
|
||||
if args.num_classes is None:
|
||||
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||
args.num_classes = model.num_classes
|
||||
|
||||
if args.checkpoint:
|
||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||
|
Loading…
x
Reference in New Issue
Block a user