diff --git a/timm/models/_hub.py b/timm/models/_hub.py index 91077f2b..c24b01cd 100644 --- a/timm/models/_hub.py +++ b/timm/models/_hub.py @@ -50,6 +50,8 @@ __all__ = ['get_cache_dir', 'download_cached_file', 'has_hf_hub', 'hf_split', 'l # Default name for a weights file hosted on the Huggingface Hub. HF_WEIGHTS_NAME = "pytorch_model.bin" # default pytorch pkl HF_SAFE_WEIGHTS_NAME = "model.safetensors" # safetensors version +HF_OPEN_CLIP_WEIGHTS_NAME = "open_clip_pytorch_model.bin" # default pytorch pkl +HF_OPEN_CLIP_SAFE_WEIGHTS_NAME = "open_clip_model.safetensors" # safetensors version def get_cache_dir(child_dir=''): @@ -374,5 +376,7 @@ def _get_safe_alternatives(filename: str) -> Iterable[str]: """ if filename == HF_WEIGHTS_NAME: yield HF_SAFE_WEIGHTS_NAME - if filename != HF_WEIGHTS_NAME and filename.endswith(".bin"): - return filename[:-4] + ".safetensors" + # if filename == HF_OPEN_CLIP_WEIGHTS_NAME: # FIXME tracking safetensors yet + # yield HF_OPEN_CLIP_SAFE_WEIGHTS_NAME + if filename not in (HF_WEIGHTS_NAME, HF_OPEN_CLIP_WEIGHTS_NAME) and filename.endswith(".bin"): + yield filename[:-4] + ".safetensors" diff --git a/timm/models/eva.py b/timm/models/eva.py index be7f3996..6142b2de 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -773,29 +773,51 @@ default_cfgs = generate_default_cfgs({ ), # EVA01 and EVA02 CLIP image towers - 'eva_giant_patch14_224.clip': _cfg( - #hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt', + 'eva_giant_patch14_clip_224.laion400m': _cfg( + # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt', + hf_hub_id='timm/eva_giant_patch14_clip_224.laion400m_s11b_b41k', # float16 weights + hf_hub_filename='open_clip_pytorch_model.bin', num_classes=1024, ), - 'eva02_base_patch16_clip_224.clip': _cfg( - #hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt', + 'eva_giant_patch14_clip_224.merged2b': _cfg( + # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA01_CLIP_g_14_plus_psz14_s11B.pt', + hf_hub_id='timm/eva_giant_patch14_plus_clip_224.merged2b_s11b_b114k', # float16 weights + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=1024, + ), + 'eva02_base_patch16_clip_224.merged2b': _cfg( + # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt', + hf_hub_id='timm/eva02_base_patch16_clip_224.merged2b_s8b_b131k', # float16 weights + hf_hub_filename='open_clip_pytorch_model.bin', num_classes=512, ), - 'eva02_large_patch14_clip_224.clip': _cfg( - #hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt', + 'eva02_large_patch14_clip_224.merged2b': _cfg( + # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt', + hf_hub_id='timm/eva02_large_patch14_clip_224.merged2b_s4b_b131k', # float16 weights + hf_hub_filename='open_clip_pytorch_model.bin', num_classes=768, ), - 'eva02_large_patch14_clip_336.clip': _cfg( + 'eva02_large_patch14_clip_336.merged2b': _cfg( # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_L_psz14_s4B.pt', + hf_hub_id='timm/eva02_large_patch14_clip_336.merged2b_s6b_b61k', # float16 weights + hf_hub_filename='open_clip_pytorch_model.bin', input_size=(3, 336, 336), crop_pct=1.0, num_classes=768, ), - 'eva02_enormous_patch14_clip_224.clip': _cfg( - #hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt', + 'eva02_enormous_patch14_clip_224.laion2b': _cfg( + # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt', + hf_hub_id='timm/eva02_enormous_patch14_clip_224.laion2b_s4b_b115k', # float16 weights + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=1024, + ), + 'eva02_enormous_patch14_clip_224.laion2b_plus': _cfg( + # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_CLIP_E_psz14_plus_s9B.pt', + hf_hub_id='timm/eva02_enormous_patch14_plus_clip_224.laion2b_s9b_b144k', # bfloat16 weights + hf_hub_filename='open_clip_pytorch_model.bin', num_classes=1024, ), 'eva02_enormous_patch14_clip_224.pretrain': _cfg( - #hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_E_psz14.pt', + # hf_hub_id='QuanSun/EVA-CLIP', hf_hub_filename='EVA02_E_psz14.pt', num_classes=0, ), @@ -970,9 +992,19 @@ def eva02_large_patch14_448(pretrained=False, **kwargs): return model +@register_model +def eva_giant_patch14_clip_224(pretrained=False, **kwargs): + """ EVA-g CLIP model (only difference from non-CLIP is the pooling) """ + model_args = dict( + patch_size=14, embed_dim=1408, depth=40, num_heads=16, mlp_ratio=6144 / 1408, + global_pool=kwargs.pop('global_pool', 'token')) + model = _create_eva('eva_giant_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def eva02_base_patch16_clip_224(pretrained=False, **kwargs): - # A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_base + """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_base """ model_args = dict( img_size=224, patch_size=16, @@ -986,7 +1018,7 @@ def eva02_base_patch16_clip_224(pretrained=False, **kwargs): scale_attn_inner=True, use_rot_pos_emb=True, ref_feat_shape=(16, 16), # 224/14 - global_pool='token', + global_pool=kwargs.pop('global_pool', 'token'), ) model = _create_eva('eva02_base_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -994,7 +1026,7 @@ def eva02_base_patch16_clip_224(pretrained=False, **kwargs): @register_model def eva02_large_patch14_clip_224(pretrained=False, **kwargs): - # A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large + """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """ model_args = dict( img_size=224, patch_size=14, @@ -1008,7 +1040,7 @@ def eva02_large_patch14_clip_224(pretrained=False, **kwargs): scale_attn_inner=True, use_rot_pos_emb=True, ref_feat_shape=(16, 16), # 224/14 - global_pool='token', + global_pool=kwargs.pop('global_pool', 'token'), ) model = _create_eva('eva02_large_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -1016,7 +1048,7 @@ def eva02_large_patch14_clip_224(pretrained=False, **kwargs): @register_model def eva02_large_patch14_clip_336(pretrained=False, **kwargs): - # A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large + """ A EVA-CLIP specific variant that adds additional attn scale layernorm to eva02_large """ model_args = dict( img_size=336, patch_size=14, @@ -1030,7 +1062,7 @@ def eva02_large_patch14_clip_336(pretrained=False, **kwargs): scale_attn_inner=True, use_rot_pos_emb=True, ref_feat_shape=(16, 16), # 224/14 - global_pool='token', + global_pool=kwargs.pop('global_pool', 'token'), ) model = _create_eva('eva02_large_patch14_clip_336', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -1038,7 +1070,7 @@ def eva02_large_patch14_clip_336(pretrained=False, **kwargs): @register_model def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs): - # A EVA-CLIP specific variant that uses residual post-norm in blocks + """ A EVA-CLIP specific variant that uses residual post-norm in blocks """ model_args = dict( img_size=224, patch_size=14, @@ -1047,7 +1079,7 @@ def eva02_enormous_patch14_clip_224(pretrained=False, **kwargs): num_heads=16, mlp_ratio=15360 / 1792, use_post_norm=True, - global_pool='token', + global_pool=kwargs.pop('global_pool', 'token'), ) model = _create_eva('eva02_enormous_patch14_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) return model