mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Support DeiT-3 (Revenge of the ViT) checkpoints. Add non-overlapping (w/ class token) pos-embed support to vit.
This commit is contained in:
parent
d0c5bd5722
commit
7d4b3807d5
@ -1,7 +1,10 @@
|
|||||||
""" DeiT - Data-efficient Image Transformers
|
""" DeiT - Data-efficient Image Transformers
|
||||||
|
|
||||||
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
|
DeiT model defs and weights from https://github.com/facebookresearch/deit, original copyright below
|
||||||
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
|
||||||
|
paper: `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
||||||
|
|
||||||
|
paper: `DeiT III: Revenge of the ViT` - https://arxiv.org/abs/2204.07118
|
||||||
|
|
||||||
Modifications copyright 2021, Ross Wightman
|
Modifications copyright 2021, Ross Wightman
|
||||||
"""
|
"""
|
||||||
@ -53,6 +56,46 @@ default_cfgs = {
|
|||||||
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
url='https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth',
|
||||||
input_size=(3, 384, 384), crop_pct=1.0,
|
input_size=(3, 384, 384), crop_pct=1.0,
|
||||||
classifier=('head', 'head_dist')),
|
classifier=('head', 'head_dist')),
|
||||||
|
|
||||||
|
'deit3_small_patch16_224': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_1k.pth'),
|
||||||
|
'deit3_small_patch16_384': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_1k.pth',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'deit3_base_patch16_224': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_1k.pth'),
|
||||||
|
'deit3_base_patch16_384': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_1k.pth',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'deit3_large_patch16_224': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_1k.pth'),
|
||||||
|
'deit3_large_patch16_384': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_1k.pth',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'deit3_huge_patch14_224': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_1k.pth'),
|
||||||
|
|
||||||
|
'deit3_small_patch16_224_in21ft1k': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_224_21k.pth',
|
||||||
|
crop_pct=1.0),
|
||||||
|
'deit3_small_patch16_384_in21ft1k': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_small_384_21k.pth',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'deit3_base_patch16_224_in21ft1k': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_224_21k.pth',
|
||||||
|
crop_pct=1.0),
|
||||||
|
'deit3_base_patch16_384_in21ft1k': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_base_384_21k.pth',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'deit3_large_patch16_224_in21ft1k': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_224_21k.pth',
|
||||||
|
crop_pct=1.0),
|
||||||
|
'deit3_large_patch16_384_in21ft1k': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_large_384_21k.pth',
|
||||||
|
input_size=(3, 384, 384), crop_pct=1.0),
|
||||||
|
'deit3_huge_patch14_224_in21ft1k': _cfg(
|
||||||
|
url='https://dl.fbaipublicfiles.com/deit/deit_3_huge_224_21k_v1.pth',
|
||||||
|
crop_pct=1.0),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@ -68,9 +111,10 @@ class VisionTransformerDistilled(VisionTransformer):
|
|||||||
super().__init__(*args, **kwargs, weight_init='skip')
|
super().__init__(*args, **kwargs, weight_init='skip')
|
||||||
assert self.global_pool in ('token',)
|
assert self.global_pool in ('token',)
|
||||||
|
|
||||||
self.num_tokens = 2
|
self.num_prefix_tokens = 2
|
||||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
|
||||||
self.pos_embed = nn.Parameter(torch.zeros(1, self.patch_embed.num_patches + self.num_tokens, self.embed_dim))
|
self.pos_embed = nn.Parameter(
|
||||||
|
torch.zeros(1, self.patch_embed.num_patches + self.num_prefix_tokens, self.embed_dim))
|
||||||
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if self.num_classes > 0 else nn.Identity()
|
||||||
self.distilled_training = False # must set this True to train w/ distillation token
|
self.distilled_training = False # must set this True to train w/ distillation token
|
||||||
|
|
||||||
@ -220,3 +264,157 @@ def deit_base_distilled_patch16_384(pretrained=False, **kwargs):
|
|||||||
model = _create_deit(
|
model = _create_deit(
|
||||||
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
'deit_base_distilled_patch16_384', pretrained=pretrained, distilled=True, **model_kwargs)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_small_patch16_224(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_small_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_small_patch16_384(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_small_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_base_patch16_224(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_base_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_base_patch16_384(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_base_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_large_patch16_224(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_large_patch16_224', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_large_patch16_384(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_large_patch16_384', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_huge_patch14_224(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-1k weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_huge_patch14_224', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_small_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 small model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_small_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_small_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 small model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=384, depth=12, num_heads=6, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_small_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_base_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 base model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_base_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_base_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=768, depth=12, num_heads=12, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_base_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_large_patch16_224_in21ft1k(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 large model @ 224x224 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_large_patch16_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_large_patch16_384_in21ft1k(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 large model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=16, embed_dim=1024, depth=24, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_large_patch16_384_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@register_model
|
||||||
|
def deit3_huge_patch14_224_in21ft1k(pretrained=False, **kwargs):
|
||||||
|
""" DeiT-3 base model @ 384x384 from paper (https://arxiv.org/abs/2204.07118).
|
||||||
|
ImageNet-21k pretrained weights from https://github.com/facebookresearch/deit.
|
||||||
|
"""
|
||||||
|
model_kwargs = dict(
|
||||||
|
patch_size=14, embed_dim=1280, depth=32, num_heads=16, no_embed_class=True, init_values=1e-6, **kwargs)
|
||||||
|
model = _create_deit('deit3_huge_patch14_224_in21ft1k', pretrained=pretrained, **model_kwargs)
|
||||||
|
return model
|
||||||
|
@ -325,8 +325,8 @@ class VisionTransformer(nn.Module):
|
|||||||
def __init__(
|
def __init__(
|
||||||
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, global_pool='token',
|
||||||
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, init_values=None,
|
||||||
class_token=True, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0., weight_init='',
|
class_token=True, no_embed_class=False, fc_norm=None, drop_rate=0., attn_drop_rate=0., drop_path_rate=0.,
|
||||||
embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
weight_init='', embed_layer=PatchEmbed, norm_layer=None, act_layer=None, block_fn=Block):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
img_size (int, tuple): input image size
|
img_size (int, tuple): input image size
|
||||||
@ -360,15 +360,17 @@ class VisionTransformer(nn.Module):
|
|||||||
self.num_classes = num_classes
|
self.num_classes = num_classes
|
||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
||||||
self.num_tokens = 1 if class_token else 0
|
self.num_prefix_tokens = 1 if class_token else 0
|
||||||
|
self.no_embed_class = no_embed_class
|
||||||
self.grad_checkpointing = False
|
self.grad_checkpointing = False
|
||||||
|
|
||||||
self.patch_embed = embed_layer(
|
self.patch_embed = embed_layer(
|
||||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||||
num_patches = self.patch_embed.num_patches
|
num_patches = self.patch_embed.num_patches
|
||||||
|
|
||||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if self.num_tokens > 0 else None
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if class_token else None
|
||||||
self.pos_embed = nn.Parameter(torch.randn(1, num_patches + self.num_tokens, embed_dim) * .02)
|
embed_len = num_patches if no_embed_class else num_patches + self.num_prefix_tokens
|
||||||
|
self.pos_embed = nn.Parameter(torch.randn(1, embed_len, embed_dim) * .02)
|
||||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||||
@ -428,11 +430,24 @@ class VisionTransformer(nn.Module):
|
|||||||
self.global_pool = global_pool
|
self.global_pool = global_pool
|
||||||
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
||||||
|
|
||||||
|
def _pos_embed(self, x):
|
||||||
|
if self.no_embed_class:
|
||||||
|
# deit-3, updated JAX (big vision)
|
||||||
|
# position embedding does not overlap with class token, add then concat
|
||||||
|
x = x + self.pos_embed
|
||||||
|
if self.cls_token is not None:
|
||||||
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
|
else:
|
||||||
|
# original timm, JAX, and deit vit impl
|
||||||
|
# pos_embed has entry for class token, concat then add
|
||||||
|
if self.cls_token is not None:
|
||||||
|
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
||||||
|
x = x + self.pos_embed
|
||||||
|
return self.pos_drop(x)
|
||||||
|
|
||||||
def forward_features(self, x):
|
def forward_features(self, x):
|
||||||
x = self.patch_embed(x)
|
x = self.patch_embed(x)
|
||||||
if self.cls_token is not None:
|
x = self._pos_embed(x)
|
||||||
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
||||||
x = self.pos_drop(x + self.pos_embed)
|
|
||||||
if self.grad_checkpointing and not torch.jit.is_scripting():
|
if self.grad_checkpointing and not torch.jit.is_scripting():
|
||||||
x = checkpoint_seq(self.blocks, x)
|
x = checkpoint_seq(self.blocks, x)
|
||||||
else:
|
else:
|
||||||
@ -442,7 +457,7 @@ class VisionTransformer(nn.Module):
|
|||||||
|
|
||||||
def forward_head(self, x, pre_logits: bool = False):
|
def forward_head(self, x, pre_logits: bool = False):
|
||||||
if self.global_pool:
|
if self.global_pool:
|
||||||
x = x[:, self.num_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
x = x[:, self.num_prefix_tokens:].mean(dim=1) if self.global_pool == 'avg' else x[:, 0]
|
||||||
x = self.fc_norm(x)
|
x = self.fc_norm(x)
|
||||||
return x if pre_logits else self.head(x)
|
return x if pre_logits else self.head(x)
|
||||||
|
|
||||||
@ -556,7 +571,11 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||||||
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
||||||
if pos_embed_w.shape != model.pos_embed.shape:
|
if pos_embed_w.shape != model.pos_embed.shape:
|
||||||
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
||||||
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
pos_embed_w,
|
||||||
|
model.pos_embed,
|
||||||
|
getattr(model, 'num_prefix_tokens', 1),
|
||||||
|
model.patch_embed.grid_size
|
||||||
|
)
|
||||||
model.pos_embed.copy_(pos_embed_w)
|
model.pos_embed.copy_(pos_embed_w)
|
||||||
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
||||||
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
||||||
@ -585,16 +604,16 @@ def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str =
|
|||||||
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
||||||
|
|
||||||
|
|
||||||
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
def resize_pos_embed(posemb, posemb_new, num_prefix_tokens=1, gs_new=()):
|
||||||
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
||||||
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
||||||
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
||||||
ntok_new = posemb_new.shape[1]
|
ntok_new = posemb_new.shape[1]
|
||||||
if num_tokens:
|
if num_prefix_tokens:
|
||||||
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
posemb_prefix, posemb_grid = posemb[:, :num_prefix_tokens], posemb[0, num_prefix_tokens:]
|
||||||
ntok_new -= num_tokens
|
ntok_new -= num_prefix_tokens
|
||||||
else:
|
else:
|
||||||
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
posemb_prefix, posemb_grid = posemb[:, :0], posemb[0]
|
||||||
gs_old = int(math.sqrt(len(posemb_grid)))
|
gs_old = int(math.sqrt(len(posemb_grid)))
|
||||||
if not len(gs_new): # backwards compatibility
|
if not len(gs_new): # backwards compatibility
|
||||||
gs_new = [int(math.sqrt(ntok_new))] * 2
|
gs_new = [int(math.sqrt(ntok_new))] * 2
|
||||||
@ -603,25 +622,34 @@ def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
|||||||
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
||||||
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bicubic', align_corners=False)
|
||||||
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
||||||
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
posemb = torch.cat([posemb_prefix, posemb_grid], dim=1)
|
||||||
return posemb
|
return posemb
|
||||||
|
|
||||||
|
|
||||||
def checkpoint_filter_fn(state_dict, model):
|
def checkpoint_filter_fn(state_dict, model):
|
||||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||||
|
import re
|
||||||
out_dict = {}
|
out_dict = {}
|
||||||
if 'model' in state_dict:
|
if 'model' in state_dict:
|
||||||
# For deit models
|
# For deit models
|
||||||
state_dict = state_dict['model']
|
state_dict = state_dict['model']
|
||||||
|
|
||||||
for k, v in state_dict.items():
|
for k, v in state_dict.items():
|
||||||
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
||||||
# For old models that I trained prior to conv based patchification
|
# For old models that I trained prior to conv based patchification
|
||||||
O, I, H, W = model.patch_embed.proj.weight.shape
|
O, I, H, W = model.patch_embed.proj.weight.shape
|
||||||
v = v.reshape(O, -1, H, W)
|
v = v.reshape(O, -1, H, W)
|
||||||
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
elif k == 'pos_embed' and v.shape[1] != model.pos_embed.shape[1]:
|
||||||
# To resize pos embedding when using model at different size from pretrained weights
|
# To resize pos embedding when using model at different size from pretrained weights
|
||||||
v = resize_pos_embed(
|
v = resize_pos_embed(
|
||||||
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
v,
|
||||||
|
model.pos_embed,
|
||||||
|
getattr(model, 'num_prefix_tokens', 1),
|
||||||
|
model.patch_embed.grid_size
|
||||||
|
)
|
||||||
|
elif 'gamma_' in k:
|
||||||
|
# remap layer-scale gamma into sub-module (deit3 models)
|
||||||
|
k = re.sub(r'gamma_([0-9])', r'ls\1.gamma', k)
|
||||||
elif 'pre_logits' in k:
|
elif 'pre_logits' in k:
|
||||||
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
|
# NOTE representation layer removed as not used in latest 21k/1k pretrained weights
|
||||||
continue
|
continue
|
||||||
|
Loading…
x
Reference in New Issue
Block a user