[Feature] Add DeiT backbone and checkpoints. (#576)

* Support DeiT backbone.

* Use hook to automatically resize pos embed

* Update ViT training setting

* Add deit configs and update docs

* Fix vit arch assertion

* Remove useless init function

* Add unit tests.

* Fix resize_pos_embed for DeiT

* Improve according to comments.
pull/607/head
Ma Zerun 2021-12-15 22:44:57 +08:00 committed by GitHub
parent 6f25bebe42
commit f9a2b04cee
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 619 additions and 72 deletions

View File

@ -8,7 +8,11 @@ img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='RandomResizedCrop', size=224, backend='pillow'),
dict(
type='RandomResizedCrop',
size=224,
backend='pillow',
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='AutoAugment', policies={{_base_.policy_imagenet}}),
dict(type='Normalize', **img_norm_cfg),
@ -18,7 +22,11 @@ train_pipeline = [
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='Resize', size=(256, -1), backend='pillow'),
dict(
type='Resize',
size=(256, -1),
backend='pillow',
interpolation='bicubic'),
dict(type='CenterCrop', crop_size=224),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),

View File

@ -1,18 +1,24 @@
# specific to vit pretrain
paramwise_cfg = dict(custom_keys={
'.cls_token': dict(decay_mult=0.0),
'.pos_embed': dict(decay_mult=0.0)
})
# optimizer
optimizer = dict(type='AdamW', lr=0.003, weight_decay=0.3)
optimizer = dict(
type='AdamW',
lr=0.003,
weight_decay=0.3,
paramwise_cfg=paramwise_cfg,
)
optimizer_config = dict(grad_clip=dict(max_norm=1.0))
# specific to vit pretrain
paramwise_cfg = dict(
custom_keys={
'.backbone.cls_token': dict(decay_mult=0.0),
'.backbone.pos_embed': dict(decay_mult=0.0)
})
# learning policy
lr_config = dict(
policy='CosineAnnealing',
min_lr=0,
warmup='linear',
warmup_iters=10000,
warmup_ratio=1e-4)
warmup_ratio=1e-4,
)
runner = dict(type='EpochBasedRunner', max_epochs=300)

View File

@ -0,0 +1,61 @@
# Training data-efficient image transformers & distillation through attention
<!-- {DeiT} -->
<!-- [ALGORITHM] -->
## Abstract
<!-- [ABSTRACT] -->
Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/143225703-c287c29e-82c9-4c85-a366-dfae30d198cd.png" width="40%"/>
</div>
## Citation
```{latex}
@InProceedings{pmlr-v139-touvron21a,
title = {Training data-efficient image transformers &amp; distillation through attention},
author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
booktitle = {International Conference on Machine Learning},
pages = {10347--10357},
year = {2021},
volume = {139},
month = {July}
}
```
## Pretrained models
The pre-trained models are converted from the [official repo](https://github.com/facebookresearch/deit). And the teacher of the distilled version DeiT is RegNetY-16GF.
### ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) |
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) |
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) |
*Models with \* are converted from other repos.*
## Fine-tuned models
The fine-tuned models are converted from the [official repo](https://github.com/facebookresearch/deit).
### ImageNet-1k
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth) |
*Models with \* are converted from other repos.*
```{warning}
MMClassification doesn't support training the distilled version DeiT.
And we provide distilled version checkpoints for inference only.
```

View File

@ -0,0 +1,9 @@
_base_ = './deit-base_ft-16xb32_in1k-384px.py'
# model settings
model = dict(
backbone=dict(type='DistilledVisionTransformer'),
head=dict(type='DeiTClsHead'),
# Change to the path of the pretrained model
# init_cfg=dict(type='Pretrained', checkpoint=''),
)

View File

@ -0,0 +1,10 @@
_base_ = './deit-small_pt-4xb256_in1k.py'
# model settings
model = dict(
backbone=dict(type='DistilledVisionTransformer', arch='deit-base'),
head=dict(type='DeiTClsHead', in_channels=768),
)
# data settings
data = dict(samples_per_gpu=64, workers_per_gpu=5)

View File

@ -0,0 +1,29 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_swin_384.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='deit-base',
img_size=384,
patch_size=16,
),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=768,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
),
# Change to the path of the pretrained model
# init_cfg=dict(type='Pretrained', checkpoint=''),
)
# data settings
data = dict(samples_per_gpu=32, workers_per_gpu=5)

View File

@ -0,0 +1,10 @@
_base_ = './deit-small_pt-4xb256_in1k.py'
# model settings
model = dict(
backbone=dict(type='VisionTransformer', arch='deit-base'),
head=dict(type='VisionTransformerClsHead', in_channels=768),
)
# data settings
data = dict(samples_per_gpu=64, workers_per_gpu=5)

