Add --pretrained-path arg to train script to allow passing local checkpoint as pretrained. Add missing/unexpected keys log.

This commit is contained in:
Ross Wightman 2023-12-11 09:13:22 -08:00 committed by Ross Wightman
parent 17a47c0e35
commit 60b170b200
2 changed files with 18 additions and 2 deletions

View File

@ -234,7 +234,15 @@ def load_pretrained(
classifier_bias = state_dict[classifier_name + '.bias']
state_dict[classifier_name + '.bias'] = classifier_bias[label_offset:]
model.load_state_dict(state_dict, strict=strict)
load_result = model.load_state_dict(state_dict, strict=strict)
if load_result.missing_keys:
_logger.info(
f'Missing keys ({", ".join(load_result.missing_keys)}) discovered while loading pretrained weights.'
f' This is expected if model is being adapted.')
if load_result.unexpected_keys:
_logger.warning(
f'Unexpected keys ({", ".join(load_result.unexpected_keys)}) found while loading pretrained weights.'
f' This may be expected if model is being adapted.')
def pretrained_cfg_for_features(pretrained_cfg):

View File

@ -103,8 +103,10 @@ group.add_argument('--model', default='resnet50', type=str, metavar='MODEL',
help='Name of model to train (default: "resnet50")')
group.add_argument('--pretrained', action='store_true', default=False,
help='Start with pretrained version of specified network (if avail)')
group.add_argument('--pretrained-path', default=None, type=str,
help='Load this checkpoint as if they were the pretrained weights (with adaptation).')
group.add_argument('--initial-checkpoint', default='', type=str, metavar='PATH',
help='Initialize model from this checkpoint (default: none)')
help='Load this checkpoint into model after initialization (default: none)')
group.add_argument('--resume', default='', type=str, metavar='PATH',
help='Resume full model and optimizer state from checkpoint (default: none)')
group.add_argument('--no-resume-opt', action='store_true', default=False,
@ -420,6 +422,11 @@ def main():
elif args.input_size is not None:
in_chans = args.input_size[0]
factory_kwargs = {}
if args.pretrained_path:
# merge with pretrained_cfg of model, 'file' has priority over 'url' and 'hf_hub'.
factory_kwargs['pretrained_cfg_overlay'] = dict(file=args.pretrained_path)
model = create_model(
args.model,
pretrained=args.pretrained,
@ -433,6 +440,7 @@ def main():
bn_eps=args.bn_eps,
scriptable=args.torchscript,
checkpoint_path=args.initial_checkpoint,
**factory_kwargs,
**args.model_kwargs,
)
if args.head_init_scale is not None: