mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2142 from huggingface/fix_hrnet_head
Fix #2139, disable strict weight loading when head changes from classification
This commit is contained in:
commit
9531eb793c
@ -862,12 +862,17 @@ def _create_hrnet(variant, pretrained=False, cfg_variant=None, **model_kwargs):
|
|||||||
kwargs_filter = ('num_classes', 'global_pool')
|
kwargs_filter = ('num_classes', 'global_pool')
|
||||||
features_only = True
|
features_only = True
|
||||||
cfg_variant = cfg_variant or variant
|
cfg_variant = cfg_variant or variant
|
||||||
|
|
||||||
|
pretrained_strict = model_kwargs.pop(
|
||||||
|
'pretrained_strict',
|
||||||
|
not features_only and model_kwargs.get('head', 'classification') == 'classification'
|
||||||
|
)
|
||||||
model = build_model_with_cfg(
|
model = build_model_with_cfg(
|
||||||
model_cls,
|
model_cls,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
model_cfg=cfg_cls[cfg_variant],
|
model_cfg=cfg_cls[cfg_variant],
|
||||||
pretrained_strict=not features_only,
|
pretrained_strict=pretrained_strict,
|
||||||
kwargs_filter=kwargs_filter,
|
kwargs_filter=kwargs_filter,
|
||||||
**model_kwargs,
|
**model_kwargs,
|
||||||
)
|
)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user