View File

@ -0,0 +1,7 @@
_base_ = './deit-small_pt-4xb256_in1k.py'
# model settings
model = dict(
backbone=dict(type='DistilledVisionTransformer', arch='deit-small'),
head=dict(type='DeiTClsHead', in_channels=384),
)

View File

@ -0,0 +1,29 @@
_base_ = [
'../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py',
'../_base_/schedules/imagenet_bs4096_AdamW.py',
'../_base_/default_runtime.py'
]
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(
type='VisionTransformer',
arch='deit-small',
img_size=224,
patch_size=16),
neck=None,
head=dict(
type='VisionTransformerClsHead',
num_classes=1000,
in_channels=384,
loss=dict(
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
),
init_cfg=[
dict(type='TruncNormal', layer='Linear', std=.02),
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
])
# data settings
data = dict(samples_per_gpu=256, workers_per_gpu=5)

View File

@ -0,0 +1,7 @@
_base_ = './deit-small_pt-4xb256_in1k.py'
# model settings
model = dict(
backbone=dict(type='DistilledVisionTransformer', arch='deit-tiny'),
head=dict(type='DeiTClsHead', in_channels=192),
)

View File

@ -0,0 +1,7 @@
_base_ = './deit-small_pt-4xb256_in1k.py'
# model settings
model = dict(
backbone=dict(type='VisionTransformer', arch='deit-tiny'),
head=dict(type='VisionTransformerClsHead', in_channels=192),
)

View File

@ -0,0 +1,143 @@
Collections:
- Name: DeiT
Metadata:
Training Data: ImageNet-1k
Architecture:
- Layer Normalization
- Scaled Dot-Product Attention
- Attention Dropout
- Multi-Head Attention
Paper:
URL: https://arxiv.org/abs/2012.12877
Title: "Training data-efficient image transformers & distillation through attention"
README: configs/deit/README.md
Models:
- Name: deit-tiny_3rdparty_pt-4xb256_in1k
Metadata:
FLOPs: 1080000000
Parameters: 5720000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 72.13
Top 5 Accuracy: 91.13
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L63
Config: configs/deit/deit-tiny_pt-4xb256_in1k.py
- Name: deit-tiny-distilled_3rdparty_pt-4xb256_in1k
Metadata:
FLOPs: 1080000000
Parameters: 5720000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 74.51
Top 5 Accuracy: 91.90
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108
Config: configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py
- Name: deit-small_3rdparty_pt-4xb256_in1k
Metadata:
FLOPs: 4240000000
Parameters: 22050000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.83
Top 5 Accuracy: 94.95
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L78
Config: configs/deit/deit-small_pt-4xb256_in1k.py
- Name: deit-small-distilled_3rdparty_pt-4xb256_in1k
Metadata:
FLOPs: 4240000000
Parameters: 22050000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.17
Top 5 Accuracy: 95.40
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123
Config: configs/deit/deit-small-distilled_pt-4xb256_in1k.py
- Name: deit-base_3rdparty_pt-16xb64_in1k
Metadata:
FLOPs: 16860000000
Parameters: 86570000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.79
Top 5 Accuracy: 95.59
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L93
Config: configs/deit/deit-base_pt-16xb64_in1k.py
- Name: deit-base-distilled_3rdparty_pt-16xb64_in1k
Metadata:
FLOPs: 16860000000
Parameters: 86570000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.33
Top 5 Accuracy: 96.49
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L138
Config: configs/deit/deit-base-distilled_pt-16xb64_in1k.py
- Name: deit-base_3rdparty_ft-16xb32_in1k-384px
Metadata:
FLOPs: 49370000000
Parameters: 86860000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.04
Top 5 Accuracy: 96.31
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L153
Config: configs/deit/deit-base_ft-16xb32_in1k-384px.py
- Name: deit-base-distilled_3rdparty_ft-16xb32_in1k-384px
Metadata:
FLOPs: 49370000000
Parameters: 86860000
In Collection: DeiT
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.55
Top 5 Accuracy: 97.35
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth
Converted From:
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L168
Config: configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py

View File

