Fix #2139, disable strict weight loading when head changes from classification

This commit is contained in:
Ross Wightman 2024-04-09 08:40:52 -07:00
parent 59b3d86c1d
commit 17b892f703

View File

@ -862,12 +862,17 @@ def _create_hrnet(variant, pretrained=False, cfg_variant=None, **model_kwargs):
kwargs_filter = ('num_classes', 'global_pool')
features_only = True
cfg_variant = cfg_variant or variant
pretrained_strict = model_kwargs.pop(
'pretrained_strict',
not features_only and model_kwargs.get('head', 'classification') == 'classification'
)
model = build_model_with_cfg(
model_cls,
variant,
pretrained,
model_cfg=cfg_cls[cfg_variant],
pretrained_strict=not features_only,
pretrained_strict=pretrained_strict,
kwargs_filter=kwargs_filter,
**model_kwargs,
)