diff --git a/timm/models/vision_transformer.py b/timm/models/vision_transformer.py index 1ccb96db..04669f54 100644 --- a/timm/models/vision_transformer.py +++ b/timm/models/vision_transformer.py @@ -567,7 +567,11 @@ class VisionTransformer(nn.Module): def reset_classifier(self, num_classes: int, global_pool=None): self.num_classes = num_classes if global_pool is not None: - assert global_pool in ('', 'avg', 'token') + assert global_pool in ('', 'avg', 'token', 'map') + if global_pool == 'map' and self.attn_pool is None: + assert False, "Cannot currently add attention pooling in reset_classifier()." + elif global_pool != 'map ' and self.attn_pool is not None: + self.attn_pool = None # remove attention pooling self.global_pool = global_pool self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity() @@ -937,10 +941,14 @@ def _convert_openai_clip(state_dict, model): def _convert_dinov2(state_dict, model): import re out_dict = {} + state_dict.pop("mask_token", None) + if 'register_tokens' in state_dict: + # convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed) + out_dict['reg_token'] = state_dict.pop('register_tokens') + out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0] + out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:] for k, v in state_dict.items(): - if k == "mask_token": - continue - elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): + if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k): out_dict[k.replace("w12", "fc1")] = v continue elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k): @@ -1229,6 +1237,32 @@ default_cfgs = generate_default_cfgs({ mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, input_size=(3, 518, 518), crop_pct=1.0), + # DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only) + 'vit_small_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth', + # hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_base_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth', + # hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_large_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth', + # hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + 'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg( + url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth', + # hf_hub_id='timm/', + license='apache-2.0', + mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0, + input_size=(3, 518, 518), crop_pct=1.0), + # ViT ImageNet-21K-P pretraining by MILL 'vit_base_patch16_224_miil.in21k': _cfg( url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth', @@ -2173,9 +2207,7 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer: def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-S/14 for DINOv2 """ - model_args = dict( - patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518, - ) + model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518) model = _create_vision_transformer( 'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2185,9 +2217,7 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-B/14 for DINOv2 """ - model_args = dict( - patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518, - ) + model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518) model = _create_vision_transformer( 'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2197,9 +2227,7 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-L/14 for DINOv2 """ - model_args = dict( - patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518, - ) + model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518) model = _create_vision_transformer( 'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) return model @@ -2209,12 +2237,10 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: """ ViT-G/14 for DINOv2 """ - # The hidden_features of SwiGLU is calculated by: # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 # When embed_dim=1536, hidden_features=4096 # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 - model_args = dict( patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU @@ -2224,6 +2250,62 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer: return model +@register_model +def vit_small_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-S/14 for DINOv2 w/ 4 registers + """ + model_args = dict( + patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, + reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_base_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-B/14 for DINOv2 w/ 4 registers + """ + model_args = dict( + patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, + reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_large_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-L/14 for DINOv2 w/ 4 registers + """ + model_args = dict( + patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, + reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + +@register_model +def vit_giant_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer: + """ ViT-G/14 for DINOv2 + """ + # The hidden_features of SwiGLU is calculated by: + # hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + # When embed_dim=1536, hidden_features=4096 + # With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192 + model_args = dict( + patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2, + mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True, + ) + model = _create_vision_transformer( + 'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs)) + return model + + @register_model def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer: model_args = dict(