@ -63,12 +63,17 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) &#124; [log]()|
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) &#124; [log]()|
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) &#124; [log]()|
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) &#124; [log]()|
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) &#124; [log]()|
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) &#124; [log]()|
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) &#124; [log]()|
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) &#124; [log]()|
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) &#124; [log]()|
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) &#124; [log]()|
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) &#124; [log]()|
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) &#124; [log]()|
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) &#124; [log]()|
Models with * are converted from other repos, others are trained by ourselves.
## CIFAR10

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .alexnet import AlexNet
from .conformer import Conformer
from .deit import DistilledVisionTransformer
from .lenet import LeNet5
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
@ -28,5 +29,5 @@ __all__ = [
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
'Conformer', 'MlpMixer'
'Conformer', 'MlpMixer', 'DistilledVisionTransformer'
]

View File

@ -0,0 +1,85 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn.utils.weight_init import trunc_normal_
from ..builder import BACKBONES
from .vision_transformer import VisionTransformer
@BACKBONES.register_module()
class DistilledVisionTransformer(VisionTransformer):
"""Distilled Vision Transformer.
A PyTorch implement of : `Training data-efficient image transformers &
distillation through attention <https://arxiv.org/abs/2012.12877>`_
Args:
arch (str | dict): Vision Transformer architecture
Default: 'b'
img_size (int | tuple): Input image size
patch_size (int | tuple): The patch size
out_indices (Sequence | int): Output from which stages.
Defaults to -1, means the last stage.
drop_rate (float): Probability of an element to be zeroed.
Defaults to 0.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='LN')``.
final_norm (bool): Whether to add a additional layer to normalize
final feature map. Defaults to True.
output_cls_token (bool): Whether output the cls_token. If set True,
`with_cls_token` must be True. Defaults to True.
interpolate_mode (str): Select the interpolate mode for position
embeding vector resize. Defaults to "bicubic".
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
layer_cfgs (Sequence | dict): Configs of each transformer layer in
encoder. Defaults to an empty dict.
init_cfg (dict, optional): Initialization config dict.
Defaults to None.
"""
num_extra_tokens = 2 # cls_token, dist_token
def __init__(self, *args, **kwargs):
super(DistilledVisionTransformer, self).__init__(*args, **kwargs)
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
def forward(self, x):
B = x.shape[0]
x = self.patch_embed(x)
patch_resolution = self.patch_embed.patches_resolution
# stole cls_tokens impl from Phil Wang, thanks
cls_tokens = self.cls_token.expand(B, -1, -1)
dist_token = self.dist_token.expand(B, -1, -1)
x = torch.cat((cls_tokens, dist_token, x), dim=1)
x = x + self.pos_embed
x = self.drop_after_pos(x)
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i == len(self.layers) - 1 and self.final_norm:
x = self.norm1(x)
if i in self.out_indices:
B, _, C = x.shape
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
patch_token = patch_token.permute(0, 3, 1, 2)
cls_token = x[:, 0]
dist_token = x[:, 1]
if self.output_cls_token:
out = [patch_token, cls_token, dist_token]
else:
out = patch_token
outs.append(out)
return tuple(outs)
def init_weights(self):
super(DistilledVisionTransformer, self).init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
trunc_normal_(self.dist_token, std=0.02)

View File

