Merge pull request #1846 from seefun/master

add I-JEPA pretrained weight for ViT
pull/1830/head
Ross Wightman 2023-06-15 11:12:53 -07:00 committed by GitHub
commit f9a24fa19f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 60 additions and 4 deletions

View File

@ -879,6 +879,17 @@ def _convert_dinov2(state_dict, model):
return out_dict
def _convert_ijepa(state_dict, model):
out_dict = {}
for k, v in state_dict['encoder'].items():
if k.startswith('module.'):
k = k[7:]
if k.startswith('norm.'):
k = 'fc_norm.' + k[5:]
out_dict[k] = v
return out_dict
def checkpoint_filter_fn(
state_dict,
model,
@ -896,7 +907,10 @@ def checkpoint_filter_fn(
return _convert_openai_clip(state_dict, model)
if "mask_token" in state_dict:
return _convert_dinov2(state_dict, model)
state_dict = _convert_dinov2(state_dict, model)
if "encoder" in state_dict:
state_dict = _convert_ijepa(state_dict, model)
for k, v in state_dict.items():
if 'patch_embed.proj.weight' in k:
@ -1437,6 +1451,27 @@ default_cfgs = generate_default_cfgs({
hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch14_224_ijepa.in1k': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch14_224_ijepa.in22k': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.h.14-900e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_huge_patch16_448_ijepa.in1k': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.16-448px-300e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_gigantic_patch16_224_ijepa.in22k': _cfg(
url='https://dl.fbaipublicfiles.com/ijepa/IN22K-vit.g.16-600e.pth.tar',
# hf_hub_id='timm/',
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
})
@ -2031,6 +2066,30 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
'vit_giant_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch14_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/14) from `I-JEPA` - https://arxiv.org/abs/2301.08243
"""
model_args = dict(patch_size=14, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg')
model = _create_vision_transformer('vit_huge_patch14_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_huge_patch16_448_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Huge model (ViT-H/16) from `I-JEPA` - https://arxiv.org/abs/2301.08243
"""
model_args = dict(patch_size=16, embed_dim=1280, depth=32, num_heads=16, class_token=False, global_pool='avg', img_size=448)
model = _create_vision_transformer('vit_huge_patch16_448_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_gigantic_patch16_224_ijepa(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-Gigantic (big-G) model (ViT-G/16) from `I-JEPA - https://arxiv.org/abs/2301.08243
"""
model_args = dict(patch_size=16, embed_dim=1664, mlp_ratio=64/13, depth=48, num_heads=16)
model = _create_vision_transformer(
'vit_gigantic_patch16_224_ijepa', pretrained=pretrained, **dict(model_args, **kwargs))
return model
register_model_deprecations(__name__, {
'vit_tiny_patch16_224_in21k': 'vit_tiny_patch16_224.augreg_in21k',

View File

@ -605,6 +605,3 @@ def samvit_huge_patch16(pretrained=False, **kwargs) -> VisionTransformerSAM:
model = _create_vision_transformer(
'samvit_huge_patch16', pretrained=pretrained, **dict(model_args, **kwargs))
return model
# TODO:
# support any input size, now only 1024 x 1024 (pretrained)