[Feature] Add Tokens-to-Token ViT backbone and converted checkpoints. (#467)
* add t2t backbone * register t2t_vit * add t2t_vit config * [Temp] Align posterize transform with timm. * Fix lint * Refactor t2t-vit * Add config for t2t-vit * Add metafile and README for t2t-vit * Add unit tests * configs * Update metafile and README * Improve docstring * Fix batch size which should be 8x64 instead of 8x128 * Fix typo * Update model zoo * Update training augments config. * Move some arguments of T2TModule to T2TViT * Update docs. * Update unit test Co-authored-by: HIT-cwh <2892770585@qq.com>pull/503/head
parent
2ce5825ef1
commit
fffa30dd48
|
@ -0,0 +1,71 @@
|
|||
_base_ = ['./pipelines/rand_aug.py']
|
||||
|
||||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(
|
||||
type='RandAugment',
|
||||
policies={{_base_.rand_increasing_policies}},
|
||||
num_policies=2,
|
||||
total_level=10,
|
||||
magnitude_level=9,
|
||||
magnitude_std=0.5,
|
||||
hparams=dict(
|
||||
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||
interpolation='bicubic')),
|
||||
dict(
|
||||
type='RandomErasing',
|
||||
erase_prob=0.25,
|
||||
mode='rand',
|
||||
min_area_ratio=0.02,
|
||||
max_area_ratio=1 / 3,
|
||||
fill_color=img_norm_cfg['mean'][::-1],
|
||||
fill_std=img_norm_cfg['std'][::-1]),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(248, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=4,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
|
||||
evaluation = dict(interval=10, metric='accuracy')
|
|
@ -0,0 +1,41 @@
|
|||
# model settings
|
||||
embed_dims = 384
|
||||
num_classes = 1000
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='T2T_ViT',
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=embed_dims,
|
||||
t2t_cfg=dict(
|
||||
token_dims=64,
|
||||
use_performer=False,
|
||||
),
|
||||
num_layers=14,
|
||||
layer_cfgs=dict(
|
||||
num_heads=6,
|
||||
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
|
||||
),
|
||||
drop_path_rate=0.1,
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
]),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='VisionTransformerClsHead',
|
||||
num_classes=num_classes,
|
||||
in_channels=embed_dims,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss',
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
),
|
||||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||
]))
|
|
@ -0,0 +1,41 @@
|
|||
# model settings
|
||||
embed_dims = 448
|
||||
num_classes = 1000
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='T2T_ViT',
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=embed_dims,
|
||||
t2t_cfg=dict(
|
||||
token_dims=64,
|
||||
use_performer=False,
|
||||
),
|
||||
num_layers=19,
|
||||
layer_cfgs=dict(
|
||||
num_heads=7,
|
||||
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
|
||||
),
|
||||
drop_path_rate=0.1,
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
]),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='VisionTransformerClsHead',
|
||||
num_classes=num_classes,
|
||||
in_channels=embed_dims,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss',
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
),
|
||||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||
]))
|
|
@ -0,0 +1,41 @@
|
|||
# model settings
|
||||
embed_dims = 512
|
||||
num_classes = 1000
|
||||
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='T2T_ViT',
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=embed_dims,
|
||||
t2t_cfg=dict(
|
||||
token_dims=64,
|
||||
use_performer=False,
|
||||
),
|
||||
num_layers=24,
|
||||
layer_cfgs=dict(
|
||||
num_heads=8,
|
||||
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
|
||||
),
|
||||
drop_path_rate=0.1,
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
]),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='VisionTransformerClsHead',
|
||||
num_classes=num_classes,
|
||||
in_channels=embed_dims,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss',
|
||||
label_smooth_val=0.1,
|
||||
mode='original',
|
||||
),
|
||||
topk=(1, 5),
|
||||
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||
]))
|
|
@ -0,0 +1,33 @@
|
|||
# Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
|
||||
<!-- {Tokens-to-Token ViT} -->
|
||||
|
||||
## Introduction
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
```latex
|
||||
@article{yuan2021tokens,
|
||||
title={Tokens-to-token vit: Training vision transformers from scratch on imagenet},
|
||||
author={Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Tay, Francis EH and Feng, Jiashi and Yan, Shuicheng},
|
||||
journal={arXiv preprint arXiv:2101.11986},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Pretrain model
|
||||
|
||||
The pre-trained modles are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|
||||
|:--------------:|:---------:|:--------:|:---------:|:---------:|:--------:|
|
||||
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth) | [log]()|
|
||||
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth) | [log]()|
|
||||
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
|
||||
|
||||
*Models with \* are converted from other repos.*
|
||||
|
||||
## Results and models
|
||||
|
||||
Waiting for adding.
|
|
@ -0,0 +1,64 @@
|
|||
Collections:
|
||||
- Name: Tokens-to-Token ViT
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Architecture:
|
||||
- Layer Normalization
|
||||
- Scaled Dot-Product Attention
|
||||
- Attention Dropout
|
||||
- Dropout
|
||||
- Tokens to Token
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2101.11986
|
||||
Title: "Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet"
|
||||
README: configs/t2t_vit/README.md
|
||||
|
||||
Models:
|
||||
- Name: t2t-vit-t-14_3rdparty_8xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 4340000000
|
||||
Parameters: 21470000
|
||||
In Collection: Tokens-to-Token ViT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.69
|
||||
Top 5 Accuracy: 95.85
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth
|
||||
Converted From:
|
||||
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/81.7_T2T_ViTt_14.pth.tar
|
||||
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L243
|
||||
Config: configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
|
||||
- Name: t2t-vit-t-19_3rdparty_8xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 7800000000
|
||||
Parameters: 39080000
|
||||
In Collection: Tokens-to-Token ViT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.43
|
||||
Top 5 Accuracy: 96.08
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth
|
||||
Converted From:
|
||||
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.4_T2T_ViTt_19.pth.tar
|
||||
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L254
|
||||
Config: configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py
|
||||
- Name: t2t-vit-t-24_3rdparty_8xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 12690000000
|
||||
Parameters: 64000000
|
||||
In Collection: Tokens-to-Token ViT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 82.55
|
||||
Top 5 Accuracy: 96.06
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth
|
||||
Converted From:
|
||||
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.6_T2T_ViTt_24.pth.tar
|
||||
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L265
|
||||
Config: configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py
|
|
@ -0,0 +1,31 @@
|
|||
_base_ = [
|
||||
'../_base_/models/t2t-vit-t-14.py',
|
||||
'../_base_/datasets/imagenet_bs64_t2t_224.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer
|
||||
paramwise_cfg = dict(
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
|
||||
)
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=5e-4,
|
||||
weight_decay=0.05,
|
||||
paramwise_cfg=paramwise_cfg,
|
||||
)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
|
||||
# learning policy
|
||||
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
|
||||
# the lr in the last 10 epoch equals to min_lr
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=1e-5,
|
||||
by_epoch=True,
|
||||
warmup_by_epoch=True,
|
||||
warmup='linear',
|
||||
warmup_iters=10,
|
||||
warmup_ratio=1e-6)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=310)
|
|
@ -0,0 +1,31 @@
|
|||
_base_ = [
|
||||
'../_base_/models/t2t-vit-t-19.py',
|
||||
'../_base_/datasets/imagenet_bs64_t2t_224.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer
|
||||
paramwise_cfg = dict(
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
|
||||
)
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=5e-4,
|
||||
weight_decay=0.065,
|
||||
paramwise_cfg=paramwise_cfg,
|
||||
)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
|
||||
# learning policy
|
||||
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
|
||||
# the lr in the last 10 epoch equals to min_lr
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=1e-5,
|
||||
by_epoch=True,
|
||||
warmup_by_epoch=True,
|
||||
warmup='linear',
|
||||
warmup_iters=10,
|
||||
warmup_ratio=1e-6)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=310)
|
|
@ -0,0 +1,31 @@
|
|||
_base_ = [
|
||||
'../_base_/models/t2t-vit-t-24.py',
|
||||
'../_base_/datasets/imagenet_bs64_t2t_224.py',
|
||||
'../_base_/default_runtime.py',
|
||||
]
|
||||
|
||||
# optimizer
|
||||
paramwise_cfg = dict(
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
|
||||
)
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=5e-4,
|
||||
weight_decay=0.065,
|
||||
paramwise_cfg=paramwise_cfg,
|
||||
)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
|
||||
# learning policy
|
||||
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
|
||||
# the lr in the last 10 epoch equals to min_lr
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=1e-5,
|
||||
by_epoch=True,
|
||||
warmup_by_epoch=True,
|
||||
warmup='linear',
|
||||
warmup_iters=10,
|
||||
warmup_ratio=1e-6)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=310)
|
|
@ -58,6 +58,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219.log.json)|
|
||||
| Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742.log.json)|
|
||||
| Transformer in Transformer small\* | 23.76 | 3.36 | 81.52 | 95.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/tnt/tnt_s_patch16_224_evalonly_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth) | [log]()|
|
||||
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-420df0f6.pth) | [log]()|
|
||||
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-e479c2a6.pth) | [log]()|
|
||||
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-b5bf2526.pth) | [log]()|
|
||||
|
||||
Models with * are converted from other repos, others are trained by ourselves.
|
||||
|
||||
|
|
|
@ -2,6 +2,7 @@
|
|||
import copy
|
||||
import inspect
|
||||
import random
|
||||
from math import ceil
|
||||
from numbers import Number
|
||||
from typing import Sequence
|
||||
|
||||
|
@ -668,7 +669,8 @@ class Posterize(object):
|
|||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||
f'got {prob} instead.'
|
||||
|
||||
self.bits = int(bits)
|
||||
# To align timm version, we need to round up to integer here.
|
||||
self.bits = ceil(bits)
|
||||
self.prob = prob
|
||||
|
||||
def __call__(self, results):
|
||||
|
|
|
@ -15,6 +15,7 @@ from .seresnext import SEResNeXt
|
|||
from .shufflenet_v1 import ShuffleNetV1
|
||||
from .shufflenet_v2 import ShuffleNetV2
|
||||
from .swin_transformer import SwinTransformer
|
||||
from .t2t_vit import T2T_ViT
|
||||
from .timm_backbone import TIMMBackbone
|
||||
from .tnt import TNT
|
||||
from .vgg import VGG
|
||||
|
@ -24,5 +25,5 @@ __all__ = [
|
|||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'Res2Net', 'RepVGG'
|
||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,367 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils import MultiheadAttention
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
class T2TTransformerLayer(BaseModule):
|
||||
"""Transformer Layer for T2T_ViT.
|
||||
|
||||
Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
|
||||
different ``input_dims`` and ``embed_dims``.
|
||||
|
||||
Args:
|
||||
embed_dims (int): The feature dimension.
|
||||
num_heads (int): Parallel attention heads.
|
||||
feedforward_channels (int): The hidden dimension for FFNs
|
||||
input_dims (int, optional): The input token dimension.
|
||||
Defaults to None.
|
||||
drop_rate (float): Probability of an element to be zeroed
|
||||
after the feed forward layer. Defaults to 0.
|
||||
attn_drop_rate (float): The drop out rate for attention output weights.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||
Defaults to 2.
|
||||
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||
qk_scale (float, optional): Override default qk scale of
|
||||
``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
|
||||
act_cfg (dict): The activation config for FFNs.
|
||||
Defaluts to ``dict(type='GELU')``.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
|
||||
Notes:
|
||||
In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
|
||||
``(embed_dims // num_heads) ** -0.5``. However, in the official
|
||||
code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
|
||||
keep the same with the official implementation.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
embed_dims,
|
||||
num_heads,
|
||||
feedforward_channels,
|
||||
input_dims=None,
|
||||
drop_rate=0.,
|
||||
attn_drop_rate=0.,
|
||||
drop_path_rate=0.,
|
||||
num_fcs=2,
|
||||
qkv_bias=False,
|
||||
qk_scale=None,
|
||||
act_cfg=dict(type='GELU'),
|
||||
norm_cfg=dict(type='LN'),
|
||||
init_cfg=None):
|
||||
super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.v_shortcut = True if input_dims is not None else False
|
||||
input_dims = input_dims or embed_dims
|
||||
|
||||
self.norm1_name, norm1 = build_norm_layer(
|
||||
norm_cfg, input_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self.attn = MultiheadAttention(
|
||||
input_dims=input_dims,
|
||||
embed_dims=embed_dims,
|
||||
num_heads=num_heads,
|
||||
attn_drop=attn_drop_rate,
|
||||
proj_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
qkv_bias=qkv_bias,
|
||||
qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
|
||||
v_shortcut=self.v_shortcut)
|
||||
|
||||
self.norm2_name, norm2 = build_norm_layer(
|
||||
norm_cfg, embed_dims, postfix=2)
|
||||
self.add_module(self.norm2_name, norm2)
|
||||
|
||||
self.ffn = FFN(
|
||||
embed_dims=embed_dims,
|
||||
feedforward_channels=feedforward_channels,
|
||||
num_fcs=num_fcs,
|
||||
ffn_drop=drop_rate,
|
||||
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||
act_cfg=act_cfg)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
@property
|
||||
def norm2(self):
|
||||
return getattr(self, self.norm2_name)
|
||||
|
||||
def forward(self, x):
|
||||
if self.v_shortcut:
|
||||
x = self.attn(self.norm1(x))
|
||||
else:
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = self.ffn(self.norm2(x), identity=x)
|
||||
return x
|
||||
|
||||
|
||||
class T2TModule(BaseModule):
|
||||
"""Tokens-to-Token module.
|
||||
|
||||
"Tokens-to-Token module" (T2T Module) can model the local structure
|
||||
information of images and reduce the length of tokens progressively.
|
||||
|
||||
Args:
|
||||
img_size (int): Input image size
|
||||
in_channels (int): Number of input channels
|
||||
embed_dims (int): Embedding dimension
|
||||
token_dims (int): Tokens dimension in T2TModuleAttention.
|
||||
use_performer (bool): If True, use Performer version self-attention to
|
||||
adopt regular self-attention. Defaults to False.
|
||||
init_cfg (dict, optional): The extra config for initialization.
|
||||
Default: None.
|
||||
|
||||
Notes:
|
||||
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
|
||||
MACs
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=384,
|
||||
token_dims=64,
|
||||
use_performer=False,
|
||||
init_cfg=None,
|
||||
):
|
||||
super(T2TModule, self).__init__(init_cfg)
|
||||
|
||||
self.embed_dims = embed_dims
|
||||
|
||||
self.soft_split0 = nn.Unfold(
|
||||
kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
|
||||
self.soft_split1 = nn.Unfold(
|
||||
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||
self.soft_split2 = nn.Unfold(
|
||||
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||
|
||||
if not use_performer:
|
||||
self.attention1 = T2TTransformerLayer(
|
||||
input_dims=in_channels * 7 * 7,
|
||||
embed_dims=token_dims,
|
||||
num_heads=1,
|
||||
feedforward_channels=token_dims)
|
||||
|
||||
self.attention2 = T2TTransformerLayer(
|
||||
input_dims=token_dims * 3 * 3,
|
||||
embed_dims=token_dims,
|
||||
num_heads=1,
|
||||
feedforward_channels=token_dims)
|
||||
|
||||
self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
|
||||
else:
|
||||
raise NotImplementedError("Performer hasn't been implemented.")
|
||||
|
||||
# there are 3 soft split, stride are 4,2,2 separately
|
||||
self.num_patches = (img_size // (4 * 2 * 2))**2
|
||||
|
||||
def forward(self, x):
|
||||
# step0: soft split
|
||||
x = self.soft_split0(x).transpose(1, 2)
|
||||
|
||||
for step in [1, 2]:
|
||||
# re-structurization/reconstruction
|
||||
attn = getattr(self, f'attention{step}')
|
||||
x = attn(x).transpose(1, 2)
|
||||
B, C, new_HW = x.shape
|
||||
x = x.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
|
||||
|
||||
# soft split
|
||||
soft_split = getattr(self, f'soft_split{step}')
|
||||
x = soft_split(x).transpose(1, 2)
|
||||
|
||||
# final tokens
|
||||
x = self.project(x)
|
||||
return x
|
||||
|
||||
|
||||
def get_sinusoid_encoding(n_position, embed_dims):
|
||||
"""Generate sinusoid encoding table.
|
||||
|
||||
Sinusoid encoding is a kind of relative position encoding method came from
|
||||
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
|
||||
|
||||
Args:
|
||||
n_position (int): The length of the input token.
|
||||
embed_dims (int): The position embedding dimension.
|
||||
|
||||
Returns:
|
||||
:obj:`torch.FloatTensor`: The sinusoid encoding table.
|
||||
"""
|
||||
|
||||
def get_position_angle_vec(position):
|
||||
return [
|
||||
position / np.power(10000, 2 * (i // 2) / embed_dims)
|
||||
for i in range(embed_dims)
|
||||
]
|
||||
|
||||
sinusoid_table = np.array(
|
||||
[get_position_angle_vec(pos) for pos in range(n_position)])
|
||||
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||
|
||||
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class T2T_ViT(BaseBackbone):
|
||||
"""Tokens-to-Token Vision Transformer (T2T-ViT)
|
||||
|
||||
A PyTorch implementation of `Tokens-to-Token ViT: Training Vision
|
||||
Transformers from Scratch on ImageNet<https://arxiv.org/abs/2101.11986>`_
|
||||
|
||||
Args:
|
||||
img_size (int): Input image size.
|
||||
in_channels (int): Number of input channels.
|
||||
embed_dims (int): Embedding dimension.
|
||||
t2t_cfg (dict): Extra config of Tokens-to-Token module.
|
||||
Defaults to an empty dict.
|
||||
drop_rate (float): Dropout rate after position embedding.
|
||||
Defaults to 0.
|
||||
num_layers (int): Num of transformer layers in encoder.
|
||||
Defaults to 14.
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer. Defaults to
|
||||
``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token.
|
||||
Defaults to True.
|
||||
init_cfg (dict, optional): The Config for initialization.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=384,
|
||||
t2t_cfg=dict(),
|
||||
drop_rate=0.,
|
||||
num_layers=14,
|
||||
out_indices=-1,
|
||||
layer_cfgs=dict(),
|
||||
drop_path_rate=0.,
|
||||
norm_cfg=dict(type='LN'),
|
||||
final_norm=True,
|
||||
output_cls_token=True,
|
||||
init_cfg=None):
|
||||
super(T2T_ViT, self).__init__(init_cfg)
|
||||
|
||||
# Token-to-Token Module
|
||||
self.tokens_to_token = T2TModule(
|
||||
img_size=img_size,
|
||||
in_channels=in_channels,
|
||||
embed_dims=embed_dims,
|
||||
**t2t_cfg)
|
||||
num_patches = self.tokens_to_token.num_patches
|
||||
|
||||
# Class token
|
||||
self.output_cls_token = output_cls_token
|
||||
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||
|
||||
# Position Embedding
|
||||
sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims)
|
||||
self.register_buffer('pos_embed', sinusoid_table)
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
out_indices = [out_indices]
|
||||
assert isinstance(out_indices, Sequence), \
|
||||
f'"out_indices" must by a sequence or int, ' \
|
||||
f'get {type(out_indices)} instead.'
|
||||
for i, index in enumerate(out_indices):
|
||||
if index < 0:
|
||||
out_indices[i] = num_layers + index
|
||||
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
||||
self.out_indices = out_indices
|
||||
|
||||
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
|
||||
self.encoder = ModuleList()
|
||||
for i in range(num_layers):
|
||||
if isinstance(layer_cfgs, Sequence):
|
||||
layer_cfg = layer_cfgs[i]
|
||||
else:
|
||||
layer_cfg = deepcopy(layer_cfgs)
|
||||
layer_cfg = {
|
||||
'embed_dims': embed_dims,
|
||||
'num_heads': 6,
|
||||
'feedforward_channels': 3 * embed_dims,
|
||||
'drop_path_rate': dpr[i],
|
||||
'qkv_bias': False,
|
||||
'norm_cfg': norm_cfg,
|
||||
**layer_cfg
|
||||
}
|
||||
|
||||
layer = T2TTransformerLayer(**layer_cfg)
|
||||
self.encoder.append(layer)
|
||||
|
||||
self.final_norm = final_norm
|
||||
if final_norm:
|
||||
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
def init_weights(self):
|
||||
super().init_weights()
|
||||
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
# Suppress custom init if use pretrained model.
|
||||
return
|
||||
|
||||
trunc_normal_(self.cls_token, std=.02)
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.tokens_to_token(x)
|
||||
num_patches = self.tokens_to_token.num_patches
|
||||
patch_resolution = [int(np.sqrt(num_patches))] * 2
|
||||
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.encoder):
|
||||
x = layer(x)
|
||||
|
||||
if i == len(self.encoder) - 1 and self.final_norm:
|
||||
x = self.norm(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token]
|
||||
else:
|
||||
out = patch_token
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
|
@ -12,3 +12,4 @@ Import:
|
|||
- configs/repvgg/metafile.yml
|
||||
- configs/tnt/metafile.yml
|
||||
- configs/vision_transformer/metafile.yml
|
||||
- configs/t2t_vit/metafile.yml
|
||||
|
|
|
@ -0,0 +1,84 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import T2T_ViT
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_vit_backbone():
|
||||
|
||||
cfg_ori = dict(
|
||||
img_size=224,
|
||||
in_channels=3,
|
||||
embed_dims=384,
|
||||
t2t_cfg=dict(
|
||||
token_dims=64,
|
||||
use_performer=False,
|
||||
),
|
||||
num_layers=14,
|
||||
layer_cfgs=dict(
|
||||
num_heads=6,
|
||||
feedforward_channels=3 * 384, # mlp_ratio = 3
|
||||
),
|
||||
drop_path_rate=0.1,
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
])
|
||||
|
||||
with pytest.raises(NotImplementedError):
|
||||
# test if use performer
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['t2t_cfg']['use_performer'] = True
|
||||
T2T_ViT(**cfg)
|
||||
|
||||
# Test T2T-ViT model with input size of 224
|
||||
model = T2T_ViT(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
imgs = torch.randn(3, 3, 224, 224)
|
||||
patch_token, cls_token = model(imgs)[-1]
|
||||
assert cls_token.shape == (3, 384)
|
||||
assert patch_token.shape == (3, 384, 14, 14)
|
||||
|
||||
# Test custom arch T2T-ViT without output cls token
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['embed_dims'] = 256
|
||||
cfg['num_layers'] = 16
|
||||
cfg['layer_cfgs'] = dict(num_heads=8, feedforward_channels=1024)
|
||||
cfg['output_cls_token'] = False
|
||||
|
||||
model = T2T_ViT(**cfg)
|
||||
patch_token = model(imgs)[-1]
|
||||
assert patch_token.shape == (3, 256, 14, 14)
|
||||
|
||||
# Test T2T_ViT with multi out indices
|
||||
cfg = deepcopy(cfg_ori)
|
||||
cfg['out_indices'] = [-3, -2, -1]
|
||||
model = T2T_ViT(**cfg)
|
||||
for out in model(imgs):
|
||||
assert out[0].shape == (3, 384, 14, 14)
|
||||
assert out[1].shape == (3, 384)
|
Loading…
Reference in New Issue