mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Update vision transformers to be compatible with official code. Port official ViT weights from jax impl.
This commit is contained in:
parent
7613094fb5
commit
736f209e7d
@ -2,6 +2,14 @@
|
||||
|
||||
## What's New
|
||||
|
||||
### Oct 26, 2020
|
||||
* Update Vision Transformer models to be compatible with official code release at https://github.com/google-research/vision_transformer
|
||||
* Add Vision Transformer weights (ImageNet-21k pretrain) for 384x384 base and large models converted from official jax impl
|
||||
* ViT-B/16 - 84.2
|
||||
* ViT-B/32 - 81.7
|
||||
* ViT-L/16 - 85.2
|
||||
* ViT-L/32 - 81.5
|
||||
|
||||
### Oct 21, 2020
|
||||
* Weights added for Vision Transformer (ViT) models. 77.86 top-1 for 'small' and 79.35 for 'base'. Thanks to [Christof](https://www.kaggle.com/christofhenkel) for training the base model w/ lots of GPUs.
|
||||
|
||||
|
@ -1,23 +1,18 @@
|
||||
""" Vision Transformer (ViT) in PyTorch
|
||||
|
||||
This is a WIP attempt to implement Vision Transformers as described in
|
||||
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' -
|
||||
https://openreview.net/pdf?id=YicbFdNTTy
|
||||
A PyTorch implement of Vision Transformers as described in
|
||||
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale' - https://arxiv.org/abs/2010.11929
|
||||
|
||||
The paper is currently under review and there is no official reference impl. The
|
||||
code here is likely to change in the future and I will not make an effort to maintain
|
||||
backwards weight compatibility when it does.
|
||||
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
||||
|
||||
Status/TODO:
|
||||
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to ~75 top-1 after 4 days, 2x GPU,
|
||||
no dropout or stochastic depth active
|
||||
* Need more time for supervised training results with dropout and drop connect active, hparam tuning
|
||||
* Need more GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune
|
||||
* There are likely mistakes. If you notice any, I'd love to improve this. This is my first time
|
||||
fiddling with transformers/multi-head attn.
|
||||
* Hopefully end up with worthwhile pretrained model at some point...
|
||||
* Models updated to be compatible with official impl. Args added to support backward compat for old PyTorch weights.
|
||||
* Weights ported from official jax impl for 384x384 base and small models, 16x16 and 32x32 patches.
|
||||
* Trained (supervised on ImageNet-1k) my custom 'small' patch model to 77.9, 'base' to 79.4 top-1 with this code.
|
||||
* Hopefully find time and GPUs for SSL or unsupervised pretraining on OpenImages w/ ImageNet fine-tune in future.
|
||||
|
||||
Acknowledgments:
|
||||
* The paper authors for releasing code and weights, thanks!
|
||||
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
||||
for some einops/einsum fun
|
||||
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
||||
@ -27,6 +22,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from functools import partial
|
||||
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
from .helpers import load_pretrained
|
||||
@ -52,13 +48,21 @@ default_cfgs = {
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_small_p16_224-15ec54c9.pth',
|
||||
),
|
||||
'vit_base_patch16_224': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth'
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/vit_base_p16_224-4e355ebd.pth',
|
||||
),
|
||||
'vit_base_patch16_384': _cfg(input_size=(3, 384, 384)),
|
||||
'vit_base_patch32_384': _cfg(input_size=(3, 384, 384)),
|
||||
'vit_base_patch16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p16_384-83fb41ba.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
'vit_base_patch32_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_base_p32_384-830016f5.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
'vit_large_patch16_224': _cfg(),
|
||||
'vit_large_patch16_384': _cfg(input_size=(3, 384, 384)),
|
||||
'vit_large_patch32_384': _cfg(input_size=(3, 384, 384)),
|
||||
'vit_large_patch16_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p16_384-b3be5167.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
'vit_large_patch32_384': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-vitjx/jx_vit_large_p32_384-9b920ba8.pth',
|
||||
input_size=(3, 384, 384), mean=(0.5, 0.5, 0.5), std=(0.5, 0.5, 0.5), crop_pct=1.0),
|
||||
'vit_huge_patch16_224': _cfg(),
|
||||
'vit_huge_patch32_384': _cfg(input_size=(3, 384, 384)),
|
||||
# hybrid models
|
||||
@ -77,38 +81,35 @@ class Mlp(nn.Module):
|
||||
self.fc1 = nn.Linear(in_features, hidden_features)
|
||||
self.act = act_layer()
|
||||
self.fc2 = nn.Linear(hidden_features, out_features)
|
||||
self.dropout = nn.Dropout(drop) # seems more common to have Transformer MLP drouput here?
|
||||
self.drop = nn.Dropout(drop)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop(x)
|
||||
x = self.fc2(x)
|
||||
x = self.dropout(x)
|
||||
x = self.drop(x)
|
||||
return x
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(self, dim, num_heads=8, attn_drop=0., proj_drop=0.):
|
||||
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.):
|
||||
super().__init__()
|
||||
self.scale = 1. / dim ** 0.5
|
||||
self.num_heads = num_heads
|
||||
head_dim = dim // num_heads
|
||||
# NOTE scale factor was wrong in my original version, can set manually to be compat with prev weights
|
||||
self.scale = qk_scale or head_dim ** -0.5
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=False)
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
||||
q, k, v = qkv[:, :, 0].transpose(1, 2), qkv[:, :, 1].transpose(1, 2), qkv[:, :, 2].transpose(1, 2)
|
||||
|
||||
# TODO benchmark vs above
|
||||
#qkv = qkv.reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
#q, k, v = qkv
|
||||
q, k, v = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
# FIXME support masking
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
@ -120,52 +121,44 @@ class Attention(nn.Module):
|
||||
|
||||
class Block(nn.Module):
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., act_layer=nn.GELU, drop=0., drop_path=0.):
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
|
||||
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(dim)
|
||||
self.attn = Attention(dim, num_heads=num_heads, attn_drop=drop, proj_drop=drop)
|
||||
self.norm1 = norm_layer(dim)
|
||||
self.attn = Attention(
|
||||
dim, num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale, attn_drop=attn_drop, proj_drop=drop)
|
||||
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
|
||||
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
||||
self.norm2 = nn.LayerNorm(dim)
|
||||
self.norm2 = norm_layer(dim)
|
||||
mlp_hidden_dim = int(dim * mlp_ratio)
|
||||
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x), attn_mask=attn_mask))
|
||||
def forward(self, x):
|
||||
x = x + self.drop_path(self.attn(self.norm1(x)))
|
||||
x = x + self.drop_path(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class PatchEmbed(nn.Module):
|
||||
""" Image to Patch Embedding
|
||||
Unfold image into fixed size patches, flatten into seq, project to embedding dim.
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, flatten_channels_last=False):
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
|
||||
super().__init__()
|
||||
img_size = to_2tuple(img_size)
|
||||
patch_size = to_2tuple(patch_size)
|
||||
assert img_size[0] % patch_size[0] == 0, 'image height must be divisible by the patch height'
|
||||
assert img_size[1] % patch_size[1] == 0, 'image width must be divisible by the patch width'
|
||||
num_patches = (img_size[1] // patch_size[1]) * (img_size[0] // patch_size[0])
|
||||
patch_dim = in_chans * patch_size[0] * patch_size[1]
|
||||
self.img_size = img_size
|
||||
self.patch_size = patch_size
|
||||
self.flatten_channels_last = flatten_channels_last
|
||||
self.num_patches = num_patches
|
||||
|
||||
self.proj = nn.Linear(patch_dim, embed_dim)
|
||||
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
|
||||
|
||||
def forward(self, x):
|
||||
B, C, H, W = x.shape
|
||||
Ph, Pw = self.patch_size
|
||||
# FIXME look at relaxing size constraints
|
||||
assert H == self.img_size[0] and W == self.img_size[1], \
|
||||
f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
|
||||
if self.flatten_channels_last:
|
||||
# flatten patches with channels last like the paper (likely using TF)
|
||||
x = x.unfold(2, Ph, Ph).unfold(3, Pw, Pw).permute(0, 2, 3, 4, 5, 1).reshape(B, -1, Ph * Pw * C)
|
||||
else:
|
||||
x = x.permute(0, 2, 3, 1).unfold(1, Ph, Ph).unfold(2, Pw, Pw).reshape(B, -1, C * Ph * Pw)
|
||||
x = self.proj(x)
|
||||
x = self.proj(x).flatten(2).transpose(1, 2)
|
||||
return x
|
||||
|
||||
|
||||
@ -208,37 +201,37 @@ class VisionTransformer(nn.Module):
|
||||
""" Vision Transformer with support for patch or hybrid CNN input stage
|
||||
"""
|
||||
def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
||||
num_heads=12, mlp_ratio=4., mlp_head=False, drop_rate=0., drop_path_rate=0.,
|
||||
flatten_channels_last=False, hybrid_backbone=None):
|
||||
num_heads=12, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop_rate=0., attn_drop_rate=0.,
|
||||
drop_path_rate=0., hybrid_backbone=None, norm_layer=nn.LayerNorm):
|
||||
super().__init__()
|
||||
if hybrid_backbone is not None:
|
||||
self.patch_embed = HybridEmbed(
|
||||
hybrid_backbone, img_size=img_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
else:
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim,
|
||||
flatten_channels_last=flatten_channels_last)
|
||||
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
||||
num_patches = self.patch_embed.num_patches
|
||||
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
||||
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, embed_dim))
|
||||
self.pos_drop = nn.Dropout(p=drop_rate)
|
||||
|
||||
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
||||
self.blocks = nn.ModuleList([
|
||||
Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, drop=drop_rate, drop_path=dpr[i])
|
||||
Block(
|
||||
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,
|
||||
drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer)
|
||||
for i in range(depth)])
|
||||
self.norm = norm_layer(embed_dim)
|
||||
|
||||
self.norm = nn.LayerNorm(embed_dim)
|
||||
if mlp_head:
|
||||
# paper diagram suggests 'MLP head', but results in 4M extra parameters vs paper
|
||||
self.head = Mlp(embed_dim, int(embed_dim * mlp_ratio), num_classes)
|
||||
else:
|
||||
# with a single Linear layer as head, the param count within rounding of paper
|
||||
# NOTE as per official impl, we could have a pre-logits representation dense layer + tanh here
|
||||
#self.repr = nn.Linear(embed_dim, representation_size)
|
||||
#self.repr_act = nn.Tanh()
|
||||
|
||||
# Classifier head
|
||||
self.head = nn.Linear(embed_dim, num_classes)
|
||||
|
||||
# FIXME not quite sure what the proper weight init is supposed to be,
|
||||
# normal / trunc normal w/ std == .02 similar to other Bert like transformers
|
||||
trunc_normal_(self.pos_embed, std=.02) # embeddings same as weights?
|
||||
trunc_normal_(self.pos_embed, std=.02)
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
self.apply(self._init_weights)
|
||||
|
||||
@ -255,55 +248,80 @@ class VisionTransformer(nn.Module):
|
||||
def no_weight_decay(self):
|
||||
return {'pos_embed', 'cls_token'}
|
||||
|
||||
def forward(self, x, attn_mask=None):
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x += self.pos_embed
|
||||
x = x + self.pos_embed
|
||||
x = self.pos_drop(x)
|
||||
|
||||
for blk in self.blocks:
|
||||
x = blk(x, attn_mask=attn_mask)
|
||||
x = blk(x)
|
||||
|
||||
x = self.norm(x[:, 0])
|
||||
x = self.head(x)
|
||||
x = self.norm(x)
|
||||
x = self.head(x[:, 0])
|
||||
return x
|
||||
|
||||
|
||||
def _conv_filter(state_dict, patch_size=16):
|
||||
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
||||
out_dict = {}
|
||||
for k, v in state_dict.items():
|
||||
if 'patch_embed.proj.weight' in k:
|
||||
v = v.reshape((v.shape[0], 3, patch_size, patch_size))
|
||||
out_dict[k] = v
|
||||
return out_dict
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_small_patch16_224(pretrained=False, **kwargs):
|
||||
if pretrained:
|
||||
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
||||
kwargs.setdefault('qk_scale', 768 ** -0.5)
|
||||
model = VisionTransformer(patch_size=16, embed_dim=768, depth=8, num_heads=8, mlp_ratio=3., **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_small_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_224(pretrained=False, **kwargs):
|
||||
if pretrained:
|
||||
# NOTE my scale was wrong for original weights, leaving this here until I have better ones for this model
|
||||
kwargs.setdefault('qk_scale', 768 ** -0.5)
|
||||
model = VisionTransformer(patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch16_224']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3), filter_fn=_conv_filter)
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch16_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
||||
img_size=384, patch_size=16, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch16_384']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_base_patch32_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, **kwargs)
|
||||
img_size=384, patch_size=32, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_patch32_384']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@ -317,16 +335,24 @@ def vit_large_patch16_224(pretrained=False, **kwargs):
|
||||
@register_model
|
||||
def vit_large_patch16_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
img_size=384, patch_size=16, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch16_384']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@register_model
|
||||
def vit_large_patch32_384(pretrained=False, **kwargs):
|
||||
model = VisionTransformer(
|
||||
img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, **kwargs)
|
||||
img_size=384, patch_size=32, embed_dim=1024, depth=24, num_heads=16, mlp_ratio=4, qkv_bias=True,
|
||||
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_large_patch32_384']
|
||||
if pretrained:
|
||||
load_pretrained(
|
||||
model, num_classes=kwargs.get('num_classes', 0), in_chans=kwargs.get('in_chans', 3))
|
||||
return model
|
||||
|
||||
|
||||
@ -383,5 +409,3 @@ def vit_base_resnet50d_224(pretrained=False, **kwargs):
|
||||
img_size=224, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, hybrid_backbone=backbone, **kwargs)
|
||||
model.default_cfg = default_cfgs['vit_base_resnet50d_224']
|
||||
return model
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user