Add full set of SigLIP models

This commit is contained in:
Ross Wightman 2023-10-10 22:15:45 -07:00
parent b9dde58076
commit 42daa3b497

View File

@ -606,6 +606,7 @@ class VisionTransformer(nn.Module):
self.attn_pool = AttentionPoolLatent(
self.embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
norm_layer=norm_layer,
)
else:
@ -1644,6 +1645,39 @@ default_cfgs = generate_default_cfgs({
input_size=(3, 256, 256),
# hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_384': _cfg(
file='',
custom_load=True,
input_size=(3, 384, 384),
# hf_hub_id='timm/',
num_classes=0),
'vit_base_patch16_siglip_512': _cfg(
file='',
custom_load=True,
input_size=(3, 512, 512),
# hf_hub_id='timm/',
num_classes=0),
'vit_large_patch16_siglip_256': _cfg(
custom_load=True,
input_size=(3, 256, 256),
# hf_hub_id='timm/',
num_classes=0),
'vit_large_patch16_siglip_384': _cfg(
custom_load=True,
input_size=(3, 384, 384),
# hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_224': _cfg(
# file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
custom_load=True,
# hf_hub_id='timm/',
num_classes=0),
'vit_so400m_patch14_siglip_384': _cfg(
#file='/data/n/temp/siglip/webli_en_b16_256_60500360.npz',
custom_load=True,
# hf_hub_id='timm/',
input_size=(3, 384, 384),
num_classes=0),
})
@ -2290,6 +2324,65 @@ def vit_base_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer
return model
@register_model
def vit_base_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_siglip_512(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=768, depth=12, num_heads=12, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_base_patch16_siglip_512', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch16_siglip_256(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_large_patch16_siglip_256', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch16_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=16, embed_dim=1024, depth=24, num_heads=16, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_large_patch16_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch14_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_so400m_patch14_siglip_384(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=14, embed_dim=1152, depth=27, num_heads=16, mlp_ratio=3.7362, class_token=False, global_pool='map',
)
model = _create_vision_transformer(
'vit_so400m_patch14_siglip_384', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_medium_patch16_reg8_224(pretrained=False, **kwargs) -> VisionTransformer: