diff --git a/timm/models/vision_transformer_hybrid.py b/timm/models/vision_transformer_hybrid.py index c16e7c78..0c690c35 100644 --- a/timm/models/vision_transformer_hybrid.py +++ b/timm/models/vision_transformer_hybrid.py @@ -389,12 +389,12 @@ default_cfgs = generate_default_cfgs({ 'vit_base_mci_224.apple_mclip': _cfg( url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_b.pt', num_classes=512, - mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv.weight', + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv', ), 'vit_base_mci_224.apple_mclip_lt': _cfg( url='https://docs-assets.developer.apple.com/ml-research/datasets/mobileclip/mobileclip_blt.pt', num_classes=512, - mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv.weight', + mean=(0., 0., 0.), std=(1., 1., 1.), first_conv='patch_embed.backbone.0.conv', ), }) @@ -552,6 +552,7 @@ def vit_base_mci_224(pretrained=False, **kwargs) -> VisionTransformer: stride=(4, 2, 2), kernel_size=(4, 2, 2), padding=0, + in_chans=kwargs.get('in_chans', 3), act_layer=nn.GELU, ) model_args = dict(embed_dim=768, depth=12, num_heads=12, no_embed_class=True)