Prep for siglip2 release
parent
105a667baa
commit
dc5c7989ae
|
@ -29,6 +29,7 @@ class AttentionPoolLatent(nn.Module):
|
|||
pos_embed: str = '',
|
||||
pool_type: str = 'token',
|
||||
norm_layer: Optional[nn.Module] = None,
|
||||
act_layer: Optional[nn.Module] = nn.GELU,
|
||||
drop: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
|
@ -54,13 +55,18 @@ class AttentionPoolLatent(nn.Module):
|
|||
|
||||
self.q = nn.Linear(embed_dim, embed_dim, bias=qkv_bias)
|
||||
self.kv = nn.Linear(embed_dim, embed_dim * 2, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
if qk_norm:
|
||||
qk_norm_layer = norm_layer or nn.LayerNorm
|
||||
self.q_norm = qk_norm_layer(self.head_dim)
|
||||
self.k_norm = qk_norm_layer(self.head_dim)
|
||||
else:
|
||||
self.q_norm = nn.Identity()
|
||||
self.k_norm = nn.Identity()
|
||||
self.proj = nn.Linear(embed_dim, embed_dim)
|
||||
self.proj_drop = nn.Dropout(drop)
|
||||
|
||||
self.norm = norm_layer(out_features) if norm_layer is not None else nn.Identity()
|
||||
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio))
|
||||
self.mlp = Mlp(embed_dim, int(embed_dim * mlp_ratio), act_layer=act_layer)
|
||||
|
||||
self.init_weights()
|
||||
|
||||
|
|
|
@ -584,6 +584,7 @@ class VisionTransformer(nn.Module):
|
|||
num_heads=num_heads,
|
||||
mlp_ratio=mlp_ratio,
|
||||
norm_layer=norm_layer,
|
||||
act_layer=act_layer,
|
||||
)
|
||||
else:
|
||||
self.attn_pool = None
|
||||
|
@ -1887,9 +1888,20 @@ default_cfgs = {
|
|||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
|
||||
'vit_base_patch32_siglip_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_224.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_224.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_256.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
|
@ -1898,28 +1910,51 @@ default_cfgs = {
|
|||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_384.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_384.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_512.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_512.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_256.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_384.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_384.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_512.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_224.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_224.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_so400m_patch16_siglip_256.webli_i18n': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
'vit_so400m_patch14_siglip_378.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 378, 378),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_378.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
|
@ -1929,10 +1964,45 @@ default_cfgs = {
|
|||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch16_siglip_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch16_siglip_256.webli_i18n': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch16_siglip_384.v2_webli': _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch16_siglip_512.v2_webli': _cfg(
|
||||
#hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_giantopt_patch16_siglip_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_giantopt_patch16_siglip_384.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
|
||||
'vit_base_patch32_siglip_gap_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_224.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_224.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_256.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
|
@ -1941,22 +2011,45 @@ default_cfgs = {
|
|||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_384.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_384.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_512.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_base_patch16_siglip_gap_512.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_gap_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_gap_256.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_gap_384.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_gap_384.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_large_patch16_siglip_gap_512.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_224.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_224.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
num_classes=0),
|
||||
|
@ -1977,9 +2070,9 @@ default_cfgs = {
|
|||
# hf_hub_filename='pt_27b_224.npz',
|
||||
# custom_load='hf',
|
||||
# num_classes=0),
|
||||
'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
'vit_so400m_patch14_siglip_gap_378.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 378, 378),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch14_siglip_gap_378.webli': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
|
@ -2053,6 +2146,30 @@ default_cfgs = {
|
|||
# custom_load='hf',
|
||||
# input_size=(3, 896, 896), crop_pct=1.0,
|
||||
# num_classes=0),
|
||||
'vit_so400m_patch16_siglip_gap_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch16_siglip_gap_256.webli_i18n': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch16_siglip_gap_384.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
'vit_so400m_patch16_siglip_gap_512.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 512, 512),
|
||||
num_classes=0),
|
||||
'vit_giantopt_patch16_siglip_gap_256.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 256, 256),
|
||||
num_classes=0),
|
||||
'vit_giantopt_patch16_siglip_gap_384.v2_webli': _cfg(
|
||||
# hf_hub_id='timm/',
|
||||
input_size=(3, 384, 384),
|
||||
num_classes=0),
|
||||
|
||||
'vit_so400m_patch14_siglip_378.webli_ft_in1k': _cfg(
|
||||
hf_hub_id='timm/',
|
||||
|
@ -3114,6 +3231,17 @@ def vit_giant_patch14_reg4_dinov2(pretrained: bool = False, **kwargs) -> VisionT
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
|
||||
act_layer='gelu_tanh',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch32_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
|
@ -3174,6 +3302,17 @@ def vit_large_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTr
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
|
||||
act_layer='gelu_tanh'
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
|
@ -3184,17 +3323,6 @@ def vit_so400m_patch14_siglip_224(pretrained: bool = False, **kwargs) -> VisionT
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
# this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation)
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch14_siglip_378(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
# this is a corrected variant of the 384 with a res properly divisible by patch size (no padding/truncation)
|
||||
|
@ -3216,6 +3344,72 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
|
||||
act_layer='gelu_tanh',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
|
||||
act_layer='gelu_tanh',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch16_siglip_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
|
||||
act_layer='gelu_tanh',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_giantopt_patch16_siglip_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
|
||||
act_layer='gelu_tanh',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_giantopt_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_giantopt_patch16_siglip_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False, global_pool='map',
|
||||
act_layer='gelu_tanh',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_giantopt_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=32, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='avg', fc_norm=False,
|
||||
act_layer='gelu_tanh',
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch32_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
|
@ -3282,6 +3476,17 @@ def vit_large_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> Visi
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False,
|
||||
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch14_siglip_gap_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
""" A SigLIP variant of ViT with global average pooling (GAP) instead of attention pooling (MAP)."""
|
||||
|
@ -3354,6 +3559,62 @@ def vit_so400m_patch14_siglip_gap_896(pretrained: bool = False, **kwargs) -> Vis
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
|
||||
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
|
||||
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_so400m_patch16_siglip_gap_512(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False,
|
||||
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_so400m_patch16_siglip_gap_512', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_giantopt_patch16_siglip_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
|
||||
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_giantopt_patch16_siglip_gap_256', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_giantopt_patch16_siglip_gap_384(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
patch_size=16, embed_dim=1536, depth=40, num_heads=16, class_token=False,
|
||||
global_pool='avg', fc_norm=False, act_layer='gelu_tanh'
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_giantopt_patch16_siglip_gap_384', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_wee_patch16_reg1_gap_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
|
|
Loading…
Reference in New Issue