[Feature] Add Tokens-to-Token ViT backbone and converted checkpoints. (#467)

* add t2t backbone

* register t2t_vit

* add t2t_vit config

* [Temp] Align posterize transform with timm.

* Fix lint

* Refactor t2t-vit

* Add config for t2t-vit

* Add metafile and README for t2t-vit

* Add unit tests

* configs

* Update metafile and README

* Improve docstring

* Fix batch size which should be 8x64 instead of 8x128

* Fix typo

* Update model zoo

* Update training augments config.

* Move some arguments of T2TModule to T2TViT

* Update docs.

* Update unit test

Co-authored-by: HIT-cwh <2892770585@qq.com>
This commit is contained in:
Ma Zerun 2021-10-29 10:37:16 +08:00 committed by GitHub
parent 2ce5825ef1
commit fffa30dd48
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 844 additions and 2 deletions

View File

@ -0,0 +1,71 @@
_base_ = ['./pipelines/rand_aug.py']
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(
type='RandAugment',
policies={{_base_.rand_increasing_policies}},
num_policies=2,
total_level=10,
magnitude_level=9,
magnitude_std=0.5,
hparams=dict(
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
interpolation='bicubic')),
dict(
type='RandomErasing',
erase_prob=0.25,
mode='rand',
min_area_ratio=0.02,
max_area_ratio=1 / 3,
fill_color=img_norm_cfg['mean'][::-1],
fill_std=img_norm_cfg['std'][::-1]),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='Resize',
size=(248, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
samples_per_gpu=64,
workers_per_gpu=4,
train=dict(
type=dataset_type,
data_prefix='data/imagenet/train',
pipeline=train_pipeline),
val=dict(
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline),
test=dict(
# replace `data/val` with `data/test` for standard test
type=dataset_type,
data_prefix='data/imagenet/val',
ann_file='data/imagenet/meta/val.txt',
pipeline=test_pipeline))
evaluation = dict(interval=10, metric='accuracy')

View File

@ -0,0 +1,41 @@
# model settings
embed_dims = 384
num_classes = 1000
model = dict(
type='ImageClassifier',
backbone=dict(
type='T2T_ViT',
img_size=224,
in_channels=3,
embed_dims=embed_dims,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=14,
layer_cfgs=dict(
num_heads=6,
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
),
drop_path_rate=0.1,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=num_classes,
in_channels=embed_dims,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))

View File

@ -0,0 +1,41 @@
# model settings
embed_dims = 448
num_classes = 1000
model = dict(
type='ImageClassifier',
backbone=dict(
type='T2T_ViT',
img_size=224,
in_channels=3,
embed_dims=embed_dims,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=19,
layer_cfgs=dict(
num_heads=7,
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
),
drop_path_rate=0.1,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=num_classes,
in_channels=embed_dims,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))

View File

@ -0,0 +1,41 @@
# model settings
embed_dims = 512
num_classes = 1000
model = dict(
type='ImageClassifier',
backbone=dict(
type='T2T_ViT',
img_size=224,
in_channels=3,
embed_dims=embed_dims,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=24,
layer_cfgs=dict(
num_heads=8,
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
),
drop_path_rate=0.1,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
]),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=num_classes,
in_channels=embed_dims,
loss=dict(
type='LabelSmoothLoss',
label_smooth_val=0.1,
mode='original',
),
topk=(1, 5),
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
train_cfg=dict(augments=[
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
]))

33
configs/t2t_vit/README.md Normal file
View File

@ -0,0 +1,33 @@
# Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
<!-- {Tokens-to-Token ViT} -->
## Introduction
<!-- [ALGORITHM] -->
```latex
@article{yuan2021tokens,
title={Tokens-to-token vit: Training vision transformers from scratch on imagenet},
author={Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Tay, Francis EH and Feng, Jiashi and Yan, Shuicheng},
journal={arXiv preprint arXiv:2101.11986},
year={2021}
}
```
## Pretrain model
The pre-trained modles are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
### ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|:--------------:|:---------:|:--------:|:---------:|:---------:|:--------:|
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth) &#124; [log]()|
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth) &#124; [log]()|
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) &#124; [log]()|
*Models with \* are converted from other repos.*
## Results and models
Waiting for adding.

View File

@ -0,0 +1,64 @@
Collections:
- Name: Tokens-to-Token ViT
Metadata:
Training Data: ImageNet-1k
Architecture:
- Layer Normalization
- Scaled Dot-Product Attention
- Attention Dropout
- Dropout
- Tokens to Token
Paper:
URL: https://arxiv.org/abs/2101.11986
Title: "Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet"
README: configs/t2t_vit/README.md
Models:
- Name: t2t-vit-t-14_3rdparty_8xb64_in1k
Metadata:
FLOPs: 4340000000
Parameters: 21470000
In Collection: Tokens-to-Token ViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.69
Top 5 Accuracy: 95.85
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth
Converted From:
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/81.7_T2T_ViTt_14.pth.tar
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L243
Config: configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
- Name: t2t-vit-t-19_3rdparty_8xb64_in1k
Metadata:
FLOPs: 7800000000
Parameters: 39080000
In Collection: Tokens-to-Token ViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.43
Top 5 Accuracy: 96.08
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth
Converted From:
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.4_T2T_ViTt_19.pth.tar
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L254
Config: configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py
- Name: t2t-vit-t-24_3rdparty_8xb64_in1k
Metadata:
FLOPs: 12690000000
Parameters: 64000000
In Collection: Tokens-to-Token ViT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.55
Top 5 Accuracy: 96.06
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth
Converted From:
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.6_T2T_ViTt_24.pth.tar
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L265
Config: configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py

View File

@ -0,0 +1,31 @@
_base_ = [
'../_base_/models/t2t-vit-t-14.py',
'../_base_/datasets/imagenet_bs64_t2t_224.py',
'../_base_/default_runtime.py',
]
# optimizer
paramwise_cfg = dict(
bias_decay_mult=0.0,
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
)
optimizer = dict(
type='AdamW',
lr=5e-4,
weight_decay=0.05,
paramwise_cfg=paramwise_cfg,
)
optimizer_config = dict(grad_clip=None)
# learning policy
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
# the lr in the last 10 epoch equals to min_lr
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-5,
by_epoch=True,
warmup_by_epoch=True,
warmup='linear',
warmup_iters=10,
warmup_ratio=1e-6)
runner = dict(type='EpochBasedRunner', max_epochs=310)

View File

@ -0,0 +1,31 @@
_base_ = [
'../_base_/models/t2t-vit-t-19.py',
'../_base_/datasets/imagenet_bs64_t2t_224.py',
'../_base_/default_runtime.py',
]
# optimizer
paramwise_cfg = dict(
bias_decay_mult=0.0,
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
)
optimizer = dict(
type='AdamW',
lr=5e-4,
weight_decay=0.065,
paramwise_cfg=paramwise_cfg,
)
optimizer_config = dict(grad_clip=None)
# learning policy
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
# the lr in the last 10 epoch equals to min_lr
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-5,
by_epoch=True,
warmup_by_epoch=True,
warmup='linear',
warmup_iters=10,
warmup_ratio=1e-6)
runner = dict(type='EpochBasedRunner', max_epochs=310)

View File

@ -0,0 +1,31 @@
_base_ = [
'../_base_/models/t2t-vit-t-24.py',
'../_base_/datasets/imagenet_bs64_t2t_224.py',
'../_base_/default_runtime.py',
]
# optimizer
paramwise_cfg = dict(
bias_decay_mult=0.0,
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
)
optimizer = dict(
type='AdamW',
lr=5e-4,
weight_decay=0.065,
paramwise_cfg=paramwise_cfg,
)
optimizer_config = dict(grad_clip=None)
# learning policy
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
# the lr in the last 10 epoch equals to min_lr
lr_config = dict(
policy='CosineAnnealing',
min_lr=1e-5,
by_epoch=True,
warmup_by_epoch=True,
warmup='linear',
warmup_iters=10,
warmup_ratio=1e-6)
runner = dict(type='EpochBasedRunner', max_epochs=310)

View File

@ -58,6 +58,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219.log.json)| | Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219.log.json)|
| Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742.log.json)| | Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742.log.json)|
| Transformer in Transformer small\* | 23.76 | 3.36 | 81.52 | 95.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/tnt/tnt_s_patch16_224_evalonly_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth) &#124; [log]()| | Transformer in Transformer small\* | 23.76 | 3.36 | 81.52 | 95.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/tnt/tnt_s_patch16_224_evalonly_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth) &#124; [log]()|
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-420df0f6.pth) &#124; [log]()|
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-e479c2a6.pth) &#124; [log]()|
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-b5bf2526.pth) &#124; [log]()|
Models with * are converted from other repos, others are trained by ourselves. Models with * are converted from other repos, others are trained by ourselves.

View File

@ -2,6 +2,7 @@
import copy import copy
import inspect import inspect
import random import random
from math import ceil
from numbers import Number from numbers import Number
from typing import Sequence from typing import Sequence
@ -668,7 +669,8 @@ class Posterize(object):
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \ assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
f'got {prob} instead.' f'got {prob} instead.'
self.bits = int(bits) # To align timm version, we need to round up to integer here.
self.bits = ceil(bits)
self.prob = prob self.prob = prob
def __call__(self, results): def __call__(self, results):

View File

@ -15,6 +15,7 @@ from .seresnext import SEResNeXt
from .shufflenet_v1 import ShuffleNetV1 from .shufflenet_v1 import ShuffleNetV1
from .shufflenet_v2 import ShuffleNetV2 from .shufflenet_v2 import ShuffleNetV2
from .swin_transformer import SwinTransformer from .swin_transformer import SwinTransformer
from .t2t_vit import T2T_ViT
from .timm_backbone import TIMMBackbone from .timm_backbone import TIMMBackbone
from .tnt import TNT from .tnt import TNT
from .vgg import VGG from .vgg import VGG
@ -24,5 +25,5 @@ __all__ = [
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d', 'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1', 'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer', 'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'TIMMBackbone', 'Res2Net', 'RepVGG' 'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG'
] ]

View File

