diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index e560ec9a..091390c4 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -1015,6 +1015,8 @@ def checkpoint_filter_fn( return _convert_openai_clip(state_dict, model) elif 'module.visual.class_embedding' in state_dict: return _convert_openai_clip(state_dict, model, prefix='module.visual.') + elif '_image_encoder.module.visual.class_embedding' in state_dict: + return _convert_openai_clip(state_dict, model, prefix='_image_encoder.module.visual.') if "mask_token" in state_dict: state_dict = _convert_dinov2(state_dict, model) @@ -1735,6 +1737,10 @@ default_cfgs = { input_size=(3, 384, 384), num_classes=0), + 'vit_8m_patch16_tinyclip_224.yfcc15m': _cfg( + url='https://github.com/wkcn/TinyCLIP-model-zoo/releases/download/checkpoints/TinyCLIP-ViT-8M-16-Text-3M-YFCC15M.pt', + mean=OPENAI_CLIP_MEAN, std=OPENAI_CLIP_STD, num_classes=512), + 'vit_medium_patch16_reg4_256': _cfg( input_size=(3, 256, 256)), 'vit_medium_patch16_reg4_gap_256': _cfg( @@ -2621,6 +2627,14 @@ def vit_so400m_patch14_siglip_384(pretrained: bool = False, **kwargs) -> VisionT return model +@register_model +def vit_8m_patch16_tinyclip_224(pretrained: bool = False, **kwargs) -> VisionTransformer: + model_args = dict(embed_dim=256, depth=10, num_heads=4, pre_norm=True, norm_layer=nn.LayerNorm) + model = _create_vision_transformer( + 'vit_8m_patch16_tinyclip_224', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_medium_patch16_reg4_256(pretrained: bool = False, **kwargs) -> VisionTransformer: model_args = dict(