Merge pull request #2142 from huggingface/fix_hrnet_head

Fix #2139, disable strict weight loading when head changes from classification
pull/2162/head
Ross Wightman 2024-04-09 10:04:09 -07:00 committed by GitHub
commit 9531eb793c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
1 changed files with 6 additions and 1 deletions

View File

@ -862,12 +862,17 @@ def _create_hrnet(variant, pretrained=False, cfg_variant=None, **model_kwargs):
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
cfg_variant = cfg_variant or variant
pretrained_strict = model_kwargs.pop(
'pretrained_strict',
not features_only and model_kwargs.get('head', 'classification') == 'classification'
)
model = build_model_with_cfg(
model_cls,
variant,
pretrained,
model_cfg=cfg_cls[cfg_variant],
pretrained_strict=not features_only,
pretrained_strict=pretrained_strict,
kwargs_filter=kwargs_filter,
**model_kwargs,
)