@ -0,0 +1,367 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
import torch
import torch.nn as nn
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from ..builder import BACKBONES
from ..utils import MultiheadAttention
from .base_backbone import BaseBackbone
class T2TTransformerLayer(BaseModule):
"""Transformer Layer for T2T_ViT.
Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
different ``input_dims`` and ``embed_dims``.
Args:
embed_dims (int): The feature dimension.
num_heads (int): Parallel attention heads.
feedforward_channels (int): The hidden dimension for FFNs
input_dims (int, optional): The input token dimension.
Defaults to None.
drop_rate (float): Probability of an element to be zeroed
after the feed forward layer. Defaults to 0.
attn_drop_rate (float): The drop out rate for attention output weights.
Defaults to 0.
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
num_fcs (int): The number of fully-connected layers for FFNs.
Defaults to 2.
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
qk_scale (float, optional): Override default qk scale of
``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
act_cfg (dict): The activation config for FFNs.
Defaluts to ``dict(type='GELU')``.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
Notes:
In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
``(embed_dims // num_heads) ** -0.5``. However, in the official
code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
keep the same with the official implementation.
"""
def __init__(self,
embed_dims,
num_heads,
feedforward_channels,
input_dims=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
num_fcs=2,
qkv_bias=False,
qk_scale=None,
act_cfg=dict(type='GELU'),
norm_cfg=dict(type='LN'),
init_cfg=None):
super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg)
self.v_shortcut = True if input_dims is not None else False
input_dims = input_dims or embed_dims
self.norm1_name, norm1 = build_norm_layer(
norm_cfg, input_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self.attn = MultiheadAttention(
input_dims=input_dims,
embed_dims=embed_dims,
num_heads=num_heads,
attn_drop=attn_drop_rate,
proj_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
qkv_bias=qkv_bias,
qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
v_shortcut=self.v_shortcut)
self.norm2_name, norm2 = build_norm_layer(
norm_cfg, embed_dims, postfix=2)
self.add_module(self.norm2_name, norm2)
self.ffn = FFN(
embed_dims=embed_dims,
feedforward_channels=feedforward_channels,
num_fcs=num_fcs,
ffn_drop=drop_rate,
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
act_cfg=act_cfg)
@property
def norm1(self):
return getattr(self, self.norm1_name)
@property
def norm2(self):
return getattr(self, self.norm2_name)
def forward(self, x):
if self.v_shortcut:
x = self.attn(self.norm1(x))
else:
x = x + self.attn(self.norm1(x))
x = self.ffn(self.norm2(x), identity=x)
return x
class T2TModule(BaseModule):
"""Tokens-to-Token module.
"Tokens-to-Token module" (T2T Module) can model the local structure
information of images and reduce the length of tokens progressively.
Args:
img_size (int): Input image size
in_channels (int): Number of input channels
embed_dims (int): Embedding dimension
token_dims (int): Tokens dimension in T2TModuleAttention.
use_performer (bool): If True, use Performer version self-attention to
adopt regular self-attention. Defaults to False.
init_cfg (dict, optional): The extra config for initialization.
Default: None.
Notes:
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
MACs
"""
def __init__(
self,
img_size=224,
in_channels=3,
embed_dims=384,
token_dims=64,
use_performer=False,
init_cfg=None,
):
super(T2TModule, self).__init__(init_cfg)
self.embed_dims = embed_dims
self.soft_split0 = nn.Unfold(
kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
self.soft_split1 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
self.soft_split2 = nn.Unfold(
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
if not use_performer:
self.attention1 = T2TTransformerLayer(
input_dims=in_channels * 7 * 7,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.attention2 = T2TTransformerLayer(
input_dims=token_dims * 3 * 3,
embed_dims=token_dims,
num_heads=1,
feedforward_channels=token_dims)
self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
else:
raise NotImplementedError("Performer hasn't been implemented.")
# there are 3 soft split, stride are 4,2,2 separately
self.num_patches = (img_size // (4 * 2 * 2))**2
def forward(self, x):
# step0: soft split
x = self.soft_split0(x).transpose(1, 2)
for step in [1, 2]:
# re-structurization/reconstruction
attn = getattr(self, f'attention{step}')
x = attn(x).transpose(1, 2)
B, C, new_HW = x.shape
x = x.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
# soft split
soft_split = getattr(self, f'soft_split{step}')
x = soft_split(x).transpose(1, 2)
# final tokens
x = self.project(x)
return x
def get_sinusoid_encoding(n_position, embed_dims):
"""Generate sinusoid encoding table.
Sinusoid encoding is a kind of relative position encoding method came from
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
Args:
n_position (int): The length of the input token.
embed_dims (int): The position embedding dimension.
Returns:
:obj:`torch.FloatTensor`: The sinusoid encoding table.
"""
def get_position_angle_vec(position):
return [
position / np.power(10000, 2 * (i // 2) / embed_dims)
for i in range(embed_dims)
]
sinusoid_table = np.array(
[get_position_angle_vec(pos) for pos in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
@BACKBONES.register_module()
class T2T_ViT(BaseBackbone):
"""Tokens-to-Token Vision Transformer (T2T-ViT)
A PyTorch implementation of `Tokens-to-Token ViT: Training Vision
Transformers from Scratch on ImageNet<https://arxiv.org/abs/2101.11986>`_
Args:
img_size (int): Input image size.
in_channels (int): Number of input channels.
embed_dims (int): Embedding dimension.
t2t_cfg (dict): Extra config of Tokens-to-Token module.
Defaults to an empty dict.
drop_rate (float): Dropout rate after position embedding.
Defaults to 0.
num_layers (int): Num of transformer layers in encoder.
Defaults to 14.
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer. Defaults to
``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token.
Defaults to True.
init_cfg (dict, optional): The Config for initialization.
Defaults to None.
"""
def __init__(self,
img_size=224,
in_channels=3,
embed_dims=384,
t2t_cfg=dict(),
drop_rate=0.,
num_layers=14,
out_indices=-1,
layer_cfgs=dict(),
drop_path_rate=0.,
norm_cfg=dict(type='LN'),
final_norm=True,
output_cls_token=True,
init_cfg=None):
super(T2T_ViT, self).__init__(init_cfg)
# Token-to-Token Module
self.tokens_to_token = T2TModule(
img_size=img_size,
in_channels=in_channels,
embed_dims=embed_dims,
**t2t_cfg)
num_patches = self.tokens_to_token.num_patches
# Class token
self.output_cls_token = output_cls_token
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
# Position Embedding
sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims)
self.register_buffer('pos_embed', sinusoid_table)
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
out_indices = [out_indices]
assert isinstance(out_indices, Sequence), \
f'"out_indices" must by a sequence or int, ' \
f'get {type(out_indices)} instead.'
for i, index in enumerate(out_indices):
if index < 0:
out_indices[i] = num_layers + index
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
self.out_indices = out_indices
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
self.encoder = ModuleList()
for i in range(num_layers):
if isinstance(layer_cfgs, Sequence):
layer_cfg = layer_cfgs[i]
else:
layer_cfg = deepcopy(layer_cfgs)
layer_cfg = {
'embed_dims': embed_dims,
'num_heads': 6,
'feedforward_channels': 3 * embed_dims,
'drop_path_rate': dpr[i],
'qkv_bias': False,
'norm_cfg': norm_cfg,
**layer_cfg
}
layer = T2TTransformerLayer(**layer_cfg)
self.encoder.append(layer)
self.final_norm = final_norm
if final_norm:
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
else:
self.norm = nn.Identity()
def init_weights(self):
super().init_weights()
if (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
# Suppress custom init if use pretrained model.
return
trunc_normal_(self.cls_token, std=.02)
def forward(self, x):
B = x.shape[0]
x = self.tokens_to_token(x)
num_patches = self.tokens_to_token.num_patches
patch_resolution = [int(np.sqrt(num_patches))] * 2
cls_tokens = self.cls_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, x), dim=1)
x = x + self.pos_embed
x = self.drop_after_pos(x)
outs = []
for i, layer in enumerate(self.encoder):
x = layer(x)
if i == len(self.encoder) - 1 and self.final_norm:
x = self.norm(x)
if i in self.out_indices:
B, _, C = x.shape
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
if self.output_cls_token:
out = [patch_token, cls_token]
else:
out = patch_token
outs.append(out)
return tuple(outs)

View File

@ -12,3 +12,4 @@ Import:
- configs/repvgg/metafile.yml - configs/repvgg/metafile.yml
- configs/tnt/metafile.yml - configs/tnt/metafile.yml
- configs/vision_transformer/metafile.yml - configs/vision_transformer/metafile.yml
- configs/t2t_vit/metafile.yml

View File

@ -0,0 +1,84 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import T2T_ViT
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_vit_backbone():
cfg_ori = dict(
img_size=224,
in_channels=3,
embed_dims=384,
t2t_cfg=dict(
token_dims=64,
use_performer=False,
),
num_layers=14,
layer_cfgs=dict(
num_heads=6,
feedforward_channels=3 * 384, # mlp_ratio = 3
),
drop_path_rate=0.1,
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
])
with pytest.raises(NotImplementedError):
# test if use performer
cfg = deepcopy(cfg_ori)
cfg['t2t_cfg']['use_performer'] = True
T2T_ViT(**cfg)
# Test T2T-ViT model with input size of 224
model = T2T_ViT(**cfg_ori)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
imgs = torch.randn(3, 3, 224, 224)
patch_token, cls_token = model(imgs)[-1]
assert cls_token.shape == (3, 384)
assert patch_token.shape == (3, 384, 14, 14)
# Test custom arch T2T-ViT without output cls token
cfg = deepcopy(cfg_ori)
cfg['embed_dims'] = 256
cfg['num_layers'] = 16
cfg['layer_cfgs'] = dict(num_heads=8, feedforward_channels=1024)
cfg['output_cls_token'] = False
model = T2T_ViT(**cfg)
patch_token = model(imgs)[-1]
assert patch_token.shape == (3, 256, 14, 14)
# Test T2T_ViT with multi out indices
cfg = deepcopy(cfg_ori)
cfg['out_indices'] = [-3, -2, -1]
model = T2T_ViT(**cfg)
for out in model(imgs):
assert out[0].shape == (3, 384, 14, 14)
assert out[1].shape == (3, 384)