Add DINOv2 models with register tokens. Convert pos embed to non-overlapping for consistency.
parent
fe92fd93e5
commit
3f02392488
|
@ -567,7 +567,11 @@ class VisionTransformer(nn.Module):
|
|||
def reset_classifier(self, num_classes: int, global_pool=None):
|
||||
self.num_classes = num_classes
|
||||
if global_pool is not None:
|
||||
assert global_pool in ('', 'avg', 'token')
|
||||
assert global_pool in ('', 'avg', 'token', 'map')
|
||||
if global_pool == 'map' and self.attn_pool is None:
|
||||
assert False, "Cannot currently add attention pooling in reset_classifier()."
|
||||
elif global_pool != 'map ' and self.attn_pool is not None:
|
||||
self.attn_pool = None # remove attention pooling
|
||||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
|
@ -937,10 +941,14 @@ def _convert_openai_clip(state_dict, model):
|
|||
def _convert_dinov2(state_dict, model):
|
||||
import re
|
||||
out_dict = {}
|
||||
state_dict.pop("mask_token", None)
|
||||
if 'register_tokens' in state_dict:
|
||||
# convert dinov2 w/ registers to no_embed_class timm model (neither cls or reg tokens overlap pos embed)
|
||||
out_dict['reg_token'] = state_dict.pop('register_tokens')
|
||||
out_dict['cls_token'] = state_dict.pop('cls_token') + state_dict['pos_embed'][:, 0]
|
||||
out_dict['pos_embed'] = state_dict.pop('pos_embed')[:, 1:]
|
||||
for k, v in state_dict.items():
|
||||
if k == "mask_token":
|
||||
continue
|
||||
elif re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
|
||||
if re.match(r"blocks\.(\d+)\.mlp\.w12\.(?:weight|bias)", k):
|
||||
out_dict[k.replace("w12", "fc1")] = v
|
||||
continue
|
||||
elif re.match(r"blocks\.(\d+)\.mlp\.w3\.(?:weight|bias)", k):
|
||||
|
@ -1229,6 +1237,32 @@ default_cfgs = generate_default_cfgs({
|
|||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 518, 518), crop_pct=1.0),
|
||||
|
||||
# DINOv2 pretrained w/ registers - https://arxiv.org/abs/2309.16588 (no classifier head, for fine-tune/features only)
|
||||
'vit_small_patch14_reg4_dinov2.lvd142m': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_reg4_pretrain.pth',
|
||||
# hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 518, 518), crop_pct=1.0),
|
||||
'vit_base_patch14_reg4_dinov2.lvd142m': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitb14/dinov2_vitb14_reg4_pretrain.pth',
|
||||
# hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 518, 518), crop_pct=1.0),
|
||||
'vit_large_patch14_reg4_dinov2.lvd142m': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_reg4_pretrain.pth',
|
||||
# hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 518, 518), crop_pct=1.0),
|
||||
'vit_giant_patch14_reg4_dinov2.lvd142m': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vitg14/dinov2_vitg14_reg4_pretrain.pth',
|
||||
# hf_hub_id='timm/',
|
||||
license='apache-2.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0,
|
||||
input_size=(3, 518, 518), crop_pct=1.0),
|
||||
|
||||
# ViT ImageNet-21K-P pretraining by MILL
|
||||
'vit_base_patch16_224_miil.in21k': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-tresnet/vit_base_patch16_224_in21k_miil-887286df.pth',
|
||||
|
@ -2173,9 +2207,7 @@ def vit_huge_patch14_xp_224(pretrained=False, **kwargs) -> VisionTransformer:
|
|||
def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-S/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518,
|
||||
)
|
||||
model_args = dict(patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5, img_size=518)
|
||||
model = _create_vision_transformer(
|
||||
'vit_small_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
@ -2185,9 +2217,7 @@ def vit_small_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
|||
def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-B/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518,
|
||||
)
|
||||
model_args = dict(patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5, img_size=518)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
@ -2197,9 +2227,7 @@ def vit_base_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
|||
def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-L/14 for DINOv2
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518,
|
||||
)
|
||||
model_args = dict(patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5, img_size=518)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch14_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
@ -2209,12 +2237,10 @@ def vit_large_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
|||
def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-G/14 for DINOv2
|
||||
"""
|
||||
|
||||
# The hidden_features of SwiGLU is calculated by:
|
||||
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
# When embed_dim=1536, hidden_features=4096
|
||||
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
||||
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
|
||||
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
||||
|
@ -2224,6 +2250,62 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
|||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-S/14 for DINOv2 w/ 4 registers
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=384, depth=12, num_heads=6, init_values=1e-5,
|
||||
reg_tokens=4, no_embed_class=True,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_small_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-B/14 for DINOv2 w/ 4 registers
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=768, depth=12, num_heads=12, init_values=1e-5,
|
||||
reg_tokens=4, no_embed_class=True,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_base_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-L/14 for DINOv2 w/ 4 registers
|
||||
"""
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1024, depth=24, num_heads=16, init_values=1e-5,
|
||||
reg_tokens=4, no_embed_class=True,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_large_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_giant_patch14_reg4_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
""" ViT-G/14 for DINOv2
|
||||
"""
|
||||
# The hidden_features of SwiGLU is calculated by:
|
||||
# hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
||||
# When embed_dim=1536, hidden_features=4096
|
||||
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5, mlp_ratio=2.66667 * 2,
|
||||
mlp_layer=SwiGLUPacked, act_layer=nn.SiLU, reg_tokens=4, no_embed_class=True,
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
'vit_giant_patch14_reg4_dinov2', pretrained=pretrained, **dict(model_args, **kwargs))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_siglip_224(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
model_args = dict(
|
||||
|
|
Loading…
Reference in New Issue