Hiera weights on hub

This commit is contained in:
Ross Wightman 2024-05-13 11:43:22 -07:00
parent c838c4233f
commit 7a4e987b9f

View File

@ -803,62 +803,62 @@ def _cfg(url='', **kwargs):
default_cfgs = generate_default_cfgs({
"hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_tiny_224.mae": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_small_224.mae_in1k_ft_in1k": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_small_224.mae": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_base_224.mae_in1k_ft_in1k": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_base_224.mae": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_base_plus_224.mae_in1k_ft_in1k": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_base_plus_224.mae": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_large_224.mae_in1k_ft_in1k": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_large_224.mae": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
"hiera_huge_224.mae_in1k_ft_in1k": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
),
"hiera_huge_224.mae": _cfg(
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
#hf_hb='timm/',
hf_hub_id='timm/',
license='cc-by-nc-4.0',
num_classes=0,
),
})
@ -880,7 +880,9 @@ def checkpoint_filter_fn(state_dict, model=None):
pass
if 'head.projection.' in k:
k = k.replace('head.projection.', 'head.fc.')
if k.startswith('norm.'):
if k.startswith('encoder_norm.'):
k = k.replace('encoder_norm.', 'head.norm.')
elif k.startswith('norm.'):
k = k.replace('norm.', 'head.norm.')
output[k] = v
return output
@ -893,7 +895,6 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
Hiera,
variant,
pretrained,
#pretrained_strict=False,
pretrained_filter_fn=checkpoint_filter_fn,
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
**kwargs,