From 83aee5c28c4ebde457d2a750770a358125c5ea58 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Wed, 15 May 2024 07:53:19 -0700 Subject: [PATCH] Add explicit GAP (avg pool) variants of other SigLIP models. --- timm/models/vision_transformer.py | 99 +++++++++++++++++++++++++++++++ 1 file changed, 99 insertions(+) diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index ace8e532..c17be8ca 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1778,6 +1778,35 @@ default_cfgs = { input_size=(3, 384, 384), num_classes=0), + 'vit_base_patch16_siglip_gap_224.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP', + hf_hub_filename='open_clip_pytorch_model.bin', + num_classes=0), + 'vit_base_patch16_siglip_gap_256.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), + 'vit_base_patch16_siglip_gap_384.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 384, 384), + num_classes=0), + 'vit_base_patch16_siglip_gap_512.webli': _cfg( + hf_hub_id='timm/ViT-B-16-SigLIP-512', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 512, 512), + num_classes=0), + 'vit_large_patch16_siglip_gap_256.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-256', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 256, 256), + num_classes=0), + 'vit_large_patch16_siglip_gap_384.webli': _cfg( + hf_hub_id='timm/ViT-L-16-SigLIP-384', + hf_hub_filename='open_clip_pytorch_model.bin', + input_size=(3, 384, 384), + num_classes=0), 'vit_so400m_patch14_siglip_gap_224.webli': _cfg( hf_hub_id='timm/ViT-SO400M-14-SigLIP', hf_hub_filename='open_clip_pytorch_model.bin', @@ -2803,8 +2832,75 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT return model +@register_model +def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_base_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" + model_args = dict( + patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='avg', fc_norm=False, + ) + model = _create_vision_transformer( + 'vit_large_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='avg', fc_norm=False, @@ -2816,6 +2912,7 @@ def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> Vis @register_model def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='avg', fc_norm=False, @@ -2827,6 +2924,7 @@ def vit_so400m_patch14_siglip_gap_384(pretrained: bool = False, **kwargs) -> Vis @register_model def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='avg', fc_norm=False, @@ -2838,6 +2936,7 @@ def vit_so400m_patch14_siglip_gap_448(pretrained: bool = False, **kwargs) -> Vis @register_model def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> VisionTransformer: + """ A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP).""" model_args = dict( patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='avg', fc_norm=False,