mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Hiera weights on hub
This commit is contained in:
parent
c838c4233f
commit
7a4e987b9f
@ -803,62 +803,62 @@ def _cfg(url='', **kwargs):
|
|||||||
|
|
||||||
default_cfgs = generate_default_cfgs({
|
default_cfgs = generate_default_cfgs({
|
||||||
"hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
|
"hiera_tiny_224.mae_in1k_ft_in1k": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/hiera_tiny_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
),
|
),
|
||||||
"hiera_tiny_224.mae": _cfg(
|
"hiera_tiny_224.mae": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_tiny_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
num_classes=0,
|
num_classes=0,
|
||||||
),
|
),
|
||||||
|
|
||||||
"hiera_small_224.mae_in1k_ft_in1k": _cfg(
|
"hiera_small_224.mae_in1k_ft_in1k": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/hiera_small_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
),
|
),
|
||||||
"hiera_small_224.mae": _cfg(
|
"hiera_small_224.mae": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_small_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
num_classes=0,
|
num_classes=0,
|
||||||
),
|
),
|
||||||
|
|
||||||
"hiera_base_224.mae_in1k_ft_in1k": _cfg(
|
"hiera_base_224.mae_in1k_ft_in1k": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/hiera_base_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
),
|
),
|
||||||
"hiera_base_224.mae": _cfg(
|
"hiera_base_224.mae": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
num_classes=0,
|
num_classes=0,
|
||||||
),
|
),
|
||||||
|
|
||||||
"hiera_base_plus_224.mae_in1k_ft_in1k": _cfg(
|
"hiera_base_plus_224.mae_in1k_ft_in1k": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/hiera_base_plus_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
),
|
),
|
||||||
"hiera_base_plus_224.mae": _cfg(
|
"hiera_base_plus_224.mae": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_base_plus_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
num_classes=0,
|
num_classes=0,
|
||||||
),
|
),
|
||||||
|
|
||||||
"hiera_large_224.mae_in1k_ft_in1k": _cfg(
|
"hiera_large_224.mae_in1k_ft_in1k": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/hiera_large_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
),
|
),
|
||||||
"hiera_large_224.mae": _cfg(
|
"hiera_large_224.mae": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_large_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
num_classes=0,
|
num_classes=0,
|
||||||
),
|
),
|
||||||
|
|
||||||
"hiera_huge_224.mae_in1k_ft_in1k": _cfg(
|
"hiera_huge_224.mae_in1k_ft_in1k": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/hiera_huge_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
),
|
),
|
||||||
"hiera_huge_224.mae": _cfg(
|
"hiera_huge_224.mae": _cfg(
|
||||||
url="https://dl.fbaipublicfiles.com/hiera/mae_hiera_huge_224.pth",
|
hf_hub_id='timm/',
|
||||||
#hf_hb='timm/',
|
license='cc-by-nc-4.0',
|
||||||
num_classes=0,
|
num_classes=0,
|
||||||
),
|
),
|
||||||
})
|
})
|
||||||
@ -880,7 +880,9 @@ def checkpoint_filter_fn(state_dict, model=None):
|
|||||||
pass
|
pass
|
||||||
if 'head.projection.' in k:
|
if 'head.projection.' in k:
|
||||||
k = k.replace('head.projection.', 'head.fc.')
|
k = k.replace('head.projection.', 'head.fc.')
|
||||||
if k.startswith('norm.'):
|
if k.startswith('encoder_norm.'):
|
||||||
|
k = k.replace('encoder_norm.', 'head.norm.')
|
||||||
|
elif k.startswith('norm.'):
|
||||||
k = k.replace('norm.', 'head.norm.')
|
k = k.replace('norm.', 'head.norm.')
|
||||||
output[k] = v
|
output[k] = v
|
||||||
return output
|
return output
|
||||||
@ -893,7 +895,6 @@ def _create_hiera(variant: str, pretrained: bool = False, **kwargs) -> Hiera:
|
|||||||
Hiera,
|
Hiera,
|
||||||
variant,
|
variant,
|
||||||
pretrained,
|
pretrained,
|
||||||
#pretrained_strict=False,
|
|
||||||
pretrained_filter_fn=checkpoint_filter_fn,
|
pretrained_filter_fn=checkpoint_filter_fn,
|
||||||
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
feature_cfg=dict(out_indices=out_indices, feature_cls='getter'),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user