mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup ijepa models, they're just gap (global-avg-pool) models w/o heads. fc-norm conversion was wrong, gigantic should have been giant
This commit is contained in:
parent
49a459e8f1
commit
e728f3efdb
@ -950,17 +950,6 @@ def _convert_dinov2(state_dict, model):
|
|||||||
return out_dict
|
return out_dict
|
||||||
|
|
||||||
|
|
||||||
def _convert_ijepa(state_dict, model):
|
|
||||||
out_dict = {}
|
|
||||||
for k, v in state_dict['encoder'].items():
|
|
||||||
if k.startswith('module.'):
|
|
||||||
k = k[7:]
|
|
||||||
if k.startswith('norm.'):
|
|
||||||
k = 'fc_norm.' + k[5:]
|
|
||||||
out_dict[k] = v
|
|
||||||
return out_dict
|
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(
|
def checkpoint_filter_fn(
|
||||||
state_dict,
|
state_dict,
|
||||||
model,
|
model,
|
||||||
@ -973,6 +962,7 @@ def checkpoint_filter_fn(
|
|||||||
out_dict = {}
|
out_dict = {}
|
||||||
state_dict = state_dict.get('model', state_dict)
|
state_dict = state_dict.get('model', state_dict)
|
||||||
state_dict = state_dict.get('state_dict', state_dict)
|
state_dict = state_dict.get('state_dict', state_dict)
|
||||||
|
prefix = ''
|
||||||
|
|
||||||
if 'visual.class_embedding' in state_dict:
|
if 'visual.class_embedding' in state_dict:
|
||||||
return _convert_openai_clip(state_dict, model)
|
return _convert_openai_clip(state_dict, model)
|
||||||
@ -981,13 +971,17 @@ def checkpoint_filter_fn(
|
|||||||
state_dict = _convert_dinov2(state_dict, model)
|
state_dict = _convert_dinov2(state_dict, model)
|
||||||
|
|
||||||
if "encoder" in state_dict:
|
if "encoder" in state_dict:
|
||||||
state_dict = _convert_ijepa(state_dict, model)
|
state_dict = state_dict['encoder']
|
||||||
|
prefix = 'module.'
|
||||||
|
|
||||||
if 'visual.trunk.pos_embed' in state_dict:
|
if 'visual.trunk.pos_embed' in state_dict:
|
||||||
# convert an OpenCLIP model with timm vision encoder
|
# convert an OpenCLIP model with timm vision encoder
|
||||||
prefix = 'visual.trunk.'
|
|
||||||
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
|
||||||
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
|
||||||
|
prefix = 'visual.trunk.'
|
||||||
|
|
||||||
|
if prefix:
|
||||||
|
# filter on & remove prefix string from keys
|
||||||
|
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
|
||||||
|
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if 'patch_embed.proj.weight' in k:
|
if 'patch_embed.proj.weight' in k:
|
||||||
@ -1529,23 +1523,23 @@ default_cfgs = generate_default_cfgs({
|
|||||||
license='cc-by-nc-4.0',
|
license='cc-by-nc-4.0',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
|
|
||||||
'vit_huge_patch14_ijepa_224.in1k': _cfg(
|
'vit_huge_patch14_gap_224.in1k_ijepa': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
|
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
|
||||||
# hf_hub_id='timm/',
|
# hf_hub_id='timm/',
|
||||||
license='cc-by-nc-4.0',
|
license='cc-by-nc-4.0',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
'vit_huge_patch14_ijepa_224.in22k': _cfg(
|
'vit_huge_patch14_gap_224.in22k_ijepa': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
|
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
|
||||||
# hf_hub_id='timm/',
|
# hf_hub_id='timm/',
|
||||||
license='cc-by-nc-4.0',
|
license='cc-by-nc-4.0',
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
'vit_huge_patch16_ijepa_448.in1k': _cfg(
|
'vit_huge_patch16_gap_448.in1k_ijepa': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
|
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
|
||||||
# hf_hub_id='timm/',
|
# hf_hub_id='timm/',
|
||||||
license='cc-by-nc-4.0',
|
license='cc-by-nc-4.0',
|
||||||
input_size=(3, 448, 448), crop_pct=1.0,
|
input_size=(3, 448, 448), crop_pct=1.0,
|
||||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||||
'vit_gigantic_patch16_ijepa_224.in22k': _cfg(
|
'vit_giant_patch16_gap_224.in22k_ijepa': _cfg(
|
||||||
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
|
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
|
||||||
# hf_hub_id='timm/',
|
# hf_hub_id='timm/',
|
||||||
license='cc-by-nc-4.0',
|
license='cc-by-nc-4.0',
|
||||||
@ -1856,7 +1850,7 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
|
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
|
||||||
"""
|
"""
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||||
@ -1865,6 +1859,40 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_huge_patch14_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
|
""" ViT-Huge model (ViT-H/14) w/ no class token, avg pool
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_huge_patch16_gap_448(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
|
""" ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def vit_giant_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
|
""" ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool
|
||||||
|
"""
|
||||||
|
model_args = dict(
|
||||||
|
patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
|
||||||
|
class_token=False, global_pool='avg', fc_norm=False)
|
||||||
|
model = _create_vision_transformer(
|
||||||
|
'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
""" ViT-B/32 CLIP image tower @ 224x224
|
""" ViT-B/32 CLIP image tower @ 224x224
|
||||||
@ -2190,37 +2218,6 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
|||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_huge_patch14_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|
||||||
""" ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
|
|
||||||
"""
|
|
||||||
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg')
|
|
||||||
model = _create_vision_transformer(
|
|
||||||
'vit_huge_patch14_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_huge_patch16_ijepa_448(pretrained=False, **kwargs) -> VisionTransformer:
|
|
||||||
""" ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
|
|
||||||
"""
|
|
||||||
model_args = dict(
|
|
||||||
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448)
|
|
||||||
model = _create_vision_transformer(
|
|
||||||
'vit_huge_patch16_ijepa_448', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
|
||||||
def vit_gigantic_patch16_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|
||||||
""" ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
|
|
||||||
"""
|
|
||||||
model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
|
|
||||||
model = _create_vision_transformer(
|
|
||||||
'vit_gigantic_patch16_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
|
|
||||||
return model
|
|
||||||
|
|
||||||
|
|
||||||
@register_model
|
@register_model
|
||||||
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||||
model_args = dict(
|
model_args = dict(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user