mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup ijepa models, they're just gap (global-avg-pool) models w/o heads. fc-norm conversion was wrong, gigantic should have been giant
This commit is contained in:
parent
49a459e8f1
commit
e728f3efdb
@ -950,17 +950,6 @@ def _convert_dinov2(state_dict, model):
|
||||
return out_dict
|
||||
|
||||
|
||||
def _convert_ijepa(state_dict, model):
|
||||
out_dict = {}
|
||||
for k, v in state_dict['encoder'].items():
|
||||
if k.startswith('module.'):
|
||||
k = k[7:]
|
||||
if k.startswith('norm.'):
|
||||
k = 'fc_norm.' + k[5:]
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
def checkpoint_filter_fn(
|
||||
state_dict,
|
||||
model,
|
||||
@ -973,6 +962,7 @@ def checkpoint_filter_fn(
|
||||
out_dict = {}
|
||||
state_dict = state_dict.get('model', state_dict)
|
||||
state_dict = state_dict.get('state_dict', state_dict)
|
||||
prefix = ''
|
||||
|
||||
if 'visual.class_embedding' in state_dict:
|
||||
return _convert_openai_clip(state_dict, model)
|
||||
@ -981,13 +971,17 @@ def checkpoint_filter_fn(
|
||||
state_dict = _convert_dinov2(state_dict, model)
|
||||
|
||||
if "encoder" in state_dict:
|
||||
state_dict = _convert_ijepa(state_dict, model)
|
||||
state_dict = state_dict['encoder']
|
||||
prefix = 'module.'
|
||||
|
||||
if 'visual.trunk.pos_embed' in state_dict:
|
||||
# convert an OpenCLIP model with timm vision encoder
|
||||
prefix = 'visual.trunk.'
|
||||
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||
prefix = 'visual.trunk.'
|
||||
|
||||
if prefix:
|
||||
# filter on & remove prefix string from keys
|
||||
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
@ -1529,23 +1523,23 @@ default_cfgs = generate_default_cfgs({
|
||||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
|
||||
'vit_huge_patch14_ijepa_224.in1k': _cfg(
|
||||
'vit_huge_patch14_gap_224.in1k_ijepa': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
|
||||
# hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
'vit_huge_patch14_ijepa_224.in22k': _cfg(
|
||||
'vit_huge_patch14_gap_224.in22k_ijepa': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
|
||||
# hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
'vit_huge_patch16_ijepa_448.in1k': _cfg(
|
||||
'vit_huge_patch16_gap_448.in1k_ijepa': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
|
||||
# hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
input_size=(3, 448, 448), crop_pct=1.0,
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
'vit_gigantic_patch16_ijepa_224.in22k': _cfg(
|
||||
'vit_giant_patch16_gap_224.in22k_ijepa': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
|
||||
# hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
@ -1856,7 +1850,7 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
|
||||
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||
@ -1865,6 +1859,40 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/14) w/ no class token, avg pool
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch16_gap_448(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_giant_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
|
||||
class_token=False, global_pool='avg', fc_norm=False)
|
||||
model = _create_vision_transformer(
|
||||
'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-B/32 CLIP image tower @ 224x224
|
||||
@ -2190,37 +2218,6 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch14_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
|
||||
"""
|
||||
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg')
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch14_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_huge_patch16_ijepa_448(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448)
|
||||
model = _create_vision_transformer(
|
||||
'vit_huge_patch16_ijepa_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_gigantic_patch16_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
|
||||
"""
|
||||
model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
|
||||
model = _create_vision_transformer(
|
||||
'vit_gigantic_patch16_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
|
Loading…
x
Reference in New Issue
Block a user