Add DINOv2 models with register tokens. Convert pos embed to non-overlapping for consistency.

pull/2015/head
Ross Wightman 2023-10-29 17:05:28 -07:00 committed by Ross Wightman
parent fe92fd93e5
commit 3f02392488
1 changed files with 97 additions and 15 deletions

View File

@ -567,7 +567,11 @@ class VisionTransformer(nn.Module):
def reset_classifier(self, num_classes: int, global_pool=None):
self.num_classes = num_classes
if global_pool is not None:
assert global_pool in ('', 'avg', 'token')
assert global_pool in ('', 'avg', 'token', 'map')
if global_pool == 'map' and self.attn_pool is None:
assert False, "Cannot currently add attention pooling in reset_classifier()."
elif global_pool != 'map ' and self.attn_pool is not None:
self.attn_pool = None # remove attention pooling
self.global_pool = global_pool
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
@ -937,10 +941,14 @@ def _convert_openai_clip(state_dict, model):
def _convert_dinov2(state_dict, model):
import re
out_dict = {}
state_dict.pop("mask_token", None)
if 'register_tokens' in state_dict:
# convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
out_dict['reg_token'] = state_dict.pop('register_tokens')
out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0]
out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:]
for k, v in state_dict.items():
if k == "mask_token":
continue
elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
out_dict[k.replace("w12", "fc1")] = v
continue
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
@ -1229,6 +1237,32 @@ default_cfgs = generate_default_cfgs({
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
# DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only)
'vit_small_patch14_reg4_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_base_patch14_reg4_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_large_patch14_reg4_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg(
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth',
# hf_hub_id='timm/',
license='apache-2.0',
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
input_size=(3, 518, 518), crop_pct=1.0),
# ViT ImageNet-21K-P pretraining by MILL
'vit_base_patch16_224_miil.in21k': _cfg(
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
@ -2173,9 +2207,7 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer:
def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-S/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518,
)
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518)
model = _create_vision_transformer(
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2185,9 +2217,7 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518,
)
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518)
model = _create_vision_transformer(
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2197,9 +2227,7 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-L/14 for DINOv2
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518,
)
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518)
model = _create_vision_transformer(
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@ -2209,12 +2237,10 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-G/14 for DINOv2
"""
# The hidden_features of SwiGLU is calculated by:
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
# When embed_dim=1536, hidden_features=4096
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
@ -2224,6 +2250,62 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
return model
@register_model
def vit_small_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-S/14 for DINOv2 w/ 4 registers
"""
model_args = dict(
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5,
reg_tokens=4, no_embed_class=True,
)
model = _create_vision_transformer(
'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-B/14 for DINOv2 w/ 4 registers
"""
model_args = dict(
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
reg_tokens=4, no_embed_class=True,
)
model = _create_vision_transformer(
'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_large_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-L/14 for DINOv2 w/ 4 registers
"""
model_args = dict(
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
reg_tokens=4, no_embed_class=True,
)
model = _create_vision_transformer(
'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_giant_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
""" ViT-G/14 for DINOv2
"""
# The hidden_features of SwiGLU is calculated by:
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
# When embed_dim=1536, hidden_features=4096
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
model_args = dict(
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2,
mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True,
)
model = _create_vision_transformer(
'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
return model
@register_model
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
model_args = dict(