diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index b4f15cb8..c072f13a 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -879,6 +879,17 @@ def _convert_dinov2(state_dict, model): return out_dict +def _convert_ijepa(state_dict, model): + out_dict = {} + for k, v in state_dict['encoder'].items(): + if k.startswith('module.'): + k = k[7:] + if k.startswith('norm.'): + k = 'fc_norm.' + k[5:] + out_dict[k] = v + return out_dict + + def checkpoint_filter_fn( state_dict, model, @@ -896,7 +907,10 @@ def checkpoint_filter_fn( return _convert_openai_clip(state_dict, model) if "mask_token" in state_dict: - return _convert_dinov2(state_dict, model) + state_dict = _convert_dinov2(state_dict, model) + + if "encoder" in state_dict: + state_dict = _convert_ijepa(state_dict, model) for k, v in state_dict.items(): if 'patch_embed.proj.weight' in k: @@ -1437,6 +1451,27 @@ default_cfgs = generate_default_cfgs({ hf_hub_id='timm/', license='cc-by-nc-4.0', mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + + 'vit_huge_patch14_224_ijepa.in1k': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_huge_patch14_224_ijepa.in22k': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_huge_patch16_448_ijepa.in1k': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), + 'vit_gigantic_patch16_224_ijepa.in22k': _cfg( + url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar', + # hf_hub_id='timm/', + license='cc-by-nc-4.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0), }) @@ -2031,6 +2066,30 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: 'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model +@register_model +def vit_huge_patch14_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243 + """ + model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg') + model = _create_vision_transformer('vit_huge_patch14_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vit_huge_patch16_448_ijepa(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243 + """ + model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448) + model = _create_vision_transformer('vit_huge_patch16_448_ijepa', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + +@register_model +def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243 + """ + model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16) + model = _create_vision_transformer( + 'vit_gigantic_patch16_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs)) + return model register_model_deprecations(__name__, { 'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k', diff --git a/timm/models/vision_transformer_sam.py b/timm/models/vision_transformer_sam.py index c8a8c53f..c561ea1b 100644 --- a/timm/models/vision_transformer_sam.py +++ b/timm/models/vision_transformer_sam.py @@ -605,6 +605,3 @@ def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM: model = _create_vision_transformer( 'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs)) return model - -# TODO: -# support any input size, now only 1024 x 1024 (pretrained)