@ -1,5 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from copy import deepcopy
from typing import Sequence
import numpy as np
@ -8,6 +7,7 @@ import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_norm_layer
from mmcv.cnn.bricks.transformer import FFN
from mmcv.cnn.utils.weight_init import trunc_normal_
from mmcv.runner.base_module import BaseModule, ModuleList
from mmcls.utils import get_root_logger
@ -104,9 +104,8 @@ class TransformerEncoderLayer(BaseModule):
class VisionTransformer(BaseBackbone):
"""Vision Transformer.
A PyTorch implement of : `An Image is Worth 16x16 Words:
Transformers for Image Recognition at Scale
<https://arxiv.org/abs/2010.11929>`_
A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
Args:
arch (str | dict): Vision Transformer architecture
@ -155,7 +154,30 @@ class VisionTransformer(BaseBackbone):
'num_heads': 16,
'feedforward_channels': 4096
}),
**dict.fromkeys(
['deit-t', 'deit-tiny'], {
'embed_dims': 192,
'num_layers': 12,
'num_heads': 3,
'feedforward_channels': 192 * 4
}),
**dict.fromkeys(
['deit-s', 'deit-small'], {
'embed_dims': 384,
'num_layers': 12,
'num_heads': 6,
'feedforward_channels': 384 * 4
}),
**dict.fromkeys(
['deit-b', 'deit-base'], {
'embed_dims': 768,
'num_layers': 12,
'num_heads': 12,
'feedforward_channels': 768 * 4
}),
}
# Some structures have multiple extra tokens, like DeiT.
num_extra_tokens = 1 # cls_token
def __init__(self,
arch='b',
@ -182,7 +204,7 @@ class VisionTransformer(BaseBackbone):
essential_keys = {
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
}
assert isinstance(arch, dict) and set(arch) == essential_keys, \
assert isinstance(arch, dict) and essential_keys <= set(arch), \
f'Custom arch needs a dict with keys {essential_keys}'
self.arch_settings = arch
@ -208,7 +230,8 @@ class VisionTransformer(BaseBackbone):
# Set position embedding
self.interpolate_mode = interpolate_mode
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, self.embed_dims))
torch.zeros(1, num_patches + self.num_extra_tokens,
self.embed_dims))
self.drop_after_pos = nn.Dropout(p=drop_rate)
if isinstance(out_indices, int):
@ -247,65 +270,49 @@ class VisionTransformer(BaseBackbone):
norm_cfg, self.embed_dims, postfix=1)
self.add_module(self.norm1_name, norm1)
self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook)
@property
def norm1(self):
return getattr(self, self.norm1_name)
def init_weights(self):
# Suppress default init if use pretrained model.
# And use custom load_checkpoint function to load checkpoint.
if (isinstance(self.init_cfg, dict)
super(VisionTransformer, self).init_weights()
if not (isinstance(self.init_cfg, dict)
and self.init_cfg['type'] == 'Pretrained'):
init_cfg = deepcopy(self.init_cfg)
init_cfg.pop('type')
self._load_checkpoint(**init_cfg)
else:
super(VisionTransformer, self).init_weights()
# Modified from ClassyVision
nn.init.normal_(self.pos_embed, std=0.02)
trunc_normal_(self.pos_embed, std=0.02)
def _load_checkpoint(self, checkpoint, prefix=None, map_location=None):
from mmcv.runner import (_load_checkpoint,
_load_checkpoint_with_prefix, load_state_dict)
from mmcv.utils import print_log
def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs):
name = prefix + 'pos_embed'
if name not in state_dict.keys():
return
logger = get_root_logger()
if prefix is None:
print_log(f'load model from: {checkpoint}', logger=logger)
checkpoint = _load_checkpoint(checkpoint, map_location, logger)
# get state_dict from checkpoint
if 'state_dict' in checkpoint:
state_dict = checkpoint['state_dict']
else:
state_dict = checkpoint
else:
ckpt_pos_embed_shape = state_dict[name].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
from mmcv.utils import print_log
logger = get_root_logger()
print_log(
f'load {prefix} in model from: {checkpoint}', logger=logger)
state_dict = _load_checkpoint_with_prefix(prefix, checkpoint,
map_location)
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.',
logger=logger)
if 'pos_embed' in state_dict.keys():
ckpt_pos_embed_shape = state_dict['pos_embed'].shape
if self.pos_embed.shape != ckpt_pos_embed_shape:
print_log(
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
f'to {self.pos_embed.shape}.',
logger=logger)
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
pos_embed_shape = self.patch_embed.patches_resolution
ckpt_pos_embed_shape = to_2tuple(
int(np.sqrt(ckpt_pos_embed_shape[1] - 1)))
pos_embed_shape = self.patch_embed.patches_resolution
state_dict['pos_embed'] = self.resize_pos_embed(
state_dict['pos_embed'], ckpt_pos_embed_shape,
pos_embed_shape, self.interpolate_mode)
# load state_dict
load_state_dict(self, state_dict, strict=False, logger=logger)
state_dict[name] = self.resize_pos_embed(state_dict[name],
ckpt_pos_embed_shape,
pos_embed_shape,
self.interpolate_mode,
self.num_extra_tokens)
@staticmethod
def resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic'):
def resize_pos_embed(pos_embed,
src_shape,
dst_shape,
mode='bicubic',
num_extra_tokens=1):
"""Resize pos_embed weights.
Args:
@ -324,17 +331,17 @@ class VisionTransformer(BaseBackbone):
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
_, L, C = pos_embed.shape
src_h, src_w = src_shape
assert L == src_h * src_w + 1
cls_token = pos_embed[:, :1]
assert L == src_h * src_w + num_extra_tokens
extra_tokens = pos_embed[:, :num_extra_tokens]
src_weight = pos_embed[:, 1:]
src_weight = pos_embed[:, num_extra_tokens:]
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
dst_weight = F.interpolate(
src_weight, size=dst_shape, align_corners=False, mode=mode)
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
return torch.cat((cls_token, dst_weight), dim=1)
return torch.cat((extra_tokens, dst_weight), dim=1)
def forward(self, x):
B = x.shape[0]

View File

