[Reproduction] Reproduce training results of DeiT. (#711)
* Update deit training settings * Update decay config * Add mixup&cutmix and drop path rate * Update training configs * Update model-zoo * Add commentspull/717/head
parent
1a28f9ace6
commit
9fd35dd7b5
|
@ -10,7 +10,7 @@ paramwise_cfg = dict(
|
|||
# lr = 5e-4 * 128 * 8 / 512 = 0.001
|
||||
optimizer = dict(
|
||||
type='AdamW',
|
||||
lr=5e-4 * 128 * 8 / 512,
|
||||
lr=5e-4 * 1024 / 512,
|
||||
weight_decay=0.05,
|
||||
eps=1e-8,
|
||||
betas=(0.9, 0.999),
|
||||
|
|
|
@ -19,11 +19,12 @@ The teacher of the distilled version DeiT is RegNetY-16GF.
|
|||
|
||||
| Model | Pretrain | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
|:---------------------:|:------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
|
||||
| DeiT-tiny\* | From scratch | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
|
||||
| DeiT-tiny | From scratch | 5.72 | 1.08 | 74.50 | 92.24 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.log.json) |
|
||||
| DeiT-tiny distilled\* | From scratch | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) |
|
||||
| DeiT-small\* | From scratch | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
|
||||
| DeiT-small | From scratch | 22.05 | 4.24 | 80.69 | 95.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.log.json) |
|
||||
| DeiT-small distilled\*| From scratch | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) |
|
||||
| DeiT-base\* | From scratch | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
|
||||
| DeiT-base | From scratch | 86.57 | 16.86 | 81.76 | 95.81 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.log.json) |
|
||||
| DeiT-base\* | From scratch | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
|
||||
| DeiT-base distilled\* | From scratch | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) |
|
||||
| DeiT-base 384px\* | ImageNet-1k | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
|
||||
| DeiT-base distilled 384px\* | ImageNet-1k | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) |
|
||||
|
|
|
@ -2,9 +2,12 @@ _base_ = './deit-small_pt-4xb256_in1k.py'
|
|||
|
||||
# model settings
|
||||
model = dict(
|
||||
backbone=dict(type='VisionTransformer', arch='deit-base'),
|
||||
backbone=dict(
|
||||
type='VisionTransformer', arch='deit-base', drop_path_rate=0.1),
|
||||
head=dict(type='VisionTransformerClsHead', in_channels=768),
|
||||
)
|
||||
|
||||
# data settings
|
||||
data = dict(samples_per_gpu=64, workers_per_gpu=5)
|
||||
|
||||
custom_hooks = [dict(type='EMAHook', momentum=4e-5, priority='ABOVE_NORMAL')]
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
# In small and tiny arch, remove drop path and EMA hook comparing with the
|
||||
# original config
|
||||
_base_ = [
|
||||
'../_base_/datasets/imagenet_bs64_pil_resize_autoaug.py',
|
||||
'../_base_/schedules/imagenet_bs4096_AdamW.py',
|
||||
'../_base_/datasets/imagenet_bs64_swin_224.py',
|
||||
'../_base_/schedules/imagenet_bs1024_adamw_swin.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
|
@ -23,7 +25,20 @@ model = dict(
|
|||
init_cfg=[
|
||||
dict(type='TruncNormal', layer='Linear', std=.02),
|
||||
dict(type='Constant', layer='LayerNorm', val=1., bias=0.),
|
||||
])
|
||||
],
|
||||
train_cfg=dict(augments=[
|
||||
dict(type='BatchMixup', alpha=0.8, num_classes=1000, prob=0.5),
|
||||
dict(type='BatchCutMix', alpha=1.0, num_classes=1000, prob=0.5)
|
||||
]))
|
||||
|
||||
# data settings
|
||||
data = dict(samples_per_gpu=256, workers_per_gpu=5)
|
||||
|
||||
paramwise_cfg = dict(
|
||||
norm_decay_mult=0.0,
|
||||
bias_decay_mult=0.0,
|
||||
custom_keys={
|
||||
'.cls_token': dict(decay_mult=0.0),
|
||||
'.pos_embed': dict(decay_mult=0.0)
|
||||
})
|
||||
optimizer = dict(paramwise_cfg=paramwise_cfg)
|
||||
|
|
|
@ -16,7 +16,7 @@ Collections:
|
|||
Version: https://github.com/open-mmlab/mmclassification/blob/v0.19.0/mmcls/models/backbones/deit.py
|
||||
|
||||
Models:
|
||||
- Name: deit-tiny_3rdparty_pt-4xb256_in1k
|
||||
- Name: deit-tiny_pt-4xb256_in1k
|
||||
Metadata:
|
||||
FLOPs: 1080000000
|
||||
Parameters: 5720000
|
||||
|
@ -24,13 +24,10 @@ Models:
|
|||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 72.13
|
||||
Top 5 Accuracy: 91.13
|
||||
Top 1 Accuracy: 74.50
|
||||
Top 5 Accuracy: 92.24
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L63
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth
|
||||
Config: configs/deit/deit-tiny_pt-4xb256_in1k.py
|
||||
- Name: deit-tiny-distilled_3rdparty_pt-4xb256_in1k
|
||||
Metadata:
|
||||
|
@ -48,7 +45,7 @@ Models:
|
|||
Weights: https://dl.fbaipublicfiles.com/deit/deit_tiny_distilled_patch16_224-b40b3cf7.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L108
|
||||
Config: configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py
|
||||
- Name: deit-small_3rdparty_pt-4xb256_in1k
|
||||
- Name: deit-small_pt-4xb256_in1k
|
||||
Metadata:
|
||||
FLOPs: 4240000000
|
||||
Parameters: 22050000
|
||||
|
@ -56,13 +53,10 @@ Models:
|
|||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 79.83
|
||||
Top 5 Accuracy: 94.95
|
||||
Top 1 Accuracy: 80.69
|
||||
Top 5 Accuracy: 95.06
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth
|
||||
Converted From:
|
||||
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L78
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth
|
||||
Config: configs/deit/deit-small_pt-4xb256_in1k.py
|
||||
- Name: deit-small-distilled_3rdparty_pt-4xb256_in1k
|
||||
Metadata:
|
||||
|
@ -80,6 +74,19 @@ Models:
|
|||
Weights: https://dl.fbaipublicfiles.com/deit/deit_small_distilled_patch16_224-649709d9.pth
|
||||
Code: https://github.com/facebookresearch/deit/blob/f5123946205daf72a88783dae94cabff98c49c55/models.py#L123
|
||||
Config: configs/deit/deit-small-distilled_pt-4xb256_in1k.py
|
||||
- Name: deit-base_pt-16xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 16860000000
|
||||
Parameters: 86570000
|
||||
In Collection: DeiT
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.76
|
||||
Top 5 Accuracy: 95.81
|
||||
Task: Image Classification
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth
|
||||
Config: configs/deit/deit-base_pt-16xb64_in1k.py
|
||||
- Name: deit-base_3rdparty_pt-16xb64_in1k
|
||||
Metadata:
|
||||
FLOPs: 16860000000
|
||||
|
|
|
@ -71,11 +71,11 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
|
|||
| T2T-ViT_t-24 | 64.00 | 12.69 | 82.71 | 96.09 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/t2t_vit/t2t-vit-t-24_8xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_8xb64_in1k_20211214-b2a68ae3.pth) | [log](https://download.openmmlab.com/mmclassification/v0/t2t-vit/t2t-vit-t-24_8xb64_in1k_20211214-b2a68ae3.log.json)|
|
||||
| Mixer-B/16\* | 59.88 | 12.61 | 76.68 | 92.25 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-base-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-base-p16_3rdparty_64xb64_in1k_20211124-1377e3e0.pth) |
|
||||
| Mixer-L/16\* | 208.2 | 44.57 | 72.34 | 88.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/mlp_mixer/mlp-mixer-large-p16_64xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/mlp-mixer/mixer-large-p16_3rdparty_64xb64_in1k_20211124-5a2519d2.pth) |
|
||||
| DeiT-tiny\* | 5.72 | 1.08 | 72.13 | 91.13 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_3rdparty_pt-4xb256_in1k_20211124-e930093b.pth) |
|
||||
| DeiT-tiny | 5.72 | 1.08 | 74.50 | 92.24 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny_pt-4xb256_in1k_20220218-13b382a0.log.json) |
|
||||
| DeiT-tiny distilled\* | 5.72 | 1.08 | 74.51 | 91.90 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-tiny-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-tiny-distilled_3rdparty_pt-4xb256_in1k_20211216-c429839a.pth) |
|
||||
| DeiT-small\* | 22.05 | 4.24 | 79.83 | 94.95 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_3rdparty_pt-4xb256_in1k_20211124-ffe94edd.pth) |
|
||||
| DeiT-small | 22.05 | 4.24 | 80.69 | 95.06 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-small_pt-4xb256_in1k_20220218-9425b9bb.log.json) |
|
||||
| DeiT-small distilled\* | 22.05 | 4.24 | 81.17 | 95.40 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-small-distilled_pt-4xb256_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-small-distilled_3rdparty_pt-4xb256_in1k_20211216-4de1d725.pth) |
|
||||
| DeiT-base\* | 86.57 | 16.86 | 81.79 | 95.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_pt-16xb64_in1k_20211124-6f40c188.pth) |
|
||||
| DeiT-base | 86.57 | 16.86 | 81.76 | 95.81 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.pth) | [log](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_pt-16xb64_in1k_20220216-db63c16c.log.json) |
|
||||
| DeiT-base distilled\* | 86.57 | 16.86 | 83.33 | 96.49 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_pt-16xb64_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_pt-16xb64_in1k_20211216-42891296.pth) |
|
||||
| DeiT-base 384px\* | 86.86 | 49.37 | 83.04 | 96.31 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base_3rdparty_ft-16xb32_in1k-384px_20211124-822d02f2.pth) |
|
||||
| DeiT-base distilled 384px\* | 86.86 | 49.37 | 85.55 | 97.35 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/deit/deit-base-distilled_ft-16xb32_in1k-384px.py) | [model](https://download.openmmlab.com/mmclassification/v0/deit/deit-base-distilled_3rdparty_ft-16xb32_in1k-384px_20211216-e48d6000.pth) |
|
||||
|
|
Loading…
Reference in New Issue