diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index e560ec9a..701fcb84 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -964,11 +964,13 @@ def _convert_openai_clip( v = v.unsqueeze(0) if v.shape[1] != model.pos_embed.shape[1]: # To resize pos embedding when using model at different size from pretrained weights - v = resize_pos_embed( + num_prefix_tokens = 0 if getattr(model, 'no_embed_class', False) \ + else getattr(model, 'num_prefix_tokens', 1) + v = resample_abs_pos_embed( v, - model.pos_embed, - 0 if getattr(model, 'no_embed_class') else getattr(model, 'num_prefix_tokens', 1), - model.patch_embed.grid_size + new_size=model.patch_embed.grid_size, + num_prefix_tokens=num_prefix_tokens, + verbose=True, ) out_dict[k] = v return out_dict @@ -1735,6 +1737,27 @@ default_cfgs = { input_size=(3, 384, 384), num_classes=0), + 'vit_xsmall_patch16_clip_224.tinyclip_yfcc15m': _cfg( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_medium_patch32_clip_224.tinyclip_laion400m': _cfg( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_medium_patch16_clip_224.tinyclip_yfcc15m': _cfg( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_betwixt_patch32_clip_224.tinyclip_laion400m': _cfg( + hf_hub_id='timm/', + hf_hub_filename='open_clip_pytorch_model.bin', + license='mit', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_medium_patch16_reg4_256': _cfg( input_size=(3, 256, 256)), 'vit_medium_patch16_reg4_gap_256': _cfg( @@ -2073,6 +2096,44 @@ def vit_giant_patch16_gap_224(pretrained: bool = False, **kwargs) -> VisionTrans return model +@register_model +def vit_xsmall_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + # TinyCLIP 8M + model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_xsmall_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + # TinyCLIP 40M + model_args = dict( + patch_size=32, embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_medium_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_medium_patch16_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + # TinyCLIP 39M + model_args = dict(embed_dim=512, depth=12, num_heads=8, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_medium_patch16_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_betwixt_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + # TinyCLIP 61M + model_args = dict( + patch_size=32, embed_dim=640, depth=12, num_heads=10, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_betwixt_patch32_clip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_base_patch32_clip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: """ ViT-B/32 CLIP image tower @ 224x224