mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add --pretrained-path arg to train script to allow passing local checkpoint as pretrained. Add missing/unexpected keys log.
This commit is contained in:
parent
17a47c0e35
commit
60b170b200
@ -234,7 +234,15 @@ def load_pretrained(
|
|||||||
classifier_bias = state_dict[classifier_name + '.bias']
|
classifier_bias = state_dict[classifier_name + '.bias']
|
||||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||||
|
|
||||||
model.load_state_dict(state_dict, strict=strict)
|
load_result = model.load_state_dict(state_dict, strict=strict)
|
||||||
|
if load_result.missing_keys:
|
||||||
|
_logger.info(
|
||||||
|
f'Missing keys ({", ".join(load_result.missing_keys)}) discovered while loading pretrained weights.'
|
||||||
|
f' This is expected if model is being adapted.')
|
||||||
|
if load_result.unexpected_keys:
|
||||||
|
_logger.warning(
|
||||||
|
f'Unexpected keys ({", ".join(load_result.unexpected_keys)}) found while loading pretrained weights.'
|
||||||
|
f' This may be expected if model is being adapted.')
|
||||||
|
|
||||||
|
|
||||||
def pretrained_cfg_for_features(pretrained_cfg):
|
def pretrained_cfg_for_features(pretrained_cfg):
|
||||||
|
10
train.py
10
train.py
@ -103,8 +103,10 @@ group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
|
|||||||
help='Name of model to train (default: "resnet50")')
|
help='Name of model to train (default: "resnet50")')
|
||||||
group.add_argument('--pretrained', action='store_true', default=False,
|
group.add_argument('--pretrained', action='store_true', default=False,
|
||||||
help='Start with pretrained version of specified network (if avail)')
|
help='Start with pretrained version of specified network (if avail)')
|
||||||
|
group.add_argument('--pretrained-path', default=None, type=str,
|
||||||
|
help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
|
||||||
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
||||||
help='Initialize model from this checkpoint (default: none)')
|
help='Load this checkpoint into model after initialization (default: none)')
|
||||||
group.add_argument('--resume', default='', type=str, metavar='PATH',
|
group.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||||
help='Resume full model and optimizer state from checkpoint (default: none)')
|
help='Resume full model and optimizer state from checkpoint (default: none)')
|
||||||
group.add_argument('--no-resume-opt', action='store_true', default=False,
|
group.add_argument('--no-resume-opt', action='store_true', default=False,
|
||||||
@ -420,6 +422,11 @@ def main():
|
|||||||
elif args.input_size is not None:
|
elif args.input_size is not None:
|
||||||
in_chans = args.input_size[0]
|
in_chans = args.input_size[0]
|
||||||
|
|
||||||
|
factory_kwargs = {}
|
||||||
|
if args.pretrained_path:
|
||||||
|
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
|
||||||
|
factory_kwargs['pretrained_cfg_overlay'] = dict(file=args.pretrained_path)
|
||||||
|
|
||||||
model = create_model(
|
model = create_model(
|
||||||
args.model,
|
args.model,
|
||||||
pretrained=args.pretrained,
|
pretrained=args.pretrained,
|
||||||
@ -433,6 +440,7 @@ def main():
|
|||||||
bn_eps=args.bn_eps,
|
bn_eps=args.bn_eps,
|
||||||
scriptable=args.torchscript,
|
scriptable=args.torchscript,
|
||||||
checkpoint_path=args.initial_checkpoint,
|
checkpoint_path=args.initial_checkpoint,
|
||||||
|
**factory_kwargs,
|
||||||
**args.model_kwargs,
|
**args.model_kwargs,
|
||||||
)
|
)
|
||||||
if args.head_init_scale is not None:
|
if args.head_init_scale is not None:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user