diff --git a/timm/models/hrnet.py b/timm/models/hrnet.py index 20ea7674..d00adfa1 100644 --- a/timm/models/hrnet.py +++ b/timm/models/hrnet.py @@ -862,12 +862,17 @@ def _create_hrnet(variant, pretrained=False, cfg_variant=None, **model_kwargs): kwargs_filter = ('num_classes', 'global_pool') features_only = True cfg_variant = cfg_variant or variant + + pretrained_strict = model_kwargs.pop( + 'pretrained_strict', + not features_only and model_kwargs.get('head', 'classification') == 'classification' + ) model = build_model_with_cfg( model_cls, variant, pretrained, model_cfg=cfg_cls[cfg_variant], - pretrained_strict=not features_only, + pretrained_strict=pretrained_strict, kwargs_filter=kwargs_filter, **model_kwargs, )