mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add --pretrained-path arg to train script to allow passing local checkpoint as pretrained. Add missing/unexpected keys log.
This commit is contained in:
parent
17a47c0e35
commit
60b170b200
@ -234,7 +234,15 @@ def load_pretrained(
|
||||
classifier_bias = state_dict[classifier_name + '.bias']
|
||||
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
|
||||
|
||||
model.load_state_dict(state_dict, strict=strict)
|
||||
load_result = model.load_state_dict(state_dict, strict=strict)
|
||||
if load_result.missing_keys:
|
||||
_logger.info(
|
||||
f'Missing keys ({", ".join(load_result.missing_keys)}) discovered while loading pretrained weights.'
|
||||
f' This is expected if model is being adapted.')
|
||||
if load_result.unexpected_keys:
|
||||
_logger.warning(
|
||||
f'Unexpected keys ({", ".join(load_result.unexpected_keys)}) found while loading pretrained weights.'
|
||||
f' This may be expected if model is being adapted.')
|
||||
|
||||
|
||||
def pretrained_cfg_for_features(pretrained_cfg):
|
||||
|
10
train.py
10
train.py
@ -103,8 +103,10 @@ group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
|
||||
help='Name of model to train (default: "resnet50")')
|
||||
group.add_argument('--pretrained', action='store_true', default=False,
|
||||
help='Start with pretrained version of specified network (if avail)')
|
||||
group.add_argument('--pretrained-path', default=None, type=str,
|
||||
help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
|
||||
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
|
||||
help='Initialize model from this checkpoint (default: none)')
|
||||
help='Load this checkpoint into model after initialization (default: none)')
|
||||
group.add_argument('--resume', default='', type=str, metavar='PATH',
|
||||
help='Resume full model and optimizer state from checkpoint (default: none)')
|
||||
group.add_argument('--no-resume-opt', action='store_true', default=False,
|
||||
@ -420,6 +422,11 @@ def main():
|
||||
elif args.input_size is not None:
|
||||
in_chans = args.input_size[0]
|
||||
|
||||
factory_kwargs = {}
|
||||
if args.pretrained_path:
|
||||
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
|
||||
factory_kwargs['pretrained_cfg_overlay'] = dict(file=args.pretrained_path)
|
||||
|
||||
model = create_model(
|
||||
args.model,
|
||||
pretrained=args.pretrained,
|
||||
@ -433,6 +440,7 @@ def main():
|
||||
bn_eps=args.bn_eps,
|
||||
scriptable=args.torchscript,
|
||||
checkpoint_path=args.initial_checkpoint,
|
||||
**factory_kwargs,
|
||||
**args.model_kwargs,
|
||||
)
|
||||
if args.head_init_scale is not None:
|
||||
|
Loading…
x
Reference in New Issue
Block a user