Support DeiT-3 (Revenge of the ViT) checkpoints. Add non-overlapping (w/ class token) pos-embed support to vit.
parent
d0c5bd5722
commit
7d4b3807d5
|
@ -1,7 +1,10 @@
|
|||
""" DeiT - Data-efficient Image Transformers
|
||||
|
||||
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
|
||||
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||
|
||||
paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||
|
||||
paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
|
||||
|
||||
Modifications copyright 2021, Ross Wightman
|
||||
"""
|
||||
|
@ -53,6 +56,46 @@ default_cfgs = {
|
|||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0,
|
||||
classifier=('head', 'head_dist')),
|
||||
|
||||
'deit3_small_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
|
||||
'deit3_small_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_base_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
|
||||
'deit3_base_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_large_patch16_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
|
||||
'deit3_large_patch16_384': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_huge_patch14_224': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
|
||||
|
||||
'deit3_small_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_small_patch16_384_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_base_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_base_patch16_384_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_large_patch16_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
|
||||
crop_pct=1.0),
|
||||
'deit3_large_patch16_384_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
|
||||
input_size=(3, 384, 384), crop_pct=1.0),
|
||||
'deit3_huge_patch14_224_in21ft1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
|
||||
crop_pct=1.0),
|
||||
}
|
||||
|
||||
|
||||
|
@ -68,9 +111,10 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||
super().__init__(*args, **kwargs, weight_init='skip')
|
||||
assert self.global_pool in ('token',)
|
||||
|
||||
self.num_tokens = 2
|
||||
self.num_prefix_tokens = 2
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
|
||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||
self.distilled_training = False # must set this True to train w/ distillation token
|
||||
|
||||
|
@ -220,3 +264,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
|||
model = _create_deit(
|
||||
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_224(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_384(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_huge_patch14_224(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs):
|
||||
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||
model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
|
|
|
@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
|
|||
def __init__(
|
||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
||||
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
|
||||
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
||||
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
||||
"""
|
||||
Args:
|
||||
img_size (int, tuple): input image size
|
||||
|
@ -360,15 +360,17 @@ class VisionTransformer(nn.Module):
|
|||
self.num_classes = num_classes
|
||||
self.global_pool = global_pool
|
||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||
self.num_tokens = 1 if class_token else 0
|
||||
self.num_prefix_tokens = 1 if class_token else 0
|
||||
self.no_embed_class = no_embed_class
|
||||
self.grad_checkpointing = False
|
||||
|
||||
self.patch_embed = embed_layer(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
|
@ -428,11 +430,24 @@ class VisionTransformer(nn.Module):
|
|||
self.global_pool = global_pool
|
||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||
|
||||
def _pos_embed(self, x):
|
||||
if self.no_embed_class:
|
||||
# deit-3, updated JAX (big vision)
|
||||
# position embedding does not overlap with class token, add then concat
|
||||
x = x + self.pos_embed
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
else:
|
||||
# original timm, JAX, and deit vit impl
|
||||
# pos_embed has entry for class token, concat then add
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
return self.pos_drop(x)
|
||||
|
||||
def forward_features(self, x):
|
||||
x = self.patch_embed(x)
|
||||
if self.cls_token is not None:
|
||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||
x = self.pos_drop(x + self.pos_embed)
|
||||
x = self._pos_embed(x)
|
||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||
x = checkpoint_seq(self.blocks, x)
|
||||
else:
|
||||
|
@ -442,7 +457,7 @@ class VisionTransformer(nn.Module):
|
|||
|
||||
def forward_head(self, x, pre_logits: bool = False):
|
||||
if self.global_pool:
|
||||
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||
x = self.fc_norm(x)
|
||||
return x if pre_logits else self.head(x)
|
||||
|
||||
|
@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||
if pos_embed_w.shape != model.pos_embed.shape:
|
||||
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
||||
pos_embed_w,
|
||||
model.pos_embed,
|
||||
getattr(model, 'num_prefix_tokens', 1),
|
||||
model.patch_embed.grid_size
|
||||
)
|
||||
model.pos_embed.copy_(pos_embed_w)
|
||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||
|
@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
||||
|
||||
|
||||
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
||||
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
||||
ntok_new = posemb_new.shape[1]
|
||||
if num_tokens:
|
||||
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
||||
ntok_new -= num_tokens
|
||||
if num_prefix_tokens:
|
||||
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
|
||||
ntok_new -= num_prefix_tokens
|
||||
else:
|
||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
||||
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
|
||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||
if not len(gs_new): # backwards compatibility
|
||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||
|
@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
|||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
||||
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
|
||||
return posemb
|
||||
|
||||
|
||||
def checkpoint_filter_fn(state_dict, model):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
import re
|
||||
out_dict = {}
|
||||
if 'model' in state_dict:
|
||||
# For deit models
|
||||
state_dict = state_dict['model']
|
||||
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
||||
# For old models that I trained prior to conv based patchification
|
||||
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||
v = v.reshape(O, -1, H, W)
|
||||
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
||||
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||
# To resize pos embedding when using model at different size from pretrained weights
|
||||
v = resize_pos_embed(
|
||||
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
||||
v,
|
||||
model.pos_embed,
|
||||
getattr(model, 'num_prefix_tokens', 1),
|
||||
model.patch_embed.grid_size
|
||||
)
|
||||
elif 'gamma_' in k:
|
||||
# remap layer-scale gamma into sub-module (deit3 models)
|
||||
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
||||
elif 'pre_logits' in k:
|
||||
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
|
||||
continue
|
||||
|
|
Loading…
Reference in New Issue