[Feature] Add Tokens-to-Token ViT backbone and converted checkpoints. (#467)

* add t2t backbone

* register t2t_vit

* add t2t_vit config

* [Temp] Align posterize transform with timm.

* Fix lint

* Refactor t2t-vit

* Add config for t2t-vit

* Add metafile and README for t2t-vit

* Add unit tests

* configs

* Update metafile and README

* Improve docstring

* Fix batch size which should be 8x64 instead of 8x128

* Fix typo

* Update model zoo

* Update training augments config.

* Move some arguments of T2TModule to T2TViT

* Update docs.

* Update unit test

Co-authored-by: HIT-cwh <2892770585@qq.com>
pull/503/head
Ma Zerun 2021-10-29 10:37:16 +08:00 committed by GitHub
parent 2ce5825ef1
commit fffa30dd48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 844 additions and 2 deletions

View File

@ -0,0 +1,71 @@
_base_ = ['./pipelines/rand_aug.py']
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies={{_base_.rand_increasing_policies}},
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(248, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='accuracy')

View File

@ -0,0 +1,41 @@
# model settings
embed_dims = 384
num_classes = 1000
model = dict(
type='ImageClassifier',
backbone=dict(
type='T2T_ViT',
img_size=224,
in_channels=3,
embed_dims=embed_dims,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=14,
layer_cfgs=dict(
num_heads=6,
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
),
drop_path_rate=0.1,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=num_classes,
in_channels=embed_dims,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))

View File

@ -0,0 +1,41 @@
# model settings
embed_dims = 448
num_classes = 1000
model = dict(
type='ImageClassifier',
backbone=dict(
type='T2T_ViT',
img_size=224,
in_channels=3,
embed_dims=embed_dims,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=19,
layer_cfgs=dict(
num_heads=7,
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
),
drop_path_rate=0.1,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=num_classes,
in_channels=embed_dims,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))

View File

@ -0,0 +1,41 @@
# model settings
embed_dims = 512
num_classes = 1000
model = dict(
type='ImageClassifier',
backbone=dict(
type='T2T_ViT',
img_size=224,
in_channels=3,
embed_dims=embed_dims,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=24,
layer_cfgs=dict(
num_heads=8,
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
),
drop_path_rate=0.1,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=num_classes,
in_channels=embed_dims,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))

View File

@ -0,0 +1,33 @@
# Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
<!-- {Tokens-to-Token ViT} -->
## Introduction
<!-- [ALGORITHM] -->
```latex
@article{yuan2021tokens,
title={Tokens-to-token vit: Training vision transformers from scratch on imagenet},
author={Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Tay, Francis EH and Feng, Jiashi and Yan, Shuicheng},
journal={arXiv preprint arXiv:2101.11986},
year={2021}
}
```
## Pretrain model
The pre-trained modles are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
### ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|:--------------:|:---------:|:--------:|:---------:|:---------:|:--------:|
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth) &#124; [log]()|
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth) &#124; [log]()|
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) &#124; [log]()|
*Models with \* are converted from other repos.*
## Results and models
Waiting for adding.

View File

@ -0,0 +1,64 @@
Collections:
- Name: Tokens-to-Token ViT
Metadata:
Training Data: ImageNet-1k
Architecture:
- Layer Normalization
- Scaled Dot-Product Attention
- Attention Dropout
- Dropout
- Tokens to Token
Paper:
URL: https://arxiv.org/abs/2101.11986
Title: "Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet"
README: configs/t2t_vit/README.md
Models:
- Name: t2t-vit-t-14_3rdparty_8xb64_in1k
Metadata:
FLOPs: 4340000000
Parameters: 21470000
In Collection: Tokens-to-Token ViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.69
Top 5 Accuracy: 95.85
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth
Converted From:
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/81.7_T2T_ViTt_14.pth.tar
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L243
Config: configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
- Name: t2t-vit-t-19_3rdparty_8xb64_in1k
Metadata:
FLOPs: 7800000000
Parameters: 39080000
In Collection: Tokens-to-Token ViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.43
Top 5 Accuracy: 96.08
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth
Converted From:
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.4_T2T_ViTt_19.pth.tar
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L254
Config: configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py
- Name: t2t-vit-t-24_3rdparty_8xb64_in1k
Metadata:
FLOPs: 12690000000
Parameters: 64000000
In Collection: Tokens-to-Token ViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.55
Top 5 Accuracy: 96.06
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth
Converted From:
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.6_T2T_ViTt_24.pth.tar
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L265
Config: configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py

View File

@ -0,0 +1,31 @@
_base_ = [
'../_base_/models/t2t-vit-t-14.py',
'../_base_/datasets/imagenet_bs64_t2t_224.py',
'../_base_/default_runtime.py',
]
# optimizer
paramwise_cfg = dict(
bias_decay_mult=0.0,
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
)
optimizer = dict(
type='AdamW',
lr=5e-4,
weight_decay=0.05,
paramwise_cfg=paramwise_cfg,
)
optimizer_config = dict(grad_clip=None)
# learning policy
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
# the lr in the last 10 epoch equals to min_lr
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-5,
by_epoch=True,
warmup_by_epoch=True,
warmup='linear',
warmup_iters=10,
warmup_ratio=1e-6)
runner = dict(type='EpochBasedRunner', max_epochs=310)

