From e728f3efdb1e5da816d3defa6dd5f60b2090a8ac Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Tue, 17 Oct 2023 15:44:46 -0700 Subject: [PATCH] Cleanup ijepa models, they're just gap (global-avg-pool) models w/o heads. fc-norm conversion was wrong, gigantic should have been giant --- timm/models/vision_transformer.py | 97 +++++++++++++++---------------- 1 file changed, 47 insertions(+), 50 deletions(-) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 2d374085..d56037c6 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -950,17 +950,6 @@ def _convert_dinov2(state_dict, model): return out_dict -def _convert_ijepa(state_dict, model): - out_dict = {} - for k, v in state_dict['encoder'].items(): - if k.startswith('module.'): - k = k[7:] - if k.startswith('norm.'): - k = 'fc_norm.' + k[5:] - out_dict[k] = v - return out_dict - - def checkpoint_filter_fn( state_dict, model, @@ -973,6 +962,7 @@ def checkpoint_filter_fn( out_dict = {} state_dict = state_dict.get('model', state_dict) state_dict = state_dict.get('state_dict', state_dict) + prefix = '' if 'visual.class_embedding' in state_dict: return _convert_openai_clip(state_dict, model) @@ -981,13 +971,17 @@ def checkpoint_filter_fn( state_dict = _convert_dinov2(state_dict, model) if "encoder" in state_dict: - state_dict = _convert_ijepa(state_dict, model) + state_dict = state_dict['encoder'] + prefix = 'module.' if 'visual.trunk.pos_embed' in state_dict: # convert an OpenCLIP model with timm vision encoder - prefix = 'visual.trunk.' - state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} # FIXME remap final nn.Linear if it exists outside of the timm .trunk (ie in visual.head.proj) + prefix = 'visual.trunk.' + + if prefix: + # filter on & remove prefix string from keys + state_dict = {k[len(prefix):]: v for k, v in state_dict.items() if k.startswith(prefix)} for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: @@ -1529,23 +1523,23 @@ default_cfgs = generate_default_cfgs({ license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_huge_patch14_ijepa_224.in1k': _cfg( + 'vit_huge_patch14_gap_224.in1k_ijepa': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_huge_patch14_ijepa_224.in22k': _cfg( + 'vit_huge_patch14_gap_224.in22k_ijepa': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_huge_patch16_ijepa_448.in1k': _cfg( + 'vit_huge_patch16_gap_448.in1k_ijepa': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', input_size=(3, 448, 448), crop_pct=1.0, mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), - 'vit_gigantic_patch16_ijepa_224.in22k': _cfg( + 'vit_giant_patch16_gap_224.in22k_ijepa': _cfg( url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar', # hf_hub_id='timm/', license='cc-by-nc-4.0', @@ -1856,7 +1850,7 @@ def vit_medium_patch16_gap_384(pretrained=False, **kwargs) -> VisionTransformer: @register_model def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: - """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 256x256 + """ ViT-Base (ViT-B/16) w/o class token, w/ avg-pool @ 224x224 """ model_args = dict( patch_size=16, embed_dim=768, depth=12, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) @@ -1865,6 +1859,40 @@ def vit_base_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: return model +@register_model +def vit_huge_patch14_gap_224(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) w/ no class token, avg pool + """ + model_args = dict( + patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_huge_patch14_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_huge_patch16_gap_448(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/16) w/ no class token, avg pool @ 448x448 + """ + model_args = dict( + patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_huge_patch16_gap_448', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch16_gap_224(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-Giant (little-gg) model (ViT-g/16) w/ no class token, avg pool + """ + model_args = dict( + patch_size=16, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=48/11, + class_token=False, global_pool='avg', fc_norm=False) + model = _create_vision_transformer( + 'vit_giant_patch16_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_base_patch32_clip_224(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-B/32 CLIP image tower @ 224x224 @@ -2190,37 +2218,6 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: return model -@register_model -def vit_huge_patch14_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer: - """ ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243 - """ - model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg') - model = _create_vision_transformer( - 'vit_huge_patch14_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -@register_model -def vit_huge_patch16_ijepa_448(pretrained=False, **kwargs) -> VisionTransformer: - """ ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243 - """ - model_args = dict( - patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448) - model = _create_vision_transformer( - 'vit_huge_patch16_ijepa_448', pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - -@register_model -def vit_gigantic_patch16_ijepa_224(pretrained=False, **kwargs) -> VisionTransformer: - """ ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243 - """ - model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) - model = _create_vision_transformer( - 'vit_gigantic_patch16_ijepa_224', pretrained=pretrained, **dict(model_args, **kwargs)) - return model - - @register_model def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(