mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix potential issue with change to num_classes arg in train/validate.py defaulting to None (rely on model def / default_cfg)
This commit is contained in:
parent
587780e56b
commit
38d8f67570
@ -198,6 +198,7 @@ def load_pretrained(model, cfg=None, num_classes=1000, in_chans=3, filter_fn=Non
|
|||||||
|
|
||||||
classifier_name = cfg['classifier']
|
classifier_name = cfg['classifier']
|
||||||
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
if num_classes == 1000 and cfg['num_classes'] == 1001:
|
||||||
|
# FIXME this special case is problematic as number of pretrained weight sources increases
|
||||||
# special case for imagenet trained models with extra background class in pretrained weights
|
# special case for imagenet trained models with extra background class in pretrained weights
|
||||||
classifier_weight = state_dict[classifier_name + '.weight']
|
classifier_weight = state_dict[classifier_name + '.weight']
|
||||||
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
state_dict[classifier_name + '.weight'] = classifier_weight[1:]
|
||||||
|
3
train.py
3
train.py
@ -337,6 +337,9 @@ def main():
|
|||||||
bn_eps=args.bn_eps,
|
bn_eps=args.bn_eps,
|
||||||
scriptable=args.torchscript,
|
scriptable=args.torchscript,
|
||||||
checkpoint_path=args.initial_checkpoint)
|
checkpoint_path=args.initial_checkpoint)
|
||||||
|
if args.num_classes is None:
|
||||||
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||||
|
args.num_classes = model.num_classes # FIXME handle model default vs config num_classes more elegantly
|
||||||
|
|
||||||
if args.local_rank == 0:
|
if args.local_rank == 0:
|
||||||
_logger.info('Model %s created, param count: %d' %
|
_logger.info('Model %s created, param count: %d' %
|
||||||
|
@ -137,6 +137,9 @@ def validate(args):
|
|||||||
in_chans=3,
|
in_chans=3,
|
||||||
global_pool=args.gp,
|
global_pool=args.gp,
|
||||||
scriptable=args.torchscript)
|
scriptable=args.torchscript)
|
||||||
|
if args.num_classes is None:
|
||||||
|
assert hasattr(model, 'num_classes'), 'Model must have `num_classes` attr if not set on cmd line/config.'
|
||||||
|
args.num_classes = model.num_classes
|
||||||
|
|
||||||
if args.checkpoint:
|
if args.checkpoint:
|
||||||
load_checkpoint(model, args.checkpoint, args.use_ema)
|
load_checkpoint(model, args.checkpoint, args.use_ema)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user