Rename tinyclip models to fit existing 'clip' variants, use consistently mapped OpenCLIP compatible checkpoint on hf hub

This commit is contained in:
Ross Wightman 2024-03-20 15:21:46 -07:00
parent 1a1d07d479
commit 256cf19148

View File

@ -964,11 +964,13 @@ def _convert_openai_clip(
v = v.unsqueeze(0)
if v.shape[1] != model.pos_embed.shape[1]:
# To resize pos embedding when using model at different size from pretrained weights
v = resize_pos_embed(
num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) \
else getattr(model, 'num_prefix_tokens', 1)
v = resample_abs_pos_embed(
v,
model.pos_embed,
0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1),
model.patch_embed.grid_size
new_size=model.patch_embed.grid_size,
num_prefix_tokens=num_prefix_tokens,
verbose=True,
)
out_dict[k] = v
return out_dict
@ -1015,8 +1017,6 @@ def checkpoint_filter_fn(
return _convert_openai_clip(state_dict, model)
elif 'module.visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model, prefix='module.visual.')
elif '_image_encoder.module.visual.class_embedding' in state_dict:
return _convert_openai_clip(state_dict, model, prefix='_image_encoder.module.visual.')
if "mask_token" in state_dict:
state_dict = _convert_dinov2(state_dict, model)
@ -1737,20 +1737,24 @@ default_cfgs = {
input_size=(3, 384, 384),
num_classes=0),
'vit_8m_patch16_tinyclip_224.yfcc15m': _cfg(
url='https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M.pt',
'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_39m_patch16_tinyclip_224.yfcc15m': _cfg(
url='https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-39M-16-Text-19M-YFCC15M.pt',
'vit_medium_patch32_clip_224.tinyclip_laion400m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_40m_patch32_tinyclip_224.laion400m': _cfg(
url='https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-40M-32-Text-19M-LAION400M.pt',
'vit_medium_patch16_clip_224.tinyclip_yfcc15m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
'vit_61m_patch32_tinyclip_224.laion400m': _cfg(
url='https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-61M-32-Text-29M-LAION400M.pt',
'vit_betwixt_patch32_clip_224.tinyclip_laion400m': _cfg(
hf_hub_id='timm/',
hf_hub_filename='open_clip_pytorch_model.bin',
license='mit',
mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512),
@ -2092,6 +2096,44 @@ def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTrans
return model
@register_model
def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
# TinyCLIP 8M
model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
# TinyCLIP 40M
model_args = dict(
patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
# TinyCLIP 39M
model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
# TinyCLIP 61M
model_args = dict(
patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
""" ViT-B/32 CLIP image tower @ 224x224
@ -2640,40 +2682,6 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT
return model
@register_model
def vit_8m_patch16_tinyclip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_8m_patch16_tinyclip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_39m_patch16_tinyclip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_39m_patch16_tinyclip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_40m_patch32_tinyclip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_40m_patch32_tinyclip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_61m_patch32_tinyclip_224(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(
patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=nn.LayerNorm)
model = _create_vision_transformer(
'vit_61m_patch32_tinyclip_224', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer:
model_args = dict(