diff --git a/timm/models/hiera.py b/timm/models/hiera.py index fb4746e9..f229daf4 100644 --- a/timm/models/hiera.py +++ b/timm/models/hiera.py @@ -803,62 +803,62 @@ def _cfg(url='', **kwargs): default_cfgs = generate_default_cfgs({ "hiera_tiny_224.mae_in1k_ft_in1k": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', ), "hiera_tiny_224.mae": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', num_classes=0, ), "hiera_small_224.mae_in1k_ft_in1k": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', ), "hiera_small_224.mae": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', num_classes=0, ), "hiera_base_224.mae_in1k_ft_in1k": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', ), "hiera_base_224.mae": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', num_classes=0, ), "hiera_base_plus_224.mae_in1k_ft_in1k": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', ), "hiera_base_plus_224.mae": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', num_classes=0, ), "hiera_large_224.mae_in1k_ft_in1k": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', ), "hiera_large_224.mae": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', num_classes=0, ), "hiera_huge_224.mae_in1k_ft_in1k": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', ), "hiera_huge_224.mae": _cfg( - url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth", - #hf_hb='timm/', + hf_hub_id='timm/', + license='cc-by-nc-4.0', num_classes=0, ), }) @@ -880,7 +880,9 @@ def checkpoint_filter_fn(state_dict, model=None): pass if 'head.projection.' in k: k = k.replace('head.projection.', 'head.fc.') - if k.startswith('norm.'): + if k.startswith('encoder_norm.'): + k = k.replace('encoder_norm.', 'head.norm.') + elif k.startswith('norm.'): k = k.replace('norm.', 'head.norm.') output[k] = v return output @@ -893,7 +895,6 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera: Hiera, variant, pretrained, - #pretrained_strict=False, pretrained_filter_fn=checkpoint_filter_fn, feature_cfg=dict(out_indices=out_indices, feature_cls='getter'), **kwargs,