[Feature] Add DeiT backbone and checkpoints. (#576)
* Support DeiT backbone. * Use hook to automatically resize pos embed * Update ViT training setting * Add deit configs and update docs * Fix vit arch assertion * Remove useless init function * Add unit tests. * Fix resize_pos_embed for DeiT * Improve according to comments.pull/607/head
parent
6f25bebe42
commit
f9a2b04cee
|
@ -8,7 +8,11 @@ img_norm_cfg = dict(
|
|||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=224, backend='pillow'),
|
||||
dict(
|
||||
type='RandomResizedCrop',
|
||||
size=224,
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='AutoAugment', policies={{_base_.policy_imagenet}}),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
|
@ -18,7 +22,11 @@ train_pipeline = [
|
|||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=(256, -1), backend='pillow'),
|
||||
dict(
|
||||
type='Resize',
|
||||
size=(256, -1),
|
||||
backend='pillow',
|
||||
interpolation='bicubic'),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
|
|
|
@ -1,18 +1,24 @@
|
|||
# specific to vit pretrain
|
||||
paramwise_cfg = dict(custom_keys={
|
||||
'.cls_token': dict(decay_mult=0.0),
|
||||
'.pos_embed': dict(decay_mult=0.0)
|
||||
})
|
||||
|
||||
# optimizer
|
||||
optimizer = dict(type='AdamW', lr=0.003, weight_decay=0.3)
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=0.003,
|
||||
weight_decay=0.3,
|
||||
paramwise_cfg=paramwise_cfg,
|
||||
)
|
||||
optimizer_config = dict(grad_clip=dict(max_norm=1.0))
|
||||
|
||||
# specific to vit pretrain
|
||||
paramwise_cfg = dict(
|
||||
custom_keys={
|
||||
'.backbone.cls_token': dict(decay_mult=0.0),
|
||||
'.backbone.pos_embed': dict(decay_mult=0.0)
|
||||
})
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=0,
|
||||
warmup='linear',
|
||||
warmup_iters=10000,
|
||||
warmup_ratio=1e-4)
|
||||
warmup_ratio=1e-4,
|
||||
)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=300)
|
||||
|
|
|
@ -0,0 +1,61 @@
|
|||
# Training data-efficient image transformers & distillation through attention
|
||||
<!-- {DeiT} -->
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
## Abstract
|
||||
|
||||
<!-- [ABSTRACT] -->
|
||||
Recently, neural networks purely based on attention were shown to address image understanding tasks such as image classification. However, these visual transformers are pre-trained with hundreds of millions of images using an expensive infrastructure, thereby limiting their adoption. In this work, we produce a competitive convolution-free transformer by training on Imagenet only. We train them on a single computer in less than 3 days. Our reference vision transformer (86M parameters) achieves top-1 accuracy of 83.1% (single-crop evaluation) on ImageNet with no external data. More importantly, we introduce a teacher-student strategy specific to transformers. It relies on a distillation token ensuring that the student learns from the teacher through attention. We show the interest of this token-based distillation, especially when using a convnet as a teacher. This leads us to report results competitive with convnets for both Imagenet (where we obtain up to 85.2% accuracy) and when transferring to other tasks. We share our code and models.
|
||||
|
||||
<!-- [IMAGE] -->
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/26739999/143225703-c287c29e-82c9-4c85-a366-dfae30d198cd.png" width="40%"/>
|
||||
</div>
|
||||
|
||||
## Citation
|
||||
```{latex}
|
||||
@InProceedings{pmlr-v139-touvron21a,
|
||||
title = {Training data-efficient image transformers & distillation through attention},
|
||||
author = {Touvron, Hugo and Cord, Matthieu and Douze, Matthijs and Massa, Francisco and Sablayrolles, Alexandre and Jegou, Herve},
|
||||
booktitle = {International Conference on Machine Learning},
|
||||
pages = {10347--10357},
|
||||
year = {2021},
|
||||
volume = {139},
|
||||
month = {July}
|
||||
}
|
||||
```
|
||||
|
||||
## Pretrained models
|
||||
|
||||
The pre-trained models are converted from the [official repo](https://github.com/facebookresearch/deit). And the teacher of the distilled version DeiT is RegNetY-16GF.
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
|
||||
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) |
|
||||
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
|
||||
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) |
|
||||
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
|
||||
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) |
|
||||
|
||||
*Models with \* are converted from other repos.*
|
||||
|
||||
## Fine-tuned models
|
||||
|
||||
The fine-tuned models are converted from the [official repo](https://github.com/facebookresearch/deit).
|
||||
|
||||
### ImageNet-1k
|
||||
|
||||
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
|
||||
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth) |
|
||||
|
||||
*Models with \* are converted from other repos.*
|
||||
|
||||
```{warning}
|
||||
MMClassification doesn't support training the distilled version DeiT.
|
||||
And we provide distilled version checkpoints for inference only.
|
||||
```
|
|
@ -0,0 +1,9 @@
|
|||
_base_ = './deit-base_ft-16xb32_in1k-384px.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='DistilledVisionTransformer'),
|
||||
head=dict(type='DeiTClsHead'),
|
||||
# Change to the path of the pretrained model
|
||||
# init_cfg=dict(type='Pretrained', checkpoint=''),
|
||||
)
|
|
@ -0,0 +1,10 @@
|
|||
_base_ = './deit-small_pt-4xb256_in1k.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='DistilledVisionTransformer', arch='deit-base'),
|
||||
head=dict(type='DeiTClsHead', in_channels=768),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data = dict(samples_per_gpu=64, workers_per_gpu=5)
|
|
@ -0,0 +1,29 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_swin_384.py',
|
||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='deit-base',
|
||||
img_size=384,
|
||||
patch_size=16,
|
||||
),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='VisionTransformerClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=768,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
# Change to the path of the pretrained model
|
||||
# init_cfg=dict(type='Pretrained', checkpoint=''),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data = dict(samples_per_gpu=32, workers_per_gpu=5)
|
|
@ -0,0 +1,10 @@
|
|||
_base_ = './deit-small_pt-4xb256_in1k.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='VisionTransformer', arch='deit-base'),
|
||||
head=dict(type='VisionTransformerClsHead', in_channels=768),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data = dict(samples_per_gpu=64, workers_per_gpu=5)
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = './deit-small_pt-4xb256_in1k.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='DistilledVisionTransformer', arch='deit-small'),
|
||||
head=dict(type='DeiTClsHead', in_channels=384),
|
||||
)
|
|
@ -0,0 +1,29 @@
|
|||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py',
|
||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='VisionTransformer',
|
||||
arch='deit-small',
|
||||
img_size=224,
|
||||
patch_size=16),
|
||||
neck=None,
|
||||
head=dict(
|
||||
type='VisionTransformerClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=384,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss', label_smooth_val=0.1, mode='original'),
|
||||
),
|
||||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
])
|
||||
|
||||
# data settings
|
||||
data = dict(samples_per_gpu=256, workers_per_gpu=5)
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = './deit-small_pt-4xb256_in1k.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='DistilledVisionTransformer', arch='deit-tiny'),
|
||||
head=dict(type='DeiTClsHead', in_channels=192),
|
||||
)
|
|
@ -0,0 +1,7 @@
|
|||
_base_ = './deit-small_pt-4xb256_in1k.py'
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='VisionTransformer', arch='deit-tiny'),
|
||||
head=dict(type='VisionTransformerClsHead', in_channels=192),
|
||||
)
|
|
@ -0,0 +1,143 @@
|
|||
Collections:
|
||||
- Name: DeiT
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Architecture:
|
||||
- Layer Normalization
|
||||
- Scaled Dot-Product Attention
|
||||
- Attention Dropout
|
||||
- Multi-Head Attention
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2012.12877
|
||||
Title: "Training data-efficient image transformers & distillation through attention"
|
||||
README: configs/deit/README.md
|
||||
|
||||
Models:
|
||||
- Name: deit-tiny_3rdparty_pt-4xb256_in1k
|
||||
Metadata:
|
||||
FLOPs: 1080000000
|
||||
Parameters: 5720000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 72.13
|
||||
Top 5 Accuracy: 91.13
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L63
|
||||
Config: configs/deit/deit-tiny_pt-4xb256_in1k.py
|
||||
- Name: deit-tiny-distilled_3rdparty_pt-4xb256_in1k
|
||||
Metadata:
|
||||
FLOPs: 1080000000
|
||||
Parameters: 5720000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 74.51
|
||||
Top 5 Accuracy: 91.90
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108
|
||||
Config: configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py
|
||||
- Name: deit-small_3rdparty_pt-4xb256_in1k
|
||||
Metadata:
|
||||
FLOPs: 4240000000
|
||||
Parameters: 22050000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 79.83
|
||||
Top 5 Accuracy: 94.95
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L78
|
||||
Config: configs/deit/deit-small_pt-4xb256_in1k.py
|
||||
- Name: deit-small-distilled_3rdparty_pt-4xb256_in1k
|
||||
Metadata:
|
||||
FLOPs: 4240000000
|
||||
Parameters: 22050000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.17
|
||||
Top 5 Accuracy: 95.40
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123
|
||||
Config: configs/deit/deit-small-distilled_pt-4xb256_in1k.py
|
||||
- Name: deit-base_3rdparty_pt-16xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 16860000000
|
||||
Parameters: 86570000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.79
|
||||
Top 5 Accuracy: 95.59
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L93
|
||||
Config: configs/deit/deit-base_pt-16xb64_in1k.py
|
||||
- Name: deit-base-distilled_3rdparty_pt-16xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 16860000000
|
||||
Parameters: 86570000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.33
|
||||
Top 5 Accuracy: 96.49
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_224-df68dfff.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L138
|
||||
Config: configs/deit/deit-base-distilled_pt-16xb64_in1k.py
|
||||
- Name: deit-base_3rdparty_ft-16xb32_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 49370000000
|
||||
Parameters: 86860000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 83.04
|
||||
Top 5 Accuracy: 96.31
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_patch16_384-8de9b5d1.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L153
|
||||
Config: configs/deit/deit-base_ft-16xb32_in1k-384px.py
|
||||
- Name: deit-base-distilled_3rdparty_ft-16xb32_in1k-384px
|
||||
Metadata:
|
||||
FLOPs: 49370000000
|
||||
Parameters: 86860000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 85.55
|
||||
Top 5 Accuracy: 97.35
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211124-91e88933.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_base_distilled_patch16_384-d0272ac0.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L168
|
||||
Config: configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py
|
|
@ -63,12 +63,17 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| T2T-ViT_t-24\* | 64.00 | 12.69 | 82.55 | 96.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_3rdparty_8xb64_in1k_20210928-fe95a61b.pth) | [log]()|
|
||||
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) | [log]()|
|
||||
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) | [log]()|
|
||||
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) | [log]()|
|
||||
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211124-e71bdd9a.pth) | [log]()|
|
||||
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) | [log]()|
|
||||
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211124-15e341b0.pth) | [log]()|
|
||||
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) | [log]()|
|
||||
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211124-766d123d.pth) | [log]()|
|
||||
| Conformer-tiny-p16\* | 23.52 | 4.90 | 81.31 | 95.60 | [config](configs/conformer/conformer-tiny-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-tiny-p16_3rdparty_8xb128_in1k_20211206-f6860372.pth) | [log]()|
|
||||
| Conformer-small-p32 | 38.85 | 7.09 | 81.96 | 96.02 | [config](configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) | [log]()|
|
||||
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) | [log]()|
|
||||
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) | [log]()|
|
||||
|
||||
|
||||
Models with * are converted from other repos, others are trained by ourselves.
|
||||
|
||||
## CIFAR10
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .alexnet import AlexNet
|
||||
from .conformer import Conformer
|
||||
from .deit import DistilledVisionTransformer
|
||||
from .lenet import LeNet5
|
||||
from .mlp_mixer import MlpMixer
|
||||
from .mobilenet_v2 import MobileNetV2
|
||||
|
@ -28,5 +29,5 @@ __all__ = [
|
|||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
|
||||
'Conformer', 'MlpMixer'
|
||||
'Conformer', 'MlpMixer', 'DistilledVisionTransformer'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,85 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from .vision_transformer import VisionTransformer
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class DistilledVisionTransformer(VisionTransformer):
|
||||
"""Distilled Vision Transformer.
|
||||
|
||||
A PyTorch implement of : `Training data-efficient image transformers &
|
||||
distillation through attention <https://arxiv.org/abs/2012.12877>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
Default: 'b'
|
||||
img_size (int | tuple): Input image size
|
||||
patch_size (int | tuple): The patch size
|
||||
out_indices (Sequence | int): Output from which stages.
|
||||
Defaults to -1, means the last stage.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Defaults to 0.
|
||||
drop_path_rate (float): stochastic depth rate. Defaults to 0.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Defaults to ``dict(type='LN')``.
|
||||
final_norm (bool): Whether to add a additional layer to normalize
|
||||
final feature map. Defaults to True.
|
||||
output_cls_token (bool): Whether output the cls_token. If set True,
|
||||
`with_cls_token` must be True. Defaults to True.
|
||||
interpolate_mode (str): Select the interpolate mode for position
|
||||
embeding vector resize. Defaults to "bicubic".
|
||||
patch_cfg (dict): Configs of patch embeding. Defaults to an empty dict.
|
||||
layer_cfgs (Sequence | dict): Configs of each transformer layer in
|
||||
encoder. Defaults to an empty dict.
|
||||
init_cfg (dict, optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
num_extra_tokens = 2 # cls_token, dist_token
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DistilledVisionTransformer, self).__init__(*args, **kwargs)
|
||||
self.dist_token = nn.Parameter(torch.zeros(1, 1, self.embed_dims))
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
x = self.patch_embed(x)
|
||||
patch_resolution = self.patch_embed.patches_resolution
|
||||
|
||||
# stole cls_tokens impl from Phil Wang, thanks
|
||||
cls_tokens = self.cls_token.expand(B, -1, -1)
|
||||
dist_token = self.dist_token.expand(B, -1, -1)
|
||||
x = torch.cat((cls_tokens, dist_token, x), dim=1)
|
||||
x = x + self.pos_embed
|
||||
x = self.drop_after_pos(x)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
|
||||
if i == len(self.layers) - 1 and self.final_norm:
|
||||
x = self.norm1(x)
|
||||
|
||||
if i in self.out_indices:
|
||||
B, _, C = x.shape
|
||||
patch_token = x[:, 2:].reshape(B, *patch_resolution, C)
|
||||
patch_token = patch_token.permute(0, 3, 1, 2)
|
||||
cls_token = x[:, 0]
|
||||
dist_token = x[:, 1]
|
||||
if self.output_cls_token:
|
||||
out = [patch_token, cls_token, dist_token]
|
||||
else:
|
||||
out = patch_token
|
||||
outs.append(out)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def init_weights(self):
|
||||
super(DistilledVisionTransformer, self).init_weights()
|
||||
|
||||
if not (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
trunc_normal_(self.dist_token, std=0.02)
|
|
@ -1,5 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from copy import deepcopy
|
||||
from typing import Sequence
|
||||
|
||||
import numpy as np
|
||||
|
@ -8,6 +7,7 @@ import torch.nn as nn
|
|||
import torch.nn.functional as F
|
||||
from mmcv.cnn import build_norm_layer
|
||||
from mmcv.cnn.bricks.transformer import FFN
|
||||
from mmcv.cnn.utils.weight_init import trunc_normal_
|
||||
from mmcv.runner.base_module import BaseModule, ModuleList
|
||||
|
||||
from mmcls.utils import get_root_logger
|
||||
|
@ -104,9 +104,8 @@ class TransformerEncoderLayer(BaseModule):
|
|||
class VisionTransformer(BaseBackbone):
|
||||
"""Vision Transformer.
|
||||
|
||||
A PyTorch implement of : `An Image is Worth 16x16 Words:
|
||||
Transformers for Image Recognition at Scale
|
||||
<https://arxiv.org/abs/2010.11929>`_
|
||||
A PyTorch implement of : `An Image is Worth 16x16 Words: Transformers
|
||||
for Image Recognition at Scale <https://arxiv.org/abs/2010.11929>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): Vision Transformer architecture
|
||||
|
@ -155,7 +154,30 @@ class VisionTransformer(BaseBackbone):
|
|||
'num_heads': 16,
|
||||
'feedforward_channels': 4096
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['deit-t', 'deit-tiny'], {
|
||||
'embed_dims': 192,
|
||||
'num_layers': 12,
|
||||
'num_heads': 3,
|
||||
'feedforward_channels': 192 * 4
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['deit-s', 'deit-small'], {
|
||||
'embed_dims': 384,
|
||||
'num_layers': 12,
|
||||
'num_heads': 6,
|
||||
'feedforward_channels': 384 * 4
|
||||
}),
|
||||
**dict.fromkeys(
|
||||
['deit-b', 'deit-base'], {
|
||||
'embed_dims': 768,
|
||||
'num_layers': 12,
|
||||
'num_heads': 12,
|
||||
'feedforward_channels': 768 * 4
|
||||
}),
|
||||
}
|
||||
# Some structures have multiple extra tokens, like DeiT.
|
||||
num_extra_tokens = 1 # cls_token
|
||||
|
||||
def __init__(self,
|
||||
arch='b',
|
||||
|
@ -182,7 +204,7 @@ class VisionTransformer(BaseBackbone):
|
|||
essential_keys = {
|
||||
'embed_dims', 'num_layers', 'num_heads', 'feedforward_channels'
|
||||
}
|
||||
assert isinstance(arch, dict) and set(arch) == essential_keys, \
|
||||
assert isinstance(arch, dict) and essential_keys <= set(arch), \
|
||||
f'Custom arch needs a dict with keys {essential_keys}'
|
||||
self.arch_settings = arch
|
||||
|
||||
|
@ -208,7 +230,8 @@ class VisionTransformer(BaseBackbone):
|
|||
# Set position embedding
|
||||
self.interpolate_mode = interpolate_mode
|
||||
self.pos_embed = nn.Parameter(
|
||||
torch.zeros(1, num_patches + 1, self.embed_dims))
|
||||
torch.zeros(1, num_patches + self.num_extra_tokens,
|
||||
self.embed_dims))
|
||||
self.drop_after_pos = nn.Dropout(p=drop_rate)
|
||||
|
||||
if isinstance(out_indices, int):
|
||||
|
@ -247,65 +270,49 @@ class VisionTransformer(BaseBackbone):
|
|||
norm_cfg, self.embed_dims, postfix=1)
|
||||
self.add_module(self.norm1_name, norm1)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self._prepare_checkpoint_hook)
|
||||
|
||||
@property
|
||||
def norm1(self):
|
||||
return getattr(self, self.norm1_name)
|
||||
|
||||
def init_weights(self):
|
||||
# Suppress default init if use pretrained model.
|
||||
# And use custom load_checkpoint function to load checkpoint.
|
||||
if (isinstance(self.init_cfg, dict)
|
||||
super(VisionTransformer, self).init_weights()
|
||||
|
||||
if not (isinstance(self.init_cfg, dict)
|
||||
and self.init_cfg['type'] == 'Pretrained'):
|
||||
init_cfg = deepcopy(self.init_cfg)
|
||||
init_cfg.pop('type')
|
||||
self._load_checkpoint(**init_cfg)
|
||||
else:
|
||||
super(VisionTransformer, self).init_weights()
|
||||
# Modified from ClassyVision
|
||||
nn.init.normal_(self.pos_embed, std=0.02)
|
||||
trunc_normal_(self.pos_embed, std=0.02)
|
||||
|
||||
def _load_checkpoint(self, checkpoint, prefix=None, map_location=None):
|
||||
from mmcv.runner import (_load_checkpoint,
|
||||
_load_checkpoint_with_prefix, load_state_dict)
|
||||
from mmcv.utils import print_log
|
||||
def _prepare_checkpoint_hook(self, state_dict, prefix, *args, **kwargs):
|
||||
name = prefix + 'pos_embed'
|
||||
if name not in state_dict.keys():
|
||||
return
|
||||
|
||||
logger = get_root_logger()
|
||||
|
||||
if prefix is None:
|
||||
print_log(f'load model from: {checkpoint}', logger=logger)
|
||||
checkpoint = _load_checkpoint(checkpoint, map_location, logger)
|
||||
# get state_dict from checkpoint
|
||||
if 'state_dict' in checkpoint:
|
||||
state_dict = checkpoint['state_dict']
|
||||
else:
|
||||
state_dict = checkpoint
|
||||
else:
|
||||
ckpt_pos_embed_shape = state_dict[name].shape
|
||||
if self.pos_embed.shape != ckpt_pos_embed_shape:
|
||||
from mmcv.utils import print_log
|
||||
logger = get_root_logger()
|
||||
print_log(
|
||||
f'load {prefix} in model from: {checkpoint}', logger=logger)
|
||||
state_dict = _load_checkpoint_with_prefix(prefix, checkpoint,
|
||||
map_location)
|
||||
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
|
||||
f'to {self.pos_embed.shape}.',
|
||||
logger=logger)
|
||||
|
||||
if 'pos_embed' in state_dict.keys():
|
||||
ckpt_pos_embed_shape = state_dict['pos_embed'].shape
|
||||
if self.pos_embed.shape != ckpt_pos_embed_shape:
|
||||
print_log(
|
||||
f'Resize the pos_embed shape from {ckpt_pos_embed_shape} '
|
||||
f'to {self.pos_embed.shape}.',
|
||||
logger=logger)
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - self.num_extra_tokens)))
|
||||
pos_embed_shape = self.patch_embed.patches_resolution
|
||||
|
||||
ckpt_pos_embed_shape = to_2tuple(
|
||||
int(np.sqrt(ckpt_pos_embed_shape[1] - 1)))
|
||||
pos_embed_shape = self.patch_embed.patches_resolution
|
||||
|
||||
state_dict['pos_embed'] = self.resize_pos_embed(
|
||||
state_dict['pos_embed'], ckpt_pos_embed_shape,
|
||||
pos_embed_shape, self.interpolate_mode)
|
||||
|
||||
# load state_dict
|
||||
load_state_dict(self, state_dict, strict=False, logger=logger)
|
||||
state_dict[name] = self.resize_pos_embed(state_dict[name],
|
||||
ckpt_pos_embed_shape,
|
||||
pos_embed_shape,
|
||||
self.interpolate_mode,
|
||||
self.num_extra_tokens)
|
||||
|
||||
@staticmethod
|
||||
def resize_pos_embed(pos_embed, src_shape, dst_shape, mode='bicubic'):
|
||||
def resize_pos_embed(pos_embed,
|
||||
src_shape,
|
||||
dst_shape,
|
||||
mode='bicubic',
|
||||
num_extra_tokens=1):
|
||||
"""Resize pos_embed weights.
|
||||
|
||||
Args:
|
||||
|
@ -324,17 +331,17 @@ class VisionTransformer(BaseBackbone):
|
|||
assert pos_embed.ndim == 3, 'shape of pos_embed must be [1, L, C]'
|
||||
_, L, C = pos_embed.shape
|
||||
src_h, src_w = src_shape
|
||||
assert L == src_h * src_w + 1
|
||||
cls_token = pos_embed[:, :1]
|
||||
assert L == src_h * src_w + num_extra_tokens
|
||||
extra_tokens = pos_embed[:, :num_extra_tokens]
|
||||
|
||||
src_weight = pos_embed[:, 1:]
|
||||
src_weight = pos_embed[:, num_extra_tokens:]
|
||||
src_weight = src_weight.reshape(1, src_h, src_w, C).permute(0, 3, 1, 2)
|
||||
|
||||
dst_weight = F.interpolate(
|
||||
src_weight, size=dst_shape, align_corners=False, mode=mode)
|
||||
dst_weight = torch.flatten(dst_weight, 2).transpose(1, 2)
|
||||
|
||||
return torch.cat((cls_token, dst_weight), dim=1)
|
||||
return torch.cat((extra_tokens, dst_weight), dim=1)
|
||||
|
||||
def forward(self, x):
|
||||
B = x.shape[0]
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .cls_head import ClsHead
|
||||
from .conformer_head import ConformerHead
|
||||
from .deit_head import DeiTClsHead
|
||||
from .linear_head import LinearClsHead
|
||||
from .multi_label_head import MultiLabelClsHead
|
||||
from .multi_label_linear_head import MultiLabelLinearClsHead
|
||||
|
@ -9,5 +10,6 @@ from .vision_transformer_head import VisionTransformerClsHead
|
|||
|
||||
__all__ = [
|
||||
'ClsHead', 'LinearClsHead', 'StackedLinearClsHead', 'MultiLabelClsHead',
|
||||
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'ConformerHead'
|
||||
'MultiLabelLinearClsHead', 'VisionTransformerClsHead', 'DeiTClsHead',
|
||||
'ConformerHead'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,36 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from mmcls.utils import get_root_logger
|
||||
from ..builder import HEADS
|
||||
from .vision_transformer_head import VisionTransformerClsHead
|
||||
|
||||
|
||||
@HEADS.register_module()
|
||||
class DeiTClsHead(VisionTransformerClsHead):
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(DeiTClsHead, self).__init__(*args, **kwargs)
|
||||
self.head_dist = nn.Linear(self.in_channels, self.num_classes)
|
||||
|
||||
def simple_test(self, x):
|
||||
"""Test without augmentation."""
|
||||
x = x[-1]
|
||||
assert isinstance(x, list) and len(x) == 3
|
||||
_, cls_token, dist_token = x
|
||||
cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
|
||||
pred = F.softmax(cls_score, dim=1) if cls_score is not None else None
|
||||
|
||||
return self.post_process(pred)
|
||||
|
||||
def forward_train(self, x, gt_label):
|
||||
logger = get_root_logger()
|
||||
logger.warning("MMClassification doesn't support to train the "
|
||||
'distilled version DeiT.')
|
||||
x = x[-1]
|
||||
assert isinstance(x, list) and len(x) == 3
|
||||
_, cls_token, dist_token = x
|
||||
cls_score = (self.layers(cls_token) + self.head_dist(dist_token)) / 2
|
||||
losses = self.loss(cls_score, gt_label)
|
||||
return losses
|
|
@ -14,3 +14,4 @@ Import:
|
|||
- configs/t2t_vit/metafile.yml
|
||||
- configs/mlp_mixer/metafile.yml
|
||||
- configs/conformer/metafile.yml
|
||||
- configs/deit/metafile.yml
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import DistilledVisionTransformer
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_deit_backbone():
|
||||
cfg_ori = dict(arch='deit-b', img_size=224, patch_size=16)
|
||||
|
||||
# Test structure
|
||||
model = DistilledVisionTransformer(**cfg_ori)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), True)
|
||||
assert model.dist_token.shape == (1, 1, 768)
|
||||
assert model.pos_embed.shape == (1, model.patch_embed.num_patches + 2, 768)
|
||||
|
||||
# Test forward
|
||||
imgs = torch.rand(1, 3, 224, 224)
|
||||
outs = model(imgs)
|
||||
patch_token, cls_token, dist_token = outs[0]
|
||||
assert patch_token.shape == (1, 768, 14, 14)
|
||||
assert cls_token.shape == (1, 768)
|
||||
assert dist_token.shape == (1, 768)
|
||||
|
||||
# Test multiple out_indices
|
||||
model = DistilledVisionTransformer(
|
||||
**cfg_ori, out_indices=(0, 1, 2, 3), output_cls_token=False)
|
||||
outs = model(imgs)
|
||||
for out in outs:
|
||||
assert out.shape == (1, 768, 14, 14)
|
|
@ -4,9 +4,9 @@ from unittest.mock import patch
|
|||
import pytest
|
||||
import torch
|
||||
|
||||
from mmcls.models.heads import (ClsHead, LinearClsHead, MultiLabelClsHead,
|
||||
MultiLabelLinearClsHead, StackedLinearClsHead,
|
||||
VisionTransformerClsHead)
|
||||
from mmcls.models.heads import (ClsHead, DeiTClsHead, LinearClsHead,
|
||||
MultiLabelClsHead, MultiLabelLinearClsHead,
|
||||
StackedLinearClsHead, VisionTransformerClsHead)
|
||||
|
||||
|
||||
@pytest.mark.parametrize('feat', [torch.rand(4, 3), (torch.rand(4, 3), )])
|
||||
|
@ -157,3 +157,44 @@ def test_vit_head():
|
|||
# test assertion
|
||||
with pytest.raises(ValueError):
|
||||
VisionTransformerClsHead(-1, 100)
|
||||
|
||||
|
||||
def test_deit_head():
|
||||
fake_features = ([
|
||||
torch.rand(4, 7, 7, 16),
|
||||
torch.rand(4, 100),
|
||||
torch.rand(4, 100)
|
||||
], )
|
||||
fake_gt_label = torch.randint(0, 10, (4, ))
|
||||
|
||||
# test deit head forward
|
||||
head = DeiTClsHead(num_classes=10, in_channels=100)
|
||||
losses = head.forward_train(fake_features, fake_gt_label)
|
||||
assert not hasattr(head.layers, 'pre_logits')
|
||||
assert not hasattr(head.layers, 'act')
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test deit head forward with hidden layer
|
||||
head = DeiTClsHead(num_classes=10, in_channels=100, hidden_dim=20)
|
||||
losses = head.forward_train(fake_features, fake_gt_label)
|
||||
assert hasattr(head.layers, 'pre_logits') and hasattr(head.layers, 'act')
|
||||
assert losses['loss'].item() > 0
|
||||
|
||||
# test deit head init_weights
|
||||
head = DeiTClsHead(10, 100, hidden_dim=20)
|
||||
head.init_weights()
|
||||
assert abs(head.layers.pre_logits.weight).sum() > 0
|
||||
|
||||
# test simple_test
|
||||
head = DeiTClsHead(10, 100, hidden_dim=20)
|
||||
pred = head.simple_test(fake_features)
|
||||
assert isinstance(pred, list) and len(pred) == 4
|
||||
|
||||
with patch('torch.onnx.is_in_onnx_export', return_value=True):
|
||||
head = DeiTClsHead(10, 100, hidden_dim=20)
|
||||
pred = head.simple_test(fake_features)
|
||||
assert pred.shape == (4, 10)
|
||||
|
||||
# test assertion
|
||||
with pytest.raises(ValueError):
|
||||
DeiTClsHead(-1, 100)
|
||||
|
|
Loading…
Reference in New Issue