Cleanup ijepa models, they're just gap (global-avg-pool) models w/o heads. fc-norm conversion was wrong, gigantic should have been giant

This commit is contained in:
Ross Wightman 2023-10-17 15:44:46 -07:00
parent 49a459e8f1
commit e728f3efdb

View File

@ -950,17 +950,6 @@ def _convert_dinov2(state_dict, model):
return out_dict
def _convert_ijepa(state_dict, model):
out_dict = {}
for k, v in state_dict['encoder'].items():
if k.startswith('module.'):
k = k[7:]
if k.startswith('norm.'):
k = 'fc_norm.' + k[5:]
out_dict[k] = v
return out_dict
def checkpoint_filter_fn(
state_dict,
model,
@ -973,6 +962,7 @@ def checkpoint_filter_fn(
out_dict = {}
state_dict = state_dict.get('model', state_dict)
state_dict = state_dict.get('state_dict', state_dict)
prefix = ''
if 'visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model)
@ -981,13 +971,17 @@ def checkpoint_filter_fn(
state_dict = _convert_dinov2(state_dict, model)
if "encoder" in state_dict:
state_dict = _convert_ijepa(state_dict, model)
state_dict = state_dict['encoder']
prefix = 'module.'
if 'visual.trunk.pos_embed' in state_dict:
# convert an OpenCLIP model with timm vision encoder
prefix = 'visual.trunk.'
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
# FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj)
prefix = 'visual.trunk.'
if prefix:
# filter on & remove prefix string from keys
state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)}
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
@ -1529,23 +1523,23 @@ default_cfgs = generate_default_cfgs({
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch14_ijepa_224.in1k': _cfg(
'vit_huge_patch14_gap_224.in1k_ijepa': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch14_ijepa_224.in22k': _cfg(
'vit_huge_patch14_gap_224.in22k_ijepa': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch16_ijepa_448.in1k': _cfg(
'vit_huge_patch16_gap_448.in1k_ijepa': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
input_size=(3, 448, 448), crop_pct=1.0,
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_gigantic_patch16_ijepa_224.in22k': _cfg(
'vit_giant_patch16_gap_224.in22k_ijepa': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
@ -1856,7 +1850,7 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer:
@register_model
def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256
""" ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224
"""
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
@ -1865,6 +1859,40 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
return model
@register_model
def vit_huge_patch14_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) w/ no class token, avg pool
"""
model_args = dict(
patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
model = _create_vision_transformer(
'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch16_gap_448(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448
"""
model_args = dict(
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False)
model = _create_vision_transformer(
'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_giant_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool
"""
model_args = dict(
patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11,
class_token=False, global_pool='avg', fc_norm=False)
model = _create_vision_transformer(
'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/32 CLIP image tower @ 224x224
@ -2190,37 +2218,6 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
return model
@register_model
def vit_huge_patch14_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
"""
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg')
model = _create_vision_transformer(
'vit_huge_patch14_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch16_ijepa_448(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
"""
model_args = dict(
patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448)
model = _create_vision_transformer(
'vit_huge_patch16_ijepa_448', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_gigantic_patch16_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
"""
model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
model = _create_vision_transformer(
'vit_gigantic_patch16_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(