@ -1,6 +1,7 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cls_head import ClsHead
from .conformer_head import ConformerHead
from .deit_head import DeiTClsHead
from .linear_head import LinearClsHead
from .multi_label_head import MultiLabelClsHead
from .multi_label_linear_head import MultiLabelLinearClsHead
@ -9,5 +10,6 @@ from .vision_transformer_head import VisionTransformerClsHead
__all__ = [
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'ConformerHead'
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
'ConformerHead'
]

View File

@ -0,0 +1,36 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from mmcls.utils import get_root_logger
from ..builder import HEADS
from .vision_transformer_head import VisionTransformerClsHead
@HEADS.register_module()
class DeiTClsHead(VisionTransformerClsHead):
def __init__(self, *args, **kwargs):
super(DeiTClsHead, self).__init__(*args, **kwargs)
self.head_dist = nn.Linear(self.in_channels, self.num_classes)
def simple_test(self, x):
"""Test without augmentation."""
x = x[-1]
assert isinstance(x, list) and len(x) == 3
_, cls_token, dist_token = x
cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
return self.post_process(pred)
def forward_train(self, x, gt_label):
logger = get_root_logger()
logger.warning("MMClassification doesn't support to train the "
'distilled version DeiT.')
x = x[-1]
assert isinstance(x, list) and len(x) == 3
_, cls_token, dist_token = x
cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
losses = self.loss(cls_score, gt_label)
return losses

View File

@ -14,3 +14,4 @@ Import:
- configs/t2t_vit/metafile.yml
- configs/mlp_mixer/metafile.yml
- configs/conformer/metafile.yml
- configs/deit/metafile.yml

View File

@ -0,0 +1,43 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import DistilledVisionTransformer
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_deit_backbone():
cfg_ori = dict(arch='deit-b', img_size=224, patch_size=16)
# Test structure
model = DistilledVisionTransformer(**cfg_ori)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), True)
assert model.dist_token.shape == (1, 1, 768)
assert model.pos_embed.shape == (1, model.patch_embed.num_patches + 2, 768)
# Test forward
imgs = torch.rand(1, 3, 224, 224)
outs = model(imgs)
patch_token, cls_token, dist_token = outs[0]
assert patch_token.shape == (1, 768, 14, 14)
assert cls_token.shape == (1, 768)
assert dist_token.shape == (1, 768)
# Test multiple out_indices
model = DistilledVisionTransformer(
**cfg_ori, out_indices=(0, 1, 2, 3), output_cls_token=False)
outs = model(imgs)
for out in outs:
assert out.shape == (1, 768, 14, 14)

View File

@ -4,9 +4,9 @@ from unittest.mock import patch
import pytest
import torch
from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
MultiLabelLinearClsHead, StackedLinearClsHead,
VisionTransformerClsHead)
from mmcls.models.heads import (ClsHead, DeiTClsHead, LinearClsHead,
MultiLabelClsHead, MultiLabelLinearClsHead,
StackedLinearClsHead, VisionTransformerClsHead)
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
@ -157,3 +157,44 @@ def test_vit_head():
# test assertion
with pytest.raises(ValueError):
VisionTransformerClsHead(-1, 100)
def test_deit_head():
fake_features = ([
torch.rand(4, 7, 7, 16),
torch.rand(4, 100),
torch.rand(4, 100)
], )
fake_gt_label = torch.randint(0, 10, (4, ))
# test deit head forward
head = DeiTClsHead(num_classes=10, in_channels=100)
losses = head.forward_train(fake_features, fake_gt_label)
assert not hasattr(head.layers, 'pre_logits')
assert not hasattr(head.layers, 'act')
assert losses['loss'].item() > 0
# test deit head forward with hidden layer
head = DeiTClsHead(num_classes=10, in_channels=100, hidden_dim=20)
losses = head.forward_train(fake_features, fake_gt_label)
assert hasattr(head.layers, 'pre_logits') and hasattr(head.layers, 'act')
assert losses['loss'].item() > 0
# test deit head init_weights
head = DeiTClsHead(10, 100, hidden_dim=20)
head.init_weights()
assert abs(head.layers.pre_logits.weight).sum() > 0
# test simple_test
head = DeiTClsHead(10, 100, hidden_dim=20)
pred = head.simple_test(fake_features)
assert isinstance(pred, list) and len(pred) == 4
with patch('torch.onnx.is_in_onnx_export', return_value=True):
head = DeiTClsHead(10, 100, hidden_dim=20)
pred = head.simple_test(fake_features)
assert pred.shape == (4, 10)
# test assertion
with pytest.raises(ValueError):
DeiTClsHead(-1, 100)