mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
[Feature] Add Tokens-to-Token ViT backbone and converted checkpoints. (#467)
* add t2t backbone * register t2t_vit * add t2t_vit config * [Temp] Align posterize transform with timm. * Fix lint * Refactor t2t-vit * Add config for t2t-vit * Add metafile and README for t2t-vit * Add unit tests * configs * Update metafile and README * Improve docstring * Fix batch size which should be 8x64 instead of 8x128 * Fix typo * Update model zoo * Update training augments config. * Move some arguments of T2TModule to T2TViT * Update docs. * Update unit test Co-authored-by: HIT-cwh <2892770585@qq.com>
This commit is contained in:
parent
2ce5825ef1
commit
fffa30dd48
71
configs/_base_/datasets/imagenet_bs64_t2t_224.py
Normal file
71
configs/_base_/datasets/imagenet_bs64_t2t_224.py
Normal file
@ -0,0 +1,71 @@
|
|||||||
|
_base_ = ['./pipelines/rand_aug.py']
|
||||||
|
|
||||||
|
# dataset settings
|
||||||
|
dataset_type = 'ImageNet'
|
||||||
|
img_norm_cfg = dict(
|
||||||
|
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||||
|
|
||||||
|
train_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='RandomResizedCrop',
|
||||||
|
size=224,
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||||
|
dict(
|
||||||
|
type='RandAugment',
|
||||||
|
policies={{_base_.rand_increasing_policies}},
|
||||||
|
num_policies=2,
|
||||||
|
total_level=10,
|
||||||
|
magnitude_level=9,
|
||||||
|
magnitude_std=0.5,
|
||||||
|
hparams=dict(
|
||||||
|
pad_val=[round(x) for x in img_norm_cfg['mean'][::-1]],
|
||||||
|
interpolation='bicubic')),
|
||||||
|
dict(
|
||||||
|
type='RandomErasing',
|
||||||
|
erase_prob=0.25,
|
||||||
|
mode='rand',
|
||||||
|
min_area_ratio=0.02,
|
||||||
|
max_area_ratio=1 / 3,
|
||||||
|
fill_color=img_norm_cfg['mean'][::-1],
|
||||||
|
fill_std=img_norm_cfg['std'][::-1]),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='ToTensor', keys=['gt_label']),
|
||||||
|
dict(type='Collect', keys=['img', 'gt_label'])
|
||||||
|
]
|
||||||
|
|
||||||
|
test_pipeline = [
|
||||||
|
dict(type='LoadImageFromFile'),
|
||||||
|
dict(
|
||||||
|
type='Resize',
|
||||||
|
size=(248, -1),
|
||||||
|
backend='pillow',
|
||||||
|
interpolation='bicubic'),
|
||||||
|
dict(type='CenterCrop', crop_size=224),
|
||||||
|
dict(type='Normalize', **img_norm_cfg),
|
||||||
|
dict(type='ImageToTensor', keys=['img']),
|
||||||
|
dict(type='Collect', keys=['img'])
|
||||||
|
]
|
||||||
|
data = dict(
|
||||||
|
samples_per_gpu=64,
|
||||||
|
workers_per_gpu=4,
|
||||||
|
train=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_prefix='data/imagenet/train',
|
||||||
|
pipeline=train_pipeline),
|
||||||
|
val=dict(
|
||||||
|
type=dataset_type,
|
||||||
|
data_prefix='data/imagenet/val',
|
||||||
|
ann_file='data/imagenet/meta/val.txt',
|
||||||
|
pipeline=test_pipeline),
|
||||||
|
test=dict(
|
||||||
|
# replace `data/val` with `data/test` for standard test
|
||||||
|
type=dataset_type,
|
||||||
|
data_prefix='data/imagenet/val',
|
||||||
|
ann_file='data/imagenet/meta/val.txt',
|
||||||
|
pipeline=test_pipeline))
|
||||||
|
|
||||||
|
evaluation = dict(interval=10, metric='accuracy')
|
41
configs/_base_/models/t2t-vit-t-14.py
Normal file
41
configs/_base_/models/t2t-vit-t-14.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# model settings
|
||||||
|
embed_dims = 384
|
||||||
|
num_classes = 1000
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='T2T_ViT',
|
||||||
|
img_size=224,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
t2t_cfg=dict(
|
||||||
|
token_dims=64,
|
||||||
|
use_performer=False,
|
||||||
|
),
|
||||||
|
num_layers=14,
|
||||||
|
layer_cfgs=dict(
|
||||||
|
num_heads=6,
|
||||||
|
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
|
||||||
|
),
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||||
|
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||||
|
]),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='VisionTransformerClsHead',
|
||||||
|
num_classes=num_classes,
|
||||||
|
in_channels=embed_dims,
|
||||||
|
loss=dict(
|
||||||
|
type='LabelSmoothLoss',
|
||||||
|
label_smooth_val=0.1,
|
||||||
|
mode='original',
|
||||||
|
),
|
||||||
|
topk=(1, 5),
|
||||||
|
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||||
|
train_cfg=dict(augments=[
|
||||||
|
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||||
|
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||||
|
]))
|
41
configs/_base_/models/t2t-vit-t-19.py
Normal file
41
configs/_base_/models/t2t-vit-t-19.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# model settings
|
||||||
|
embed_dims = 448
|
||||||
|
num_classes = 1000
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='T2T_ViT',
|
||||||
|
img_size=224,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
t2t_cfg=dict(
|
||||||
|
token_dims=64,
|
||||||
|
use_performer=False,
|
||||||
|
),
|
||||||
|
num_layers=19,
|
||||||
|
layer_cfgs=dict(
|
||||||
|
num_heads=7,
|
||||||
|
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
|
||||||
|
),
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||||
|
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||||
|
]),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='VisionTransformerClsHead',
|
||||||
|
num_classes=num_classes,
|
||||||
|
in_channels=embed_dims,
|
||||||
|
loss=dict(
|
||||||
|
type='LabelSmoothLoss',
|
||||||
|
label_smooth_val=0.1,
|
||||||
|
mode='original',
|
||||||
|
),
|
||||||
|
topk=(1, 5),
|
||||||
|
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||||
|
train_cfg=dict(augments=[
|
||||||
|
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||||
|
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||||
|
]))
|
41
configs/_base_/models/t2t-vit-t-24.py
Normal file
41
configs/_base_/models/t2t-vit-t-24.py
Normal file
@ -0,0 +1,41 @@
|
|||||||
|
# model settings
|
||||||
|
embed_dims = 512
|
||||||
|
num_classes = 1000
|
||||||
|
|
||||||
|
model = dict(
|
||||||
|
type='ImageClassifier',
|
||||||
|
backbone=dict(
|
||||||
|
type='T2T_ViT',
|
||||||
|
img_size=224,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
t2t_cfg=dict(
|
||||||
|
token_dims=64,
|
||||||
|
use_performer=False,
|
||||||
|
),
|
||||||
|
num_layers=24,
|
||||||
|
layer_cfgs=dict(
|
||||||
|
num_heads=8,
|
||||||
|
feedforward_channels=3 * embed_dims, # mlp_ratio = 3
|
||||||
|
),
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||||
|
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||||
|
]),
|
||||||
|
neck=None,
|
||||||
|
head=dict(
|
||||||
|
type='VisionTransformerClsHead',
|
||||||
|
num_classes=num_classes,
|
||||||
|
in_channels=embed_dims,
|
||||||
|
loss=dict(
|
||||||
|
type='LabelSmoothLoss',
|
||||||
|
label_smooth_val=0.1,
|
||||||
|
mode='original',
|
||||||
|
),
|
||||||
|
topk=(1, 5),
|
||||||
|
init_cfg=dict(type='TruncNormal', layer='Linear', std=.02)),
|
||||||
|
train_cfg=dict(augments=[
|
||||||
|
dict(type='BatchMixup', alpha=0.8, prob=0.5, num_classes=num_classes),
|
||||||
|
dict(type='BatchCutMix', alpha=1.0, prob=0.5, num_classes=num_classes),
|
||||||
|
]))
|
33
configs/t2t_vit/README.md
Normal file
33
configs/t2t_vit/README.md
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
|
||||||
|
<!-- {Tokens-to-Token ViT} -->
|
||||||
|
|
||||||
|
## Introduction
|
||||||
|
|
||||||
|
<!-- [ALGORITHM] -->
|
||||||
|
|
||||||
|
```latex
|
||||||
|
@article{yuan2021tokens,
|
||||||
|
title={Tokens-to-token vit: Training vision transformers from scratch on imagenet},
|
||||||
|
author={Yuan, Li and Chen, Yunpeng and Wang, Tao and Yu, Weihao and Shi, Yujun and Tay, Francis EH and Feng, Jiashi and Yan, Shuicheng},
|
||||||
|
journal={arXiv preprint arXiv:2101.11986},
|
||||||
|
year={2021}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
## Pretrain model
|
||||||
|
|
||||||
|
The pre-trained modles are converted from [official repo](https://github.com/yitu-opensource/T2T-ViT/tree/main#2-t2t-vit-models).
|
||||||
|
|
||||||
|
### ImageNet-1k
|
||||||
|
|
||||||
|
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Download |
|
||||||
|
|:--------------:|:---------:|:--------:|:---------:|:---------:|:--------:|
|
||||||
|
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth) | [log]()|
|
||||||
|
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth) | [log]()|
|
||||||
|
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
|
||||||
|
|
||||||
|
*Models with \* are converted from other repos.*
|
||||||
|
|
||||||
|
## Results and models
|
||||||
|
|
||||||
|
Waiting for adding.
|
64
configs/t2t_vit/metafile.yml
Normal file
64
configs/t2t_vit/metafile.yml
Normal file
@ -0,0 +1,64 @@
|
|||||||
|
Collections:
|
||||||
|
- Name: Tokens-to-Token ViT
|
||||||
|
Metadata:
|
||||||
|
Training Data: ImageNet-1k
|
||||||
|
Architecture:
|
||||||
|
- Layer Normalization
|
||||||
|
- Scaled Dot-Product Attention
|
||||||
|
- Attention Dropout
|
||||||
|
- Dropout
|
||||||
|
- Tokens to Token
|
||||||
|
Paper:
|
||||||
|
URL: https://arxiv.org/abs/2101.11986
|
||||||
|
Title: "Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet"
|
||||||
|
README: configs/t2t_vit/README.md
|
||||||
|
|
||||||
|
Models:
|
||||||
|
- Name: t2t-vit-t-14_3rdparty_8xb64_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 4340000000
|
||||||
|
Parameters: 21470000
|
||||||
|
In Collection: Tokens-to-Token ViT
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 81.69
|
||||||
|
Top 5 Accuracy: 95.85
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-b7c09b62.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/81.7_T2T_ViTt_14.pth.tar
|
||||||
|
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L243
|
||||||
|
Config: configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
|
||||||
|
- Name: t2t-vit-t-19_3rdparty_8xb64_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 7800000000
|
||||||
|
Parameters: 39080000
|
||||||
|
In Collection: Tokens-to-Token ViT
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 82.43
|
||||||
|
Top 5 Accuracy: 96.08
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-7f1478d5.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.4_T2T_ViTt_19.pth.tar
|
||||||
|
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L254
|
||||||
|
Config: configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py
|
||||||
|
- Name: t2t-vit-t-24_3rdparty_8xb64_in1k
|
||||||
|
Metadata:
|
||||||
|
FLOPs: 12690000000
|
||||||
|
Parameters: 64000000
|
||||||
|
In Collection: Tokens-to-Token ViT
|
||||||
|
Results:
|
||||||
|
- Dataset: ImageNet-1k
|
||||||
|
Metrics:
|
||||||
|
Top 1 Accuracy: 82.55
|
||||||
|
Top 5 Accuracy: 96.06
|
||||||
|
Task: Image Classification
|
||||||
|
Weights: https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth
|
||||||
|
Converted From:
|
||||||
|
Weights: https://github.com/yitu-opensource/T2T-ViT/releases/download/main/82.6_T2T_ViTt_24.pth.tar
|
||||||
|
Code: https://github.com/yitu-opensource/T2T-ViT/blob/main/models/t2t_vit.py#L265
|
||||||
|
Config: configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py
|
31
configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
Normal file
31
configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/t2t-vit-t-14.py',
|
||||||
|
'../_base_/datasets/imagenet_bs64_t2t_224.py',
|
||||||
|
'../_base_/default_runtime.py',
|
||||||
|
]
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
paramwise_cfg = dict(
|
||||||
|
bias_decay_mult=0.0,
|
||||||
|
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
|
||||||
|
)
|
||||||
|
optimizer = dict(
|
||||||
|
type='AdamW',
|
||||||
|
lr=5e-4,
|
||||||
|
weight_decay=0.05,
|
||||||
|
paramwise_cfg=paramwise_cfg,
|
||||||
|
)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
|
||||||
|
# learning policy
|
||||||
|
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
|
||||||
|
# the lr in the last 10 epoch equals to min_lr
|
||||||
|
lr_config = dict(
|
||||||
|
policy='CosineAnnealing',
|
||||||
|
min_lr=1e-5,
|
||||||
|
by_epoch=True,
|
||||||
|
warmup_by_epoch=True,
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=10,
|
||||||
|
warmup_ratio=1e-6)
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=310)
|
31
configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py
Normal file
31
configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/t2t-vit-t-19.py',
|
||||||
|
'../_base_/datasets/imagenet_bs64_t2t_224.py',
|
||||||
|
'../_base_/default_runtime.py',
|
||||||
|
]
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
paramwise_cfg = dict(
|
||||||
|
bias_decay_mult=0.0,
|
||||||
|
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
|
||||||
|
)
|
||||||
|
optimizer = dict(
|
||||||
|
type='AdamW',
|
||||||
|
lr=5e-4,
|
||||||
|
weight_decay=0.065,
|
||||||
|
paramwise_cfg=paramwise_cfg,
|
||||||
|
)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
|
||||||
|
# learning policy
|
||||||
|
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
|
||||||
|
# the lr in the last 10 epoch equals to min_lr
|
||||||
|
lr_config = dict(
|
||||||
|
policy='CosineAnnealing',
|
||||||
|
min_lr=1e-5,
|
||||||
|
by_epoch=True,
|
||||||
|
warmup_by_epoch=True,
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=10,
|
||||||
|
warmup_ratio=1e-6)
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=310)
|
31
configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py
Normal file
31
configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py
Normal file
@ -0,0 +1,31 @@
|
|||||||
|
_base_ = [
|
||||||
|
'../_base_/models/t2t-vit-t-24.py',
|
||||||
|
'../_base_/datasets/imagenet_bs64_t2t_224.py',
|
||||||
|
'../_base_/default_runtime.py',
|
||||||
|
]
|
||||||
|
|
||||||
|
# optimizer
|
||||||
|
paramwise_cfg = dict(
|
||||||
|
bias_decay_mult=0.0,
|
||||||
|
custom_keys={'.backbone.cls_token': dict(decay_mult=0.0)},
|
||||||
|
)
|
||||||
|
optimizer = dict(
|
||||||
|
type='AdamW',
|
||||||
|
lr=5e-4,
|
||||||
|
weight_decay=0.065,
|
||||||
|
paramwise_cfg=paramwise_cfg,
|
||||||
|
)
|
||||||
|
optimizer_config = dict(grad_clip=None)
|
||||||
|
|
||||||
|
# learning policy
|
||||||
|
# FIXME: lr in the first 300 epochs conforms to the CosineAnnealing and
|
||||||
|
# the lr in the last 10 epoch equals to min_lr
|
||||||
|
lr_config = dict(
|
||||||
|
policy='CosineAnnealing',
|
||||||
|
min_lr=1e-5,
|
||||||
|
by_epoch=True,
|
||||||
|
warmup_by_epoch=True,
|
||||||
|
warmup='linear',
|
||||||
|
warmup_iters=10,
|
||||||
|
warmup_ratio=1e-6)
|
||||||
|
runner = dict(type='EpochBasedRunner', max_epochs=310)
|
@ -58,6 +58,9 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||||||
| Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219.log.json)|
|
| Swin-Transformer small| 49.61 | 8.52 | 83.02 | 96.29 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_small_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219-7f9d988b.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_small_224_b16x64_300e_imagenet_20210615_110219.log.json)|
|
||||||
| Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742.log.json)|
|
| Swin-Transformer base | 87.77 | 15.14 | 83.36 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/swin_transformer/swin_base_224_b16x64_300e_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742-93230b0d.pth) | [log](https://download.openmmlab.com/mmclassification/v0/swin-transformer/swin_base_224_b16x64_300e_imagenet_20210616_190742.log.json)|
|
||||||
| Transformer in Transformer small\* | 23.76 | 3.36 | 81.52 | 95.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/tnt/tnt_s_patch16_224_evalonly_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth) | [log]()|
|
| Transformer in Transformer small\* | 23.76 | 3.36 | 81.52 | 95.73 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/tnt/tnt_s_patch16_224_evalonly_imagenet.py) | [model](https://download.openmmlab.com/mmclassification/v0/tnt/tnt-small-p16_3rdparty_in1k_20210903-c56ee7df.pth) | [log]()|
|
||||||
|
| T2T-ViT_t-14\* | 21.47 | 4.34 | 81.69 | 95.85 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-14_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-14_3rdparty_8xb64_in1k_20210928-420df0f6.pth) | [log]()|
|
||||||
|
| T2T-ViT_t-19\* | 39.08 | 7.80 | 82.43 | 96.08 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-19_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-19_3rdparty_8xb64_in1k_20210928-e479c2a6.pth) | [log]()|
|
||||||
|
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-b5bf2526.pth) | [log]()|
|
||||||
|
|
||||||
Models with * are converted from other repos, others are trained by ourselves.
|
Models with * are converted from other repos, others are trained by ourselves.
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
import copy
|
import copy
|
||||||
import inspect
|
import inspect
|
||||||
import random
|
import random
|
||||||
|
from math import ceil
|
||||||
from numbers import Number
|
from numbers import Number
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
@ -668,7 +669,8 @@ class Posterize(object):
|
|||||||
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
assert 0 <= prob <= 1.0, 'The prob should be in range [0,1], ' \
|
||||||
f'got {prob} instead.'
|
f'got {prob} instead.'
|
||||||
|
|
||||||
self.bits = int(bits)
|
# To align timm version, we need to round up to integer here.
|
||||||
|
self.bits = ceil(bits)
|
||||||
self.prob = prob
|
self.prob = prob
|
||||||
|
|
||||||
def __call__(self, results):
|
def __call__(self, results):
|
||||||
|
@ -15,6 +15,7 @@ from .seresnext import SEResNeXt
|
|||||||
from .shufflenet_v1 import ShuffleNetV1
|
from .shufflenet_v1 import ShuffleNetV1
|
||||||
from .shufflenet_v2 import ShuffleNetV2
|
from .shufflenet_v2 import ShuffleNetV2
|
||||||
from .swin_transformer import SwinTransformer
|
from .swin_transformer import SwinTransformer
|
||||||
|
from .t2t_vit import T2T_ViT
|
||||||
from .timm_backbone import TIMMBackbone
|
from .timm_backbone import TIMMBackbone
|
||||||
from .tnt import TNT
|
from .tnt import TNT
|
||||||
from .vgg import VGG
|
from .vgg import VGG
|
||||||
@ -24,5 +25,5 @@ __all__ = [
|
|||||||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'Res2Net', 'RepVGG'
|
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG'
|
||||||
]
|
]
|
||||||
|
367
mmcls/models/backbones/t2t_vit.py
Normal file
367
mmcls/models/backbones/t2t_vit.py
Normal file
@ -0,0 +1,367 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from copy import deepcopy
|
||||||
|
from typing import Sequence
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
import torch.nn as nn
|
||||||
|
from mmcv.cnn import build_norm_layer
|
||||||
|
from mmcv.cnn.bricks.transformer import FFN
|
||||||
|
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||||
|
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||||
|
|
||||||
|
from ..builder import BACKBONES
|
||||||
|
from ..utils import MultiheadAttention
|
||||||
|
from .base_backbone import BaseBackbone
|
||||||
|
|
||||||
|
|
||||||
|
class T2TTransformerLayer(BaseModule):
|
||||||
|
"""Transformer Layer for T2T_ViT.
|
||||||
|
|
||||||
|
Comparing with :obj:`TransformerEncoderLayer` in ViT, it supports
|
||||||
|
different ``input_dims`` and ``embed_dims``.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
embed_dims (int): The feature dimension.
|
||||||
|
num_heads (int): Parallel attention heads.
|
||||||
|
feedforward_channels (int): The hidden dimension for FFNs
|
||||||
|
input_dims (int, optional): The input token dimension.
|
||||||
|
Defaults to None.
|
||||||
|
drop_rate (float): Probability of an element to be zeroed
|
||||||
|
after the feed forward layer. Defaults to 0.
|
||||||
|
attn_drop_rate (float): The drop out rate for attention output weights.
|
||||||
|
Defaults to 0.
|
||||||
|
drop_path_rate (float): Stochastic depth rate. Defaults to 0.
|
||||||
|
num_fcs (int): The number of fully-connected layers for FFNs.
|
||||||
|
Defaults to 2.
|
||||||
|
qkv_bias (bool): enable bias for qkv if True. Defaults to True.
|
||||||
|
qk_scale (float, optional): Override default qk scale of
|
||||||
|
``(input_dims // num_heads) ** -0.5`` if set. Defaults to None.
|
||||||
|
act_cfg (dict): The activation config for FFNs.
|
||||||
|
Defaluts to ``dict(type='GELU')``.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer.
|
||||||
|
Defaults to ``dict(type='LN')``.
|
||||||
|
init_cfg (dict, optional): Initialization config dict.
|
||||||
|
Defaults to None.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
In general, ``qk_scale`` should be ``head_dims ** -0.5``, i.e.
|
||||||
|
``(embed_dims // num_heads) ** -0.5``. However, in the official
|
||||||
|
code, it uses ``(input_dims // num_heads) ** -0.5``, so here we
|
||||||
|
keep the same with the official implementation.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
embed_dims,
|
||||||
|
num_heads,
|
||||||
|
feedforward_channels,
|
||||||
|
input_dims=None,
|
||||||
|
drop_rate=0.,
|
||||||
|
attn_drop_rate=0.,
|
||||||
|
drop_path_rate=0.,
|
||||||
|
num_fcs=2,
|
||||||
|
qkv_bias=False,
|
||||||
|
qk_scale=None,
|
||||||
|
act_cfg=dict(type='GELU'),
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
init_cfg=None):
|
||||||
|
super(T2TTransformerLayer, self).__init__(init_cfg=init_cfg)
|
||||||
|
|
||||||
|
self.v_shortcut = True if input_dims is not None else False
|
||||||
|
input_dims = input_dims or embed_dims
|
||||||
|
|
||||||
|
self.norm1_name, norm1 = build_norm_layer(
|
||||||
|
norm_cfg, input_dims, postfix=1)
|
||||||
|
self.add_module(self.norm1_name, norm1)
|
||||||
|
|
||||||
|
self.attn = MultiheadAttention(
|
||||||
|
input_dims=input_dims,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
num_heads=num_heads,
|
||||||
|
attn_drop=attn_drop_rate,
|
||||||
|
proj_drop=drop_rate,
|
||||||
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||||
|
qkv_bias=qkv_bias,
|
||||||
|
qk_scale=qk_scale or (input_dims // num_heads)**-0.5,
|
||||||
|
v_shortcut=self.v_shortcut)
|
||||||
|
|
||||||
|
self.norm2_name, norm2 = build_norm_layer(
|
||||||
|
norm_cfg, embed_dims, postfix=2)
|
||||||
|
self.add_module(self.norm2_name, norm2)
|
||||||
|
|
||||||
|
self.ffn = FFN(
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
feedforward_channels=feedforward_channels,
|
||||||
|
num_fcs=num_fcs,
|
||||||
|
ffn_drop=drop_rate,
|
||||||
|
dropout_layer=dict(type='DropPath', drop_prob=drop_path_rate),
|
||||||
|
act_cfg=act_cfg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm1(self):
|
||||||
|
return getattr(self, self.norm1_name)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def norm2(self):
|
||||||
|
return getattr(self, self.norm2_name)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
if self.v_shortcut:
|
||||||
|
x = self.attn(self.norm1(x))
|
||||||
|
else:
|
||||||
|
x = x + self.attn(self.norm1(x))
|
||||||
|
x = self.ffn(self.norm2(x), identity=x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class T2TModule(BaseModule):
|
||||||
|
"""Tokens-to-Token module.
|
||||||
|
|
||||||
|
"Tokens-to-Token module" (T2T Module) can model the local structure
|
||||||
|
information of images and reduce the length of tokens progressively.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size (int): Input image size
|
||||||
|
in_channels (int): Number of input channels
|
||||||
|
embed_dims (int): Embedding dimension
|
||||||
|
token_dims (int): Tokens dimension in T2TModuleAttention.
|
||||||
|
use_performer (bool): If True, use Performer version self-attention to
|
||||||
|
adopt regular self-attention. Defaults to False.
|
||||||
|
init_cfg (dict, optional): The extra config for initialization.
|
||||||
|
Default: None.
|
||||||
|
|
||||||
|
Notes:
|
||||||
|
Usually, ``token_dim`` is set as a small value (32 or 64) to reduce
|
||||||
|
MACs
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
img_size=224,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dims=384,
|
||||||
|
token_dims=64,
|
||||||
|
use_performer=False,
|
||||||
|
init_cfg=None,
|
||||||
|
):
|
||||||
|
super(T2TModule, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
self.embed_dims = embed_dims
|
||||||
|
|
||||||
|
self.soft_split0 = nn.Unfold(
|
||||||
|
kernel_size=(7, 7), stride=(4, 4), padding=(2, 2))
|
||||||
|
self.soft_split1 = nn.Unfold(
|
||||||
|
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||||
|
self.soft_split2 = nn.Unfold(
|
||||||
|
kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
|
||||||
|
|
||||||
|
if not use_performer:
|
||||||
|
self.attention1 = T2TTransformerLayer(
|
||||||
|
input_dims=in_channels * 7 * 7,
|
||||||
|
embed_dims=token_dims,
|
||||||
|
num_heads=1,
|
||||||
|
feedforward_channels=token_dims)
|
||||||
|
|
||||||
|
self.attention2 = T2TTransformerLayer(
|
||||||
|
input_dims=token_dims * 3 * 3,
|
||||||
|
embed_dims=token_dims,
|
||||||
|
num_heads=1,
|
||||||
|
feedforward_channels=token_dims)
|
||||||
|
|
||||||
|
self.project = nn.Linear(token_dims * 3 * 3, embed_dims)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError("Performer hasn't been implemented.")
|
||||||
|
|
||||||
|
# there are 3 soft split, stride are 4,2,2 separately
|
||||||
|
self.num_patches = (img_size // (4 * 2 * 2))**2
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
# step0: soft split
|
||||||
|
x = self.soft_split0(x).transpose(1, 2)
|
||||||
|
|
||||||
|
for step in [1, 2]:
|
||||||
|
# re-structurization/reconstruction
|
||||||
|
attn = getattr(self, f'attention{step}')
|
||||||
|
x = attn(x).transpose(1, 2)
|
||||||
|
B, C, new_HW = x.shape
|
||||||
|
x = x.reshape(B, C, int(np.sqrt(new_HW)), int(np.sqrt(new_HW)))
|
||||||
|
|
||||||
|
# soft split
|
||||||
|
soft_split = getattr(self, f'soft_split{step}')
|
||||||
|
x = soft_split(x).transpose(1, 2)
|
||||||
|
|
||||||
|
# final tokens
|
||||||
|
x = self.project(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
def get_sinusoid_encoding(n_position, embed_dims):
|
||||||
|
"""Generate sinusoid encoding table.
|
||||||
|
|
||||||
|
Sinusoid encoding is a kind of relative position encoding method came from
|
||||||
|
`Attention Is All You Need<https://arxiv.org/abs/1706.03762>`_.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
n_position (int): The length of the input token.
|
||||||
|
embed_dims (int): The position embedding dimension.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
:obj:`torch.FloatTensor`: The sinusoid encoding table.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def get_position_angle_vec(position):
|
||||||
|
return [
|
||||||
|
position / np.power(10000, 2 * (i // 2) / embed_dims)
|
||||||
|
for i in range(embed_dims)
|
||||||
|
]
|
||||||
|
|
||||||
|
sinusoid_table = np.array(
|
||||||
|
[get_position_angle_vec(pos) for pos in range(n_position)])
|
||||||
|
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
|
||||||
|
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
|
||||||
|
|
||||||
|
return torch.FloatTensor(sinusoid_table).unsqueeze(0)
|
||||||
|
|
||||||
|
|
||||||
|
@BACKBONES.register_module()
|
||||||
|
class T2T_ViT(BaseBackbone):
|
||||||
|
"""Tokens-to-Token Vision Transformer (T2T-ViT)
|
||||||
|
|
||||||
|
A PyTorch implementation of `Tokens-to-Token ViT: Training Vision
|
||||||
|
Transformers from Scratch on ImageNet<https://arxiv.org/abs/2101.11986>`_
|
||||||
|
|
||||||
|
Args:
|
||||||
|
img_size (int): Input image size.
|
||||||
|
in_channels (int): Number of input channels.
|
||||||
|
embed_dims (int): Embedding dimension.
|
||||||
|
t2t_cfg (dict): Extra config of Tokens-to-Token module.
|
||||||
|
Defaults to an empty dict.
|
||||||
|
drop_rate (float): Dropout rate after position embedding.
|
||||||
|
Defaults to 0.
|
||||||
|
num_layers (int): Num of transformer layers in encoder.
|
||||||
|
Defaults to 14.
|
||||||
|
out_indices (Sequence | int): Output from which stages.
|
||||||
|
Defaults to -1, means the last stage.
|
||||||
|
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||||
|
encoder. Defaults to an empty dict.
|
||||||
|
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||||
|
norm_cfg (dict): Config dict for normalization layer. Defaults to
|
||||||
|
``dict(type='LN')``.
|
||||||
|
final_norm (bool): Whether to add a additional layer to normalize
|
||||||
|
final feature map. Defaults to True.
|
||||||
|
output_cls_token (bool): Whether output the cls_token.
|
||||||
|
Defaults to True.
|
||||||
|
init_cfg (dict, optional): The Config for initialization.
|
||||||
|
Defaults to None.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self,
|
||||||
|
img_size=224,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dims=384,
|
||||||
|
t2t_cfg=dict(),
|
||||||
|
drop_rate=0.,
|
||||||
|
num_layers=14,
|
||||||
|
out_indices=-1,
|
||||||
|
layer_cfgs=dict(),
|
||||||
|
drop_path_rate=0.,
|
||||||
|
norm_cfg=dict(type='LN'),
|
||||||
|
final_norm=True,
|
||||||
|
output_cls_token=True,
|
||||||
|
init_cfg=None):
|
||||||
|
super(T2T_ViT, self).__init__(init_cfg)
|
||||||
|
|
||||||
|
# Token-to-Token Module
|
||||||
|
self.tokens_to_token = T2TModule(
|
||||||
|
img_size=img_size,
|
||||||
|
in_channels=in_channels,
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
**t2t_cfg)
|
||||||
|
num_patches = self.tokens_to_token.num_patches
|
||||||
|
|
||||||
|
# Class token
|
||||||
|
self.output_cls_token = output_cls_token
|
||||||
|
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims))
|
||||||
|
|
||||||
|
# Position Embedding
|
||||||
|
sinusoid_table = get_sinusoid_encoding(num_patches + 1, embed_dims)
|
||||||
|
self.register_buffer('pos_embed', sinusoid_table)
|
||||||
|
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||||
|
|
||||||
|
if isinstance(out_indices, int):
|
||||||
|
out_indices = [out_indices]
|
||||||
|
assert isinstance(out_indices, Sequence), \
|
||||||
|
f'"out_indices" must by a sequence or int, ' \
|
||||||
|
f'get {type(out_indices)} instead.'
|
||||||
|
for i, index in enumerate(out_indices):
|
||||||
|
if index < 0:
|
||||||
|
out_indices[i] = num_layers + index
|
||||||
|
assert out_indices[i] >= 0, f'Invalid out_indices {index}'
|
||||||
|
self.out_indices = out_indices
|
||||||
|
|
||||||
|
dpr = [x for x in np.linspace(0, drop_path_rate, num_layers)]
|
||||||
|
self.encoder = ModuleList()
|
||||||
|
for i in range(num_layers):
|
||||||
|
if isinstance(layer_cfgs, Sequence):
|
||||||
|
layer_cfg = layer_cfgs[i]
|
||||||
|
else:
|
||||||
|
layer_cfg = deepcopy(layer_cfgs)
|
||||||
|
layer_cfg = {
|
||||||
|
'embed_dims': embed_dims,
|
||||||
|
'num_heads': 6,
|
||||||
|
'feedforward_channels': 3 * embed_dims,
|
||||||
|
'drop_path_rate': dpr[i],
|
||||||
|
'qkv_bias': False,
|
||||||
|
'norm_cfg': norm_cfg,
|
||||||
|
**layer_cfg
|
||||||
|
}
|
||||||
|
|
||||||
|
layer = T2TTransformerLayer(**layer_cfg)
|
||||||
|
self.encoder.append(layer)
|
||||||
|
|
||||||
|
self.final_norm = final_norm
|
||||||
|
if final_norm:
|
||||||
|
self.norm = build_norm_layer(norm_cfg, embed_dims)[1]
|
||||||
|
else:
|
||||||
|
self.norm = nn.Identity()
|
||||||
|
|
||||||
|
def init_weights(self):
|
||||||
|
super().init_weights()
|
||||||
|
|
||||||
|
if (isinstance(self.init_cfg, dict)
|
||||||
|
and self.init_cfg['type'] == 'Pretrained'):
|
||||||
|
# Suppress custom init if use pretrained model.
|
||||||
|
return
|
||||||
|
|
||||||
|
trunc_normal_(self.cls_token, std=.02)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
B = x.shape[0]
|
||||||
|
x = self.tokens_to_token(x)
|
||||||
|
num_patches = self.tokens_to_token.num_patches
|
||||||
|
patch_resolution = [int(np.sqrt(num_patches))] * 2
|
||||||
|
|
||||||
|
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||||
|
x = torch.cat((cls_tokens, x), dim=1)
|
||||||
|
x = x + self.pos_embed
|
||||||
|
x = self.drop_after_pos(x)
|
||||||
|
|
||||||
|
outs = []
|
||||||
|
for i, layer in enumerate(self.encoder):
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
if i == len(self.encoder) - 1 and self.final_norm:
|
||||||
|
x = self.norm(x)
|
||||||
|
|
||||||
|
if i in self.out_indices:
|
||||||
|
B, _, C = x.shape
|
||||||
|
patch_token = x[:, 1:].reshape(B, *patch_resolution, C)
|
||||||
|
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||||
|
cls_token = x[:, 0]
|
||||||
|
if self.output_cls_token:
|
||||||
|
out = [patch_token, cls_token]
|
||||||
|
else:
|
||||||
|
out = patch_token
|
||||||
|
outs.append(out)
|
||||||
|
|
||||||
|
return tuple(outs)
|
@ -12,3 +12,4 @@ Import:
|
|||||||
- configs/repvgg/metafile.yml
|
- configs/repvgg/metafile.yml
|
||||||
- configs/tnt/metafile.yml
|
- configs/tnt/metafile.yml
|
||||||
- configs/vision_transformer/metafile.yml
|
- configs/vision_transformer/metafile.yml
|
||||||
|
- configs/t2t_vit/metafile.yml
|
||||||
|
84
tests/test_models/test_backbones/test_t2t_vit.py
Normal file
84
tests/test_models/test_backbones/test_t2t_vit.py
Normal file
@ -0,0 +1,84 @@
|
|||||||
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||||||
|
from copy import deepcopy
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
from torch.nn.modules import GroupNorm
|
||||||
|
from torch.nn.modules.batchnorm import _BatchNorm
|
||||||
|
|
||||||
|
from mmcls.models.backbones import T2T_ViT
|
||||||
|
|
||||||
|
|
||||||
|
def is_norm(modules):
|
||||||
|
"""Check if is one of the norms."""
|
||||||
|
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def check_norm_state(modules, train_state):
|
||||||
|
"""Check if norm layer is in correct train state."""
|
||||||
|
for mod in modules:
|
||||||
|
if isinstance(mod, _BatchNorm):
|
||||||
|
if mod.training != train_state:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def test_vit_backbone():
|
||||||
|
|
||||||
|
cfg_ori = dict(
|
||||||
|
img_size=224,
|
||||||
|
in_channels=3,
|
||||||
|
embed_dims=384,
|
||||||
|
t2t_cfg=dict(
|
||||||
|
token_dims=64,
|
||||||
|
use_performer=False,
|
||||||
|
),
|
||||||
|
num_layers=14,
|
||||||
|
layer_cfgs=dict(
|
||||||
|
num_heads=6,
|
||||||
|
feedforward_channels=3 * 384, # mlp_ratio = 3
|
||||||
|
),
|
||||||
|
drop_path_rate=0.1,
|
||||||
|
init_cfg=[
|
||||||
|
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||||
|
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||||
|
])
|
||||||
|
|
||||||
|
with pytest.raises(NotImplementedError):
|
||||||
|
# test if use performer
|
||||||
|
cfg = deepcopy(cfg_ori)
|
||||||
|
cfg['t2t_cfg']['use_performer'] = True
|
||||||
|
T2T_ViT(**cfg)
|
||||||
|
|
||||||
|
# Test T2T-ViT model with input size of 224
|
||||||
|
model = T2T_ViT(**cfg_ori)
|
||||||
|
model.init_weights()
|
||||||
|
model.train()
|
||||||
|
|
||||||
|
assert check_norm_state(model.modules(), True)
|
||||||
|
|
||||||
|
imgs = torch.randn(3, 3, 224, 224)
|
||||||
|
patch_token, cls_token = model(imgs)[-1]
|
||||||
|
assert cls_token.shape == (3, 384)
|
||||||
|
assert patch_token.shape == (3, 384, 14, 14)
|
||||||
|
|
||||||
|
# Test custom arch T2T-ViT without output cls token
|
||||||
|
cfg = deepcopy(cfg_ori)
|
||||||
|
cfg['embed_dims'] = 256
|
||||||
|
cfg['num_layers'] = 16
|
||||||
|
cfg['layer_cfgs'] = dict(num_heads=8, feedforward_channels=1024)
|
||||||
|
cfg['output_cls_token'] = False
|
||||||
|
|
||||||
|
model = T2T_ViT(**cfg)
|
||||||
|
patch_token = model(imgs)[-1]
|
||||||
|
assert patch_token.shape == (3, 256, 14, 14)
|
||||||
|
|
||||||
|
# Test T2T_ViT with multi out indices
|
||||||
|
cfg = deepcopy(cfg_ori)
|
||||||
|
cfg['out_indices'] = [-3, -2, -1]
|
||||||
|
model = T2T_ViT(**cfg)
|
||||||
|
for out in model(imgs):
|
||||||
|
assert out[0].shape == (3, 384, 14, 14)
|
||||||
|
assert out[1].shape == (3, 384)
|
Loading…
x
Reference in New Issue
Block a user