Fix #2139, disable strict weight loading when head changes from classification

This commit is contained in:
Ross Wightman 2024-04-09 08:40:52 -07:00
parent 59b3d86c1d
commit 17b892f703

View File

@ -862,12 +862,17 @@ def _create_hrnet(variant, pretrained=False, cfg_variant=None, **model_kwargs):
kwargs_filter = ('num_classes', 'global_pool') kwargs_filter = ('num_classes', 'global_pool')
features_only = True features_only = True
cfg_variant = cfg_variant or variant cfg_variant = cfg_variant or variant
pretrained_strict = model_kwargs.pop(
'pretrained_strict',
not features_only and model_kwargs.get('head', 'classification') == 'classification'
)
model = build_model_with_cfg( model = build_model_with_cfg(
model_cls, model_cls,
variant, variant,
pretrained, pretrained,
model_cfg=cfg_cls[cfg_variant], model_cfg=cfg_cls[cfg_variant],
pretrained_strict=not features_only, pretrained_strict=pretrained_strict,
kwargs_filter=kwargs_filter, kwargs_filter=kwargs_filter,
**model_kwargs, **model_kwargs,
) )