Prep for siglip2 release

siglip2
Ross Wightman 2025-02-20 12:07:49 -08:00
parent 105a667baa
commit dc5c7989ae
2 changed files with 287 additions and 20 deletions

View File

@ -29,6 +29,7 @@ class AttentionPoolLatent(nn.Module):
pos_embed: str = '',
pool_type: str = 'token',
norm_layer: Optional[nn.Module] = None,
act_layer: Optional[nn.Module] = nn.GELU,
drop: float = 0.0,
):
super().__init__()
@ -54,13 +55,18 @@ class AttentionPoolLatent(nn.Module):
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
if qk_norm:
qk_norm_layer = norm_layer or nn.LayerNorm
self.q_norm = qk_norm_layer(self.head_dim)
self.k_norm = qk_norm_layer(self.head_dim)
else:
self.q_norm = nn.Identity()
self.k_norm = nn.Identity()
self.proj = nn.Linear(embed_dim, embed_dim)
self.proj_drop = nn.Dropout(drop)
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer)
self.init_weights()

View File

@ -584,6 +584,7 @@ class VisionTransformer(nn.Module):
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
act_layer=act_layer,
)
else:
self.attn_pool = None
@ -1887,9 +1888,20 @@ default_cfgs = {
license='cc-by-nc-4.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
'vit_base_patch32_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_224.v2_webli': _cfg(
# hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_256.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
@ -1898,28 +1910,51 @@ default_cfgs = {
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_384.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_384.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_512.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_base_patch16_siglip_512.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_large_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_256.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_384.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_384.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_512.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_so400m_patch14_siglip_224.v2_webli': _cfg(
# hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch16_siglip_256.webli_i18n': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
'vit_so400m_patch14_siglip_378.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 378, 378),
num_classes=0),
'vit_so400m_patch14_siglip_378.webli': _cfg(
hf_hub_id='timm/',
@ -1929,10 +1964,45 @@ default_cfgs = {
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_256.webli_i18n': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_384.v2_webli': _cfg(
#hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch16_siglip_512.v2_webli': _cfg(
#hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_giantopt_patch16_siglip_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_giantopt_patch16_siglip_384.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch32_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_224.v2_webli': _cfg(
# hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_gap_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_256.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
@ -1941,22 +2011,45 @@ default_cfgs = {
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_base_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_gap_384.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_base_patch16_siglip_gap_512.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_base_patch16_siglip_gap_512.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_large_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_gap_256.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_large_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_gap_384.webli': _cfg(
hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_large_patch16_siglip_gap_512.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.v2_webli': _cfg(
# hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
hf_hub_id='timm/',
num_classes=0),
@ -1977,9 +2070,9 @@ default_cfgs = {
# hf_hub_filename='pt_27b_224.npz',
# custom_load='hf',
# num_classes=0),
'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
'vit_so400m_patch14_siglip_gap_378.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 378, 378),
num_classes=0),
'vit_so400m_patch14_siglip_gap_378.webli': _cfg(
hf_hub_id='timm/',
@ -2053,6 +2146,30 @@ default_cfgs = {
# custom_load='hf',
# input_size=(3, 896, 896), crop_pct=1.0,
# num_classes=0),
'vit_so400m_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_so400m_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch16_siglip_gap_512.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 512, 512),
num_classes=0),
'vit_giantopt_patch16_siglip_gap_256.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 256, 256),
num_classes=0),
'vit_giantopt_patch16_siglip_gap_384.v2_webli': _cfg(
# hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
'vit_so400m_patch14_siglip_378.webli_ft_in1k': _cfg(
hf_hub_id='timm/',
@ -3114,6 +3231,17 @@ def vit_giant_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionT
return model
@register_model
def vit_base_patch32_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
act_layer='gelu_tanh',
)
model = _create_vision_transformer(
'vit_base_patch32_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
@ -3174,6 +3302,17 @@ def vit_large_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTr
return model
@register_model
def vit_large_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
act_layer='gelu_tanh'
)
model = _create_vision_transformer(
'vit_large_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
@ -3184,17 +3323,6 @@ def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionT
return model
@register_model
def vit_so400m_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
# this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation)
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_so400m_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch14_siglip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
# this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation)
@ -3216,6 +3344,72 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
return model
@register_model
def vit_so400m_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
act_layer='gelu_tanh',
)
model = _create_vision_transformer(
'vit_so400m_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
act_layer='gelu_tanh',
)
model = _create_vision_transformer(
'vit_so400m_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
act_layer='gelu_tanh',
)
model = _create_vision_transformer(
'vit_so400m_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_giantopt_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
act_layer='gelu_tanh',
)
model = _create_vision_transformer(
'vit_giantopt_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_giantopt_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
act_layer='gelu_tanh',
)
model = _create_vision_transformer(
'vit_giantopt_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch32_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
act_layer='gelu_tanh',
)
model = _create_vision_transformer(
'vit_base_patch32_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
@ -3282,6 +3476,17 @@ def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> Visi
return model
@register_model
def vit_large_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False,
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
)
model = _create_vision_transformer(
'vit_large_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
@ -3354,6 +3559,62 @@ def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> Vis
return model
@register_model
def vit_so400m_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
)
model = _create_vision_transformer(
'vit_so400m_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
)
model = _create_vision_transformer(
'vit_so400m_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
)
model = _create_vision_transformer(
'vit_so400m_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_giantopt_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
)
model = _create_vision_transformer(
'vit_giantopt_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_giantopt_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
)
model = _create_vision_transformer(
'vit_giantopt_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(