View File

@ -0,0 +1,31 @@
_base_ = [
'../_base_/models/t2t-vit-t-19.py',
'../_base_/datasets/imagenet_bs64_t2t_224.py',
'../_base_/default_runtime.py',
]
# optimizer
paramwise_cfg = dict(
bias_decay_mult=0.0,
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
)
optimizer = dict(
type='AdamW',
lr=5e-4,
weight_decay=0.065,
paramwise_cfg=paramwise_cfg,
)
optimizer_config = dict(grad_clip=None)
# learning policy
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
# the lr in the last 10 epoch equals to min_lr
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-5,
by_epoch=True,
warmup_by_epoch=True,
warmup='linear',
warmup_iters=10,
warmup_ratio=1e-6)
runner = dict(type='EpochBasedRunner', max_epochs=310)

View File

@ -0,0 +1,31 @@
_base_ = [
'../_base_/models/t2t-vit-t-24.py',
'../_base_/datasets/imagenet_bs64_t2t_224.py',
'../_base_/default_runtime.py',
]
# optimizer
paramwise_cfg = dict(
bias_decay_mult=0.0,
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
)
optimizer = dict(
type='AdamW',
lr=5e-4,
weight_decay=0.065,
paramwise_cfg=paramwise_cfg,
)
optimizer_config = dict(grad_clip=None)
# learning policy
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
# the lr in the last 10 epoch equals to min_lr
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-5,
by_epoch=True,
warmup_by_epoch=True,
warmup='linear',
warmup_iters=10,
warmup_ratio=1e-6)
runner = dict(type='EpochBasedRunner', max_epochs=310)

View File

@ -58,6 +58,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219.log.json)|
| Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742.log.json)|
| Transformer in Transformer small\* | 23.76 | 3.36 | 81.52 | 95.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/tnt/tnt_s_patch16_224_evalonly_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth) &#124; [log]()|
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-420df0f6.pth) &#124; [log]()|
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-e479c2a6.pth) &#124; [log]()|
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-b5bf2526.pth) &#124; [log]()|
Models with * are converted from other repos, others are trained by ourselves.

View File

@ -2,6 +2,7 @@
import copy
import inspect
import random
from math import ceil
from numbers import Number
from typing import Sequence
@ -668,7 +669,8 @@ class Posterize(object):
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.'
self.bits = int(bits)
# To align timm version, we need to round up to integer here.
self.bits = ceil(bits)
self.prob = prob
def __call__(self, results):

View File

@ -15,6 +15,7 @@ from .seresnext import SEResNeXt
from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2
from .swin_transformer import SwinTransformer
from .t2t_vit import T2T_ViT
from .timm_backbone import TIMMBackbone
from .tnt import TNT
from .vgg import VGG
@ -24,5 +25,5 @@ __all__ = [
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'TIMMBackbone', 'Res2Net', 'RepVGG'
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG'
]

View File

@ -0,0 +1,367 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from ..builder import BACKBONES
from ..utils import MultiheadAttention
from .base_backbone import BaseBackbone
class T2TTransformerLayer(BaseModule):
"""Transformer Layer for T2T_ViT.
Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
different ``input_dims`` and ``embed_dims``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs
input_dims (int, optional): The input token dimension.
Defaults to None.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
Notes:
In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
``(embed_dims // num_heads) ** -0.5``. However, in the official
code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
keep the same with the official implementation.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
input_dims=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=False,
qk_scale=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg)
self.v_shortcut = True if input_dims is not None else False
input_dims = input_dims or embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, input_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = MultiheadAttention(
input_dims=input_dims,
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
v_shortcut=self.v_shortcut)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
if self.v_shortcut:
x = self.attn(self.norm1(x))
else:
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
return x
class T2TModule(BaseModule):
"""Tokens-to-Token module.
"Tokens-to-Token module" (T2T Module) can model the local structure
information of images and reduce the length of tokens progressively.
Args:
img_size (int): Input image size
in_channels (int): Number of input channels
embed_dims (int): Embedding dimension
token_dims (int): Tokens dimension in T2TModuleAttention.
use_performer (bool): If True, use Performer version self-attention to
adopt regular self-attention. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Notes:
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
MACs
"""
def __init__(
self,
img_size=224,
in_channels=3,
embed_dims=384,
token_dims=64,
use_performer=False,
init_cfg=None,
):
super(T2TModule, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.soft_split0 = nn.Unfold(
kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
if not use_performer:
self.attention1 = T2TTransformerLayer(
input_dims=in_channels * 7 * 7,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.attention2 = T2TTransformerLayer(
input_dims=token_dims * 3 * 3,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
else:
raise NotImplementedError("Performer hasn't been implemented.")
# there are 3 soft split, stride are 4,2,2 separately
self.num_patches = (img_size // (4 * 2 * 2))**2
def forward(self, x):
# step0: soft split
x = self.soft_split0(x).transpose(1, 2)
for step in [1, 2]:
# re-structurization/reconstruction
attn = getattr(self, f'attention{step}')
x = attn(x).transpose(1, 2)
B, C, new_HW = x.shape
x = x.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
# soft split
soft_split = getattr(self, f'soft_split{step}')
x = soft_split(x).transpose(1, 2)
# final tokens
x = self.project(x)
return x
def get_sinusoid_encoding(n_position, embed_dims):
"""Generate sinusoid encoding table.
Sinusoid encoding is a kind of relative position encoding method came from
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
Args:
n_position (int): The length of the input token.
embed_dims (int): The position embedding dimension.
Returns:
:obj:`torch.FloatTensor`: The sinusoid encoding table.
"""
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (i // 2) / embed_dims)
for i in range(embed_dims)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos) for pos in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
@BACKBONES.register_module()
class T2T_ViT(BaseBackbone):
"""Tokens-to-Token Vision Transformer (T2T-ViT)
A PyTorch implementation of `Tokens-to-Token ViT: Training Vision
Transformers from Scratch on ImageNet<https://arxiv.org/abs/2101.11986>`_
Args:
img_size (int): Input image size.
in_channels (int): Number of input channels.
embed_dims (int): Embedding dimension.
t2t_cfg (dict): Extra config of Tokens-to-Token module.
Defaults to an empty dict.
drop_rate (float): Dropout rate after position embedding.
Defaults to 0.
num_layers (int): Num of transformer layers in encoder.
Defaults to 14.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer. Defaults to
``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token.
Defaults to True.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=384,
t2t_cfg=dict(),
drop_rate=0.,
num_layers=14,
out_indices=-1,
layer_cfgs=dict(),
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
final_norm=True,
output_cls_token=True,
init_cfg=None):
super(T2T_ViT, self).__init__(init_cfg)
# Token-to-Token Module
self.tokens_to_token = T2TModule(
img_size=img_size,
in_channels=in_channels,
embed_dims=embed_dims,
**t2t_cfg)
num_patches = self.tokens_to_token.num_patches
# Class token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
# Position Embedding
sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims)
self.register_buffer('pos_embed', sinusoid_table)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = num_layers + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
self.encoder = ModuleList()
for i in range(num_layers):
if isinstance(layer_cfgs, Sequence):
layer_cfg = layer_cfgs[i]
else:
layer_cfg = deepcopy(layer_cfgs)
layer_cfg = {
'embed_dims': embed_dims,
'num_heads': 6,
'feedforward_channels': 3 * embed_dims,
'drop_path_rate': dpr[i],
'qkv_bias': False,
'norm_cfg': norm_cfg,
**layer_cfg
}
layer = T2TTransformerLayer(**layer_cfg)
self.encoder.append(layer)
self.final_norm = final_norm
if final_norm:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = nn.Identity()
def init_weights(self):
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress custom init if use pretrained model.
return
trunc_normal_(self.cls_token, std=.02)
def forward(self, x):
B = x.shape[0]
x = self.tokens_to_token(x)
num_patches = self.tokens_to_token.num_patches
patch_resolution = [int(np.sqrt(num_patches))] * 2
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.drop_after_pos(x)
outs = []
for i, layer in enumerate(self.encoder):
x = layer(x)
if i == len(self.encoder) - 1 and self.final_norm:
x = self.norm(x)
if i in self.out_indices:
B, _, C = x.shape
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
return tuple(outs)

View File

@ -12,3 +12,4 @@ Import:
- configs/repvgg/metafile.yml
- configs/tnt/metafile.yml
- configs/vision_transformer/metafile.yml
- configs/t2t_vit/metafile.yml

View File

@ -0,0 +1,84 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import T2T_ViT
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_vit_backbone():
cfg_ori = dict(
img_size=224,
in_channels=3,
embed_dims=384,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=14,
layer_cfgs=dict(
num_heads=6,
feedforward_channels=3 * 384, # mlp_ratio = 3
),
drop_path_rate=0.1,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
])
with pytest.raises(NotImplementedError):
# test if use performer
cfg = deepcopy(cfg_ori)
cfg['t2t_cfg']['use_performer'] = True
T2T_ViT(**cfg)
# Test T2T-ViT model with input size of 224
model = T2T_ViT(**cfg_ori)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(3, 3, 224, 224)
patch_token, cls_token = model(imgs)[-1]
assert cls_token.shape == (3, 384)
assert patch_token.shape == (3, 384, 14, 14)
# Test custom arch T2T-ViT without output cls token
cfg = deepcopy(cfg_ori)
cfg['embed_dims'] = 256
cfg['num_layers'] = 16
cfg['layer_cfgs'] = dict(num_heads=8, feedforward_channels=1024)
cfg['output_cls_token'] = False
model = T2T_ViT(**cfg)
patch_token = model(imgs)[-1]
assert patch_token.shape == (3, 256, 14, 14)
# Test T2T_ViT with multi out indices
cfg = deepcopy(cfg_ori)
cfg['out_indices'] = [-3, -2, -1]
model = T2T_ViT(**cfg)
for out in model(imgs):
assert out[0].shape == (3, 384, 14, 14)
assert out[1].shape == (3, 384)