[Feature] Support EfficientNet (#649)

* add config for resnest test

* fix config

* add label smoothing

* add memcached

* minor fix

* fix bug

* fix config

* add config

* minor fix

* fix configs

* use EResize

* change interpolation

* add more configs

* add docsting

* add unittest

* remove unnecessary changes

* minor fix

* add more docstring

* fix linting

* add efficient backbone

* add config

* add Edge Residual

* fix bug

* remove unnecessary files

* refactor

* add resize in crop to ensure crop size is output size

* fix bug and add comments

* test

* fix

* add more configs

* add more configs

* add more configs

* fix bug

* add model zoo

* fix

* reorganize code

* add edge tpu

* add edge tpu converter

* rename

* update readme

* reorganize code and config

* Rename configs of EfficientNet, and add metafile & model_zoo

* Remove `backend='pillow'`

* Add comments about EfficientNet-EdgeTPU

* Rename the convert tool of EfficientNet.

* Refactor EfficientNet and update docstring.

* Update EfficientNet-EdgeTPU config

* Fix unit tests

Co-authored-by: lixinran <lixr423@outlook.com>
Co-authored-by: lixinran <lixinran@sensetime.com>
Co-authored-by: mzr1996 <mzr1996@163.com>
pull/671/head
Zhicheng Chen 2022-01-25 12:14:17 +08:00 committed by GitHub
parent 16864c7495
commit d56170a734
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
43 changed files with 2187 additions and 20 deletions

View File

@ -77,7 +77,7 @@ Results and models are available in the [model zoo](https://mmclassification.rea
- [x] [DeiT](https://github.com/open-mmlab/mmclassification/tree/master/configs/deit)
- [x] [Conformer](https://github.com/open-mmlab/mmclassification/tree/master/configs/conformer)
- [x] [T2T-ViT](https://github.com/open-mmlab/mmclassification/tree/master/configs/t2t_vit)
- [ ] EfficientNet
- [x] [EfficientNet](https://github.com/open-mmlab/mmclassification/tree/master/configs/efficientnet)
- [ ] Twins
- [ ] HRNet

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='EfficientNet', arch='b0'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='EfficientNet', arch='b1'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='EfficientNet', arch='b2'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1408,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='EfficientNet', arch='b3'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1536,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='EfficientNet', arch='b4'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1792,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='EfficientNet', arch='b5'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2048,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='EfficientNet', arch='b6'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2304,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='EfficientNet', arch='b7'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2560,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,12 @@
# model settings
model = dict(
type='ImageClassifier',
backbone=dict(type='EfficientNet', arch='b8'),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=2816,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,13 @@
# model settings
model = dict(
type='ImageClassifier',
# `em` means EfficientNet-EdgeTPU-M arch
backbone=dict(type='EfficientNet', arch='em', act_cfg=dict(type='ReLU')),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,13 @@
# model settings
model = dict(
type='ImageClassifier',
# `es` means EfficientNet-EdgeTPU-S arch
backbone=dict(type='EfficientNet', arch='es', act_cfg=dict(type='ReLU')),
neck=dict(type='GlobalAveragePooling'),
head=dict(
type='LinearClsHead',
num_classes=1000,
in_channels=1280,
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
topk=(1, 5),
))

View File

@ -0,0 +1,62 @@
# Rethinking Model Scaling for Convolutional Neural Networks
<!-- {EfficientNet} -->
<!-- [ALGORITHM] -->
## Abstract
<!-- [ABSTRACT] -->
Convolutional Neural Networks (ConvNets) are commonly developed at a fixed resource budget, and then scaled up for better accuracy if more resources are available. In this paper, we systematically study model scaling and identify that carefully balancing network depth, width, and resolution can lead to better performance. Based on this observation, we propose a new scaling method that uniformly scales all dimensions of depth/width/resolution using a simple yet highly effective compound coefficient. We demonstrate the effectiveness of this method on scaling up MobileNets and ResNet. To go even further, we use neural architecture search to design a new baseline network and scale it up to obtain a family of models, called EfficientNets, which achieve much better accuracy and efficiency than previous ConvNets. In particular, our EfficientNet-B7 achieves state-of-the-art 84.3% top-1 accuracy on ImageNet, while being 8.4x smaller and 6.1x faster on inference than the best existing ConvNet. Our EfficientNets also transfer well and achieve state-of-the-art accuracy on CIFAR-100 (91.7%), Flowers (98.8%), and 3 other transfer learning datasets, with an order of magnitude fewer parameters.
<!-- [IMAGE] -->
<div align=center>
<img src="https://user-images.githubusercontent.com/26739999/150078232-d28c91fc-d0e8-43e3-9d20-b5162f0fb463.png" width="60%"/>
</div>
## Citation
```latex
@inproceedings{tan2019efficientnet,
title={Efficientnet: Rethinking model scaling for convolutional neural networks},
author={Tan, Mingxing and Le, Quoc},
booktitle={International Conference on Machine Learning},
pages={6105--6114},
year={2019},
organization={PMLR}
}
```
## Results and models
### ImageNet-1k
In the result table, AA means trained with AutoAugment pre-processing, more details can be found in the [paper](https://arxiv.org/abs/1805.09501), and AdvProp is a method to train with adversarial examples, more details can be found in the [paper](https://arxiv.org/abs/1911.09665).
Note: In MMClassification, we support training with AutoAugment, don't support AdvProp by now.
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:------:|:--------:|
| EfficientNet-B0\* | 5.29 | 0.02 | 76.74 | 93.17 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32_in1k_20220119-a7e2a0b1.pth) |
| EfficientNet-B0 (AA)\* | 5.29 | 0.02 | 77.26 | 93.41 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa_in1k_20220119-8d939117.pth) |
| EfficientNet-B0 (AA + AdvProp)\* | 5.29 | 0.02 | 77.53 | 93.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b0_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa-advprop_in1k_20220119-26434485.pth) |
| EfficientNet-B1\* | 7.79 | 0.03 | 78.68 | 94.28 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32_in1k_20220119-002556d9.pth) |
| EfficientNet-B1 (AA)\* | 7.79 | 0.03 | 79.20 | 94.42 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa_in1k_20220119-619d8ae3.pth) |
| EfficientNet-B1 (AA + AdvProp)\* | 7.79 | 0.03 | 79.52 | 94.43 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b1_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa-advprop_in1k_20220119-5715267d.pth) |
| EfficientNet-B2\* | 9.11 | 0.03 | 79.64 | 94.80 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32_in1k_20220119-ea374a30.pth) |
| EfficientNet-B2 (AA)\* | 9.11 | 0.03 | 80.21 | 94.96 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa_in1k_20220119-dd61e80b.pth) |
| EfficientNet-B2 (AA + AdvProp)\* | 9.11 | 0.03 | 80.45 | 95.07 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b2_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa-advprop_in1k_20220119-1655338a.pth) |
| EfficientNet-B3\* | 12.23 | 0.06 | 81.01 | 95.34 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32_in1k_20220119-4b4d7487.pth) |
| EfficientNet-B3 (AA)\* | 12.23 | 0.06 | 81.58 | 95.67 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa_in1k_20220119-5b4887a0.pth) |
| EfficientNet-B3 (AA + AdvProp)\* | 12.23 | 0.06 | 81.81 | 95.69 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b3_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa-advprop_in1k_20220119-53b41118.pth) |
| EfficientNet-B4\* | 19.34 | 0.12 | 82.57 | 96.09 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32_in1k_20220119-81fd4077.pth) |
| EfficientNet-B4 (AA)\* | 19.34 | 0.12 | 82.95 | 96.26 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa_in1k_20220119-45b8bd2b.pth) |
| EfficientNet-B4 (AA + AdvProp)\* | 19.34 | 0.12 | 83.25 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b4_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa-advprop_in1k_20220119-38c2238c.pth) |
| EfficientNet-B5\* | 30.39 | 0.24 | 83.18 | 96.47 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b5_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32_in1k_20220119-e9814430.pth) |
| EfficientNet-B5 (AA)\* | 30.39 | 0.24 | 83.82 | 96.76 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b5_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa_in1k_20220119-2cab8b78.pth) |
| EfficientNet-B5 (AA + AdvProp)\* | 30.39 | 0.24 | 84.21 | 96.98 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b5_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa-advprop_in1k_20220119-f57a895a.pth) |
| EfficientNet-B6 (AA)\* | 43.04 | 0.41 | 84.05 | 96.82 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b6_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa_in1k_20220119-45b03310.pth) |
| EfficientNet-B6 (AA + AdvProp)\* | 43.04 | 0.41 | 84.74 | 97.14 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b6_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa-advprop_in1k_20220119-bfe3485e.pth) |
| EfficientNet-B7 (AA)\* | 66.35 | 0.72 | 84.38 | 96.88 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b7_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa_in1k_20220119-bf03951c.pth) |
| EfficientNet-B7 (AA + AdvProp)\* | 66.35 | 0.72 | 85.14 | 97.23 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b7_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k_20220119-c6dbff10.pth) |
| EfficientNet-B8 (AA + AdvProp)\* | 87.41 | 1.09 | 85.38 | 97.28 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b8_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b8_3rdparty_8xb32-aa-advprop_in1k_20220119-297ce1b7.pth) |
*Models with \* are converted from the [official repo](https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet). The config files of these models are only for inference. We don't ensure these config files' training accuracy and welcome you to contribute your reproduction results.*

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b0.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=224,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b0.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=224,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b1.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=240,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=240,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b1.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=240,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=240,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b2.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=260,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=260,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b2.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=260,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=260,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b3.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=300,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=300,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b3.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=300,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=300,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b4.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=380,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=380,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b4.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=380,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=380,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b5.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=456,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=456,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b5.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=456,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=456,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b6.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=528,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=528,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b6.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=528,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=528,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b7.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=600,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=600,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b7.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=600,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=600,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b8.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=672,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=672,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_b8.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=672,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=672,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_em.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=240,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=240,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,39 @@
_base_ = [
'../_base_/models/efficientnet_es.py',
'../_base_/datasets/imagenet_bs32.py',
'../_base_/schedules/imagenet_bs256.py',
'../_base_/default_runtime.py',
]
# dataset settings
dataset_type = 'ImageNet'
img_norm_cfg = dict(
mean=[127.5, 127.5, 127.5], std=[127.5, 127.5, 127.5], to_rgb=True)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='RandomResizedCrop',
size=224,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='ToTensor', keys=['gt_label']),
dict(type='Collect', keys=['img', 'gt_label'])
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='CenterCrop',
crop_size=224,
efficientnet_style=True,
interpolation='bicubic'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img'])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,391 @@
Collections:
- Name: EfficientNet
Metadata:
Training Data: ImageNet-1k
Architecture:
- 1x1 Convolution
- Average Pooling
- Convolution
- Dense Connections
- Dropout
- Inverted Residual Block
- RMSProp
- Squeeze-and-Excitation Block
- Swish
Paper:
URL: https://arxiv.org/abs/1905.11946v5
Title: "EfficientNet: Rethinking Model Scaling for Convolutional Neural Networks"
README: configs/efficientnet/README.md
Code:
Version: v0.20.0
URL: https://github.com/open-mmlab/mmclassification/blob/v0.20.0/mmcls/models/backbones/efficientnet.py
Models:
- Name: efficientnet-b0_3rdparty_8xb32_in1k
Metadata:
FLOPs: 16481180
Parameters: 5288548
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 76.74
Top 5 Accuracy: 93.17
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32_in1k_20220119-a7e2a0b1.pth
Config: configs/efficientnet/efficientnet-b0_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b0.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b0_3rdparty_8xb32-aa_in1k
Metadata:
FLOPs: 16481180
Parameters: 5288548
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 77.26
Top 5 Accuracy: 93.41
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa_in1k_20220119-8d939117.pth
Config: configs/efficientnet/efficientnet-b0_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b0.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b0_3rdparty_8xb32-aa-advprop_in1k
Metadata:
FLOPs: 16481180
Parameters: 5288548
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 77.53
Top 5 Accuracy: 93.61
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa-advprop_in1k_20220119-26434485.pth
Config: configs/efficientnet/efficientnet-b0_8xb32-01norm_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b0.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b1_3rdparty_8xb32_in1k
Metadata:
FLOPs: 27052224
Parameters: 7794184
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 78.68
Top 5 Accuracy: 94.28
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32_in1k_20220119-002556d9.pth
Config: configs/efficientnet/efficientnet-b1_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b1.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b1_3rdparty_8xb32-aa_in1k
Metadata:
FLOPs: 27052224
Parameters: 7794184
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.20
Top 5 Accuracy: 94.42
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa_in1k_20220119-619d8ae3.pth
Config: configs/efficientnet/efficientnet-b1_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b1.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b1_3rdparty_8xb32-aa-advprop_in1k
Metadata:
FLOPs: 27052224
Parameters: 7794184
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.52
Top 5 Accuracy: 94.43
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa-advprop_in1k_20220119-5715267d.pth
Config: configs/efficientnet/efficientnet-b1_8xb32-01norm_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b1.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b2_3rdparty_8xb32_in1k
Metadata:
FLOPs: 34346386
Parameters: 9109994
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 79.64
Top 5 Accuracy: 94.80
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32_in1k_20220119-ea374a30.pth
Config: configs/efficientnet/efficientnet-b2_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b2.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b2_3rdparty_8xb32-aa_in1k
Metadata:
FLOPs: 34346386
Parameters: 9109994
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 80.21
Top 5 Accuracy: 94.96
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa_in1k_20220119-dd61e80b.pth
Config: configs/efficientnet/efficientnet-b2_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b2.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b2_3rdparty_8xb32-aa-advprop_in1k
Metadata:
FLOPs: 34346386
Parameters: 9109994
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 80.45
Top 5 Accuracy: 95.07
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa-advprop_in1k_20220119-1655338a.pth
Config: configs/efficientnet/efficientnet-b2_8xb32-01norm_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b2.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b3_3rdparty_8xb32_in1k
Metadata:
FLOPs: 58641904
Parameters: 12233232
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.01
Top 5 Accuracy: 95.34
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32_in1k_20220119-4b4d7487.pth
Config: configs/efficientnet/efficientnet-b3_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b3.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b3_3rdparty_8xb32-aa_in1k
Metadata:
FLOPs: 58641904
Parameters: 12233232
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.58
Top 5 Accuracy: 95.67
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa_in1k_20220119-5b4887a0.pth
Config: configs/efficientnet/efficientnet-b3_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b3.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b3_3rdparty_8xb32-aa-advprop_in1k
Metadata:
FLOPs: 58641904
Parameters: 12233232
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 81.81
Top 5 Accuracy: 95.69
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa-advprop_in1k_20220119-53b41118.pth
Config: configs/efficientnet/efficientnet-b3_8xb32-01norm_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b3.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b4_3rdparty_8xb32_in1k
Metadata:
FLOPs: 121870624
Parameters: 19341616
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.57
Top 5 Accuracy: 96.09
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32_in1k_20220119-81fd4077.pth
Config: configs/efficientnet/efficientnet-b4_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b4.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b4_3rdparty_8xb32-aa_in1k
Metadata:
FLOPs: 121870624
Parameters: 19341616
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 82.95
Top 5 Accuracy: 96.26
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa_in1k_20220119-45b8bd2b.pth
Config: configs/efficientnet/efficientnet-b4_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b4.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b4_3rdparty_8xb32-aa-advprop_in1k
Metadata:
FLOPs: 121870624
Parameters: 19341616
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.25
Top 5 Accuracy: 96.44
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa-advprop_in1k_20220119-38c2238c.pth
Config: configs/efficientnet/efficientnet-b4_8xb32-01norm_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b4.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b5_3rdparty_8xb32_in1k
Metadata:
FLOPs: 243879440
Parameters: 30389784
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.18
Top 5 Accuracy: 96.47
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32_in1k_20220119-e9814430.pth
Config: configs/efficientnet/efficientnet-b5_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckpts/efficientnet-b5.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b5_3rdparty_8xb32-aa_in1k
Metadata:
FLOPs: 243879440
Parameters: 30389784
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 83.82
Top 5 Accuracy: 96.76
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa_in1k_20220119-2cab8b78.pth
Config: configs/efficientnet/efficientnet-b5_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b5.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b5_3rdparty_8xb32-aa-advprop_in1k
Metadata:
FLOPs: 243879440
Parameters: 30389784
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.21
Top 5 Accuracy: 96.98
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa-advprop_in1k_20220119-f57a895a.pth
Config: configs/efficientnet/efficientnet-b5_8xb32-01norm_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b5.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b6_3rdparty_8xb32-aa_in1k
Metadata:
FLOPs: 412002408
Parameters: 43040704
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.05
Top 5 Accuracy: 96.82
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa_in1k_20220119-45b03310.pth
Config: configs/efficientnet/efficientnet-b6_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b6.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b6_3rdparty_8xb32-aa-advprop_in1k
Metadata:
FLOPs: 412002408
Parameters: 43040704
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.74
Top 5 Accuracy: 97.14
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa-advprop_in1k_20220119-bfe3485e.pth
Config: configs/efficientnet/efficientnet-b6_8xb32-01norm_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b6.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b7_3rdparty_8xb32-aa_in1k
Metadata:
FLOPs: 715526512
Parameters: 66347960
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 84.38
Top 5 Accuracy: 96.88
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa_in1k_20220119-bf03951c.pth
Config: configs/efficientnet/efficientnet-b7_8xb32_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/ckptsaug/efficientnet-b7.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k
Metadata:
FLOPs: 715526512
Parameters: 66347960
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.14
Top 5 Accuracy: 97.23
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k_20220119-c6dbff10.pth
Config: configs/efficientnet/efficientnet-b7_8xb32-01norm_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b7.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet
- Name: efficientnet-b8_3rdparty_8xb32-aa-advprop_in1k
Metadata:
FLOPs: 1092755326
Parameters: 87413142
In Collections: EfficientNet
Results:
- Dataset: ImageNet-1k
Metrics:
Top 1 Accuracy: 85.38
Top 5 Accuracy: 97.28
Task: Image Classification
Weights: https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b8_3rdparty_8xb32-aa-advprop_in1k_20220119-297ce1b7.pth
Config: configs/efficientnet/efficientnet-b8_8xb32-01norm_in1k.py
Converted From:
Weights: https://storage.googleapis.com/cloud-tpu-checkpoints/efficientnet/advprop/efficientnet-b8.tar.gz
Code: https://github.com/tensorflow/tpu/tree/master/models/official/efficientnet

View File

@ -26,7 +26,7 @@ The depth of representations is of central importance for many visual recognitio
## Results and models
## Cifar10
### Cifar10
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:---------:|:--------:|
@ -36,7 +36,7 @@ The depth of representations is of central importance for many visual recognitio
| ResNet-101-b16x8 | 42.51 | 2.52 | 95.58 | 99.87 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet101_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet101_b16x8_cifar10_20210528-2d29e936.log.json) |
| ResNet-152-b16x8 | 58.16 | 3.74 | 95.76 | 99.89 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/resnet/resnet152_8xb16_cifar10.py) | [model](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.pth) &#124; [log](https://download.openmmlab.com/mmclassification/v0/resnet/resnet152_b16x8_cifar10_20210528-3e8e9178.log.json) |
## Cifar100
### Cifar100
| Model | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|:---------------------:|:---------:|:--------:|:---------:|:---------:|:---------:|:--------:|

View File

@ -83,6 +83,29 @@ The ResNet family models below are trained by standard data augmentations, i.e.,
| Conformer-small-p32\* | 38.85 | 7.09 | 81.96 | 96.02 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p32_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p32_8xb128_in1k_20211206-947a0816.pth) |
| Conformer-small-p16\* | 37.67 | 10.31 | 83.32 | 96.46 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-small-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-small-p16_3rdparty_8xb128_in1k_20211206-3065dcf5.pth) |
| Conformer-base-p16\* | 83.29 | 22.89 | 83.82 | 96.59 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/conformer/conformer-base-p16_8xb128_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/conformer/conformer-base-p16_3rdparty_8xb128_in1k_20211206-bfdf8637.pth) |
| EfficientNet-B0\* | 5.29 | 0.02 | 76.74 | 93.17 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32_in1k_20220119-a7e2a0b1.pth) |
| EfficientNet-B0 (AA)\* | 5.29 | 0.02 | 77.26 | 93.41 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b0_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa_in1k_20220119-8d939117.pth) |
| EfficientNet-B0 (AA + AdvProp)\* | 5.29 | 0.02 | 77.53 | 93.61 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b0_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b0_3rdparty_8xb32-aa-advprop_in1k_20220119-26434485.pth) |
| EfficientNet-B1\* | 7.79 | 0.03 | 78.68 | 94.28 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32_in1k_20220119-002556d9.pth) |
| EfficientNet-B1 (AA)\* | 7.79 | 0.03 | 79.20 | 94.42 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b1_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa_in1k_20220119-619d8ae3.pth) |
| EfficientNet-B1 (AA + AdvProp)\* | 7.79 | 0.03 | 79.52 | 94.43 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b1_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b1_3rdparty_8xb32-aa-advprop_in1k_20220119-5715267d.pth) |
| EfficientNet-B2\* | 9.11 | 0.03 | 79.64 | 94.80 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32_in1k_20220119-ea374a30.pth) |
| EfficientNet-B2 (AA)\* | 9.11 | 0.03 | 80.21 | 94.96 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b2_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa_in1k_20220119-dd61e80b.pth) |
| EfficientNet-B2 (AA + AdvProp)\* | 9.11 | 0.03 | 80.45 | 95.07 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b2_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b2_3rdparty_8xb32-aa-advprop_in1k_20220119-1655338a.pth) |
| EfficientNet-B3\* | 12.23 | 0.06 | 81.01 | 95.34 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32_in1k_20220119-4b4d7487.pth) |
| EfficientNet-B3 (AA)\* | 12.23 | 0.06 | 81.58 | 95.67 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b3_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa_in1k_20220119-5b4887a0.pth) |
| EfficientNet-B3 (AA + AdvProp)\* | 12.23 | 0.06 | 81.81 | 95.69 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b3_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b3_3rdparty_8xb32-aa-advprop_in1k_20220119-53b41118.pth) |
| EfficientNet-B4\* | 19.34 | 0.12 | 82.57 | 96.09 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32_in1k_20220119-81fd4077.pth) |
| EfficientNet-B4 (AA)\* | 19.34 | 0.12 | 82.95 | 96.26 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b4_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa_in1k_20220119-45b8bd2b.pth) |
| EfficientNet-B4 (AA + AdvProp)\* | 19.34 | 0.12 | 83.25 | 96.44 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b4_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b4_3rdparty_8xb32-aa-advprop_in1k_20220119-38c2238c.pth) |
| EfficientNet-B5\* | 30.39 | 0.24 | 83.18 | 96.47 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b5_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32_in1k_20220119-e9814430.pth) |
| EfficientNet-B5 (AA)\* | 30.39 | 0.24 | 83.82 | 96.76 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b5_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa_in1k_20220119-2cab8b78.pth) |
| EfficientNet-B5 (AA + AdvProp)\* | 30.39 | 0.24 | 84.21 | 96.98 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b5_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b5_3rdparty_8xb32-aa-advprop_in1k_20220119-f57a895a.pth) |
| EfficientNet-B6 (AA)\* | 43.04 | 0.41 | 84.05 | 96.82 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b6_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa_in1k_20220119-45b03310.pth) |
| EfficientNet-B6 (AA + AdvProp)\* | 43.04 | 0.41 | 84.74 | 97.14 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b6_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b6_3rdparty_8xb32-aa-advprop_in1k_20220119-bfe3485e.pth) |
| EfficientNet-B7 (AA)\* | 66.35 | 0.72 | 84.38 | 96.88 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b7_8xb32_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa_in1k_20220119-bf03951c.pth) |
| EfficientNet-B7 (AA + AdvProp)\* | 66.35 | 0.72 | 85.14 | 97.23 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b7_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b7_3rdparty_8xb32-aa-advprop_in1k_20220119-c6dbff10.pth) |
| EfficientNet-B8 (AA + AdvProp)\* | 87.41 | 1.09 | 85.38 | 97.28 | [config](https://github.com/open-mmlab/mmclassification/blob/master/configs/efficientnet/efficientnet-b8_8xb32-01norm_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/efficientnet/efficientnet-b8_3rdparty_8xb32-aa-advprop_in1k_20220119-297ce1b7.pth) |
*Models with \* are converted from other repos, others are trained by ourselves.*

View File

@ -2,6 +2,7 @@
from .alexnet import AlexNet
from .conformer import Conformer
from .deit import DistilledVisionTransformer
from .efficientnet import EfficientNet
from .lenet import LeNet5
from .mlp_mixer import MlpMixer
from .mobilenet_v2 import MobileNetV2
@ -29,5 +30,5 @@ __all__ = [
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
'SwinTransformer', 'TNT', 'TIMMBackbone', 'T2T_ViT', 'Res2Net', 'RepVGG',
'Conformer', 'MlpMixer', 'DistilledVisionTransformer'
'Conformer', 'MlpMixer', 'DistilledVisionTransformer', 'EfficientNet'
]

View File

@ -0,0 +1,407 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import math
from functools import partial
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import ConvModule, DropPath
from mmcv.runner import BaseModule, Sequential
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmcls.models.utils import InvertedResidual, SELayer, make_divisible
from ..builder import BACKBONES
class EdgeResidual(BaseModule):
"""Edge Residual Block.
Args:
in_channels (int): The input channels of this module.
out_channels (int): The output channels of this module.
mid_channels (int): The input channels of the second convolution.
kernel_size (int): The kernel size of the first convolution.
Defaults to 3.
stride (int): The stride of the first convolution. Defaults to 1.
se_cfg (dict, optional): Config dict for se layer. Defaults to None,
which means no se layer.
with_residual (bool): Use residual connection. Defaults to True.
conv_cfg (dict, optional): Config dict for convolution layer.
Defaults to None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Defaults to ``dict(type='ReLU')``.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
init_cfg (dict | list[dict], optional): Initialization config dict.
"""
def __init__(self,
in_channels,
out_channels,
mid_channels,
kernel_size=3,
stride=1,
se_cfg=None,
with_residual=True,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
drop_path_rate=0.,
with_cp=False,
init_cfg=None):
super(EdgeResidual, self).__init__(init_cfg=init_cfg)
assert stride in [1, 2]
self.with_cp = with_cp
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.with_se = se_cfg is not None
self.with_residual = (
stride == 1 and in_channels == out_channels and with_residual)
if self.with_se:
assert isinstance(se_cfg, dict)
self.conv1 = ConvModule(
in_channels=in_channels,
out_channels=mid_channels,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if self.with_se:
self.se = SELayer(**se_cfg)
self.conv2 = ConvModule(
in_channels=mid_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
def forward(self, x):
def _inner_forward(x):
out = x
out = self.conv1(out)
if self.with_se:
out = self.se(out)
out = self.conv2(out)
if self.with_residual:
return x + self.drop_path(out)
else:
return out
if self.with_cp and x.requires_grad:
out = cp.checkpoint(_inner_forward, x)
else:
out = _inner_forward(x)
return out
def model_scaling(layer_setting, arch_setting):
"""Scaling operation to the layer's parameters according to the
arch_setting."""
# scale width
new_layer_setting = copy.deepcopy(layer_setting)
for layer_cfg in new_layer_setting:
for block_cfg in layer_cfg:
block_cfg[1] = make_divisible(block_cfg[1] * arch_setting[0], 8)
# scale depth
split_layer_setting = [new_layer_setting[0]]
for layer_cfg in new_layer_setting[1:-1]:
tmp_index = [0]
for i in range(len(layer_cfg) - 1):
if layer_cfg[i + 1][1] != layer_cfg[i][1]:
tmp_index.append(i + 1)
tmp_index.append(len(layer_cfg))
for i in range(len(tmp_index) - 1):
split_layer_setting.append(layer_cfg[tmp_index[i]:tmp_index[i +
1]])
split_layer_setting.append(new_layer_setting[-1])
num_of_layers = [len(layer_cfg) for layer_cfg in split_layer_setting[1:-1]]
new_layers = [
int(math.ceil(arch_setting[1] * num)) for num in num_of_layers
]
merge_layer_setting = [split_layer_setting[0]]
for i, layer_cfg in enumerate(split_layer_setting[1:-1]):
if new_layers[i] <= num_of_layers[i]:
tmp_layer_cfg = layer_cfg[:new_layers[i]]
else:
tmp_layer_cfg = copy.deepcopy(layer_cfg) + [layer_cfg[-1]] * (
new_layers[i] - num_of_layers[i])
if tmp_layer_cfg[0][3] == 1 and i != 0:
merge_layer_setting[-1] += tmp_layer_cfg.copy()
else:
merge_layer_setting.append(tmp_layer_cfg.copy())
merge_layer_setting.append(split_layer_setting[-1])
return merge_layer_setting
@BACKBONES.register_module()
class EfficientNet(BaseBackbone):
"""EfficientNet backbone.
Args:
arch (str): Architecture of efficientnet. Defaults to b0.
out_indices (Sequence[int]): Output from which stages.
Defaults to (6, ).
frozen_stages (int): Stages to be frozen (all param fixed).
Defaults to 0, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer.
Defaults to None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Defaults to dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Defaults to dict(type='Swish').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Defaults to False.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Defaults to False.
"""
# Parameters to build layers.
# 'b' represents the architecture of normal EfficientNet family includes
# 'b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b6', 'b7', 'b8'.
# 'e' represents the architecture of EfficientNet-EdgeTPU including 'es',
# 'em', 'el'.
# 6 parameters are needed to construct a layer, From left to right:
# - kernel_size: The kernel size of the block
# - out_channel: The number of out_channels of the block
# - se_ratio: The sequeeze ratio of SELayer.
# - stride: The stride of the block
# - expand_ratio: The expand_ratio of the mid_channels
# - block_type: -1: Not a block, 0: InvertedResidual, 1: EdgeResidual
layer_settings = {
'b': [[[3, 32, 0, 2, 0, -1]],
[[3, 16, 4, 1, 1, 0]],
[[3, 24, 4, 2, 6, 0],
[3, 24, 4, 1, 6, 0]],
[[5, 40, 4, 2, 6, 0],
[5, 40, 4, 1, 6, 0]],
[[3, 80, 4, 2, 6, 0],
[3, 80, 4, 1, 6, 0],
[3, 80, 4, 1, 6, 0],
[5, 112, 4, 1, 6, 0],
[5, 112, 4, 1, 6, 0],
[5, 112, 4, 1, 6, 0]],
[[5, 192, 4, 2, 6, 0],
[5, 192, 4, 1, 6, 0],
[5, 192, 4, 1, 6, 0],
[5, 192, 4, 1, 6, 0],
[3, 320, 4, 1, 6, 0]],
[[1, 1280, 0, 1, 0, -1]]
],
'e': [[[3, 32, 0, 2, 0, -1]],
[[3, 24, 0, 1, 3, 1]],
[[3, 32, 0, 2, 8, 1],
[3, 32, 0, 1, 8, 1]],
[[3, 48, 0, 2, 8, 1],
[3, 48, 0, 1, 8, 1],
[3, 48, 0, 1, 8, 1],
[3, 48, 0, 1, 8, 1]],
[[5, 96, 0, 2, 8, 0],
[5, 96, 0, 1, 8, 0],
[5, 96, 0, 1, 8, 0],
[5, 96, 0, 1, 8, 0],
[5, 96, 0, 1, 8, 0],
[5, 144, 0, 1, 8, 0],
[5, 144, 0, 1, 8, 0],
[5, 144, 0, 1, 8, 0],
[5, 144, 0, 1, 8, 0]],
[[5, 192, 0, 2, 8, 0],
[5, 192, 0, 1, 8, 0]],
[[1, 1280, 0, 1, 0, -1]]
]
} # yapf: disable
# Parameters to build different kinds of architecture.
# From left to right: scaling factor for width, scaling factor for depth,
# resolution.
arch_settings = {
'b0': (1.0, 1.0, 224),
'b1': (1.0, 1.1, 240),
'b2': (1.1, 1.2, 260),
'b3': (1.2, 1.4, 300),
'b4': (1.4, 1.8, 380),
'b5': (1.6, 2.2, 456),
'b6': (1.8, 2.6, 528),
'b7': (2.0, 3.1, 600),
'b8': (2.2, 3.6, 672),
'es': (1.0, 1.0, 224),
'em': (1.0, 1.1, 240),
'el': (1.2, 1.4, 300)
}
def __init__(self,
arch='b0',
drop_path_rate=0.,
out_indices=(6, ),
frozen_stages=0,
conv_cfg=dict(type='Conv2dAdaptivePadding'),
norm_cfg=dict(type='BN', eps=1e-3),
act_cfg=dict(type='Swish'),
norm_eval=False,
with_cp=False,
init_cfg=[
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant',
layer=['_BatchNorm', 'GroupNorm'],
val=1)
]):
super(EfficientNet, self).__init__(init_cfg)
assert arch in self.arch_settings, \
f'"{arch}" is not one of the arch_settings ' \
f'({", ".join(self.arch_settings.keys())})'
self.arch_setting = self.arch_settings[arch]
self.layer_setting = self.layer_settings[arch[:1]]
for index in out_indices:
if index not in range(0, len(self.layer_setting)):
raise ValueError('the item in out_indices must in '
f'range(0, {len(self.layer_setting)}). '
f'But received {index}')
if frozen_stages not in range(len(self.layer_setting) + 1):
raise ValueError('frozen_stages must be in range(0, '
f'{len(self.layer_setting) + 1}). '
f'But received {frozen_stages}')
self.drop_path_rate = drop_path_rate
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
self.layer_setting = model_scaling(self.layer_setting,
self.arch_setting)
block_cfg_0 = self.layer_setting[0][0]
block_cfg_last = self.layer_setting[-1][0]
self.in_channels = make_divisible(block_cfg_0[1], 8)
self.out_channels = block_cfg_last[1]
self.layers = nn.ModuleList()
self.layers.append(
ConvModule(
in_channels=3,
out_channels=self.in_channels,
kernel_size=block_cfg_0[0],
stride=block_cfg_0[3],
padding=block_cfg_0[0] // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
self.make_layer()
self.layers.append(
ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=block_cfg_last[0],
stride=block_cfg_last[3],
padding=block_cfg_last[0] // 2,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg))
def make_layer(self):
# Without the first and the final conv block.
layer_setting = self.layer_setting[1:-1]
total_num_blocks = sum([len(x) for x in layer_setting])
block_idx = 0
dpr = [
x.item()
for x in torch.linspace(0, self.drop_path_rate, total_num_blocks)
] # stochastic depth decay rule
for layer_cfg in layer_setting:
layer = []
for i, block_cfg in enumerate(layer_cfg):
(kernel_size, out_channels, se_ratio, stride, expand_ratio,
block_type) = block_cfg
mid_channels = int(self.in_channels * expand_ratio)
out_channels = make_divisible(out_channels, 8)
if se_ratio <= 0:
se_cfg = None
else:
se_cfg = dict(
channels=mid_channels,
ratio=expand_ratio * se_ratio,
divisor=1,
act_cfg=(self.act_cfg, dict(type='Sigmoid')))
if block_type == 1: # edge tpu
if i > 0 and expand_ratio == 3:
with_residual = False
expand_ratio = 4
else:
with_residual = True
mid_channels = int(self.in_channels * expand_ratio)
if se_cfg is not None:
se_cfg = dict(
channels=mid_channels,
ratio=se_ratio * expand_ratio,
divisor=1,
act_cfg=(self.act_cfg, dict(type='Sigmoid')))
block = partial(EdgeResidual, with_residual=with_residual)
else:
block = InvertedResidual
layer.append(
block(
in_channels=self.in_channels,
out_channels=out_channels,
mid_channels=mid_channels,
kernel_size=kernel_size,
stride=stride,
se_cfg=se_cfg,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
drop_path_rate=dpr[block_idx],
with_cp=self.with_cp))
self.in_channels = out_channels
block_idx += 1
self.layers.append(Sequential(*layer))
def forward(self, x):
outs = []
for i, layer in enumerate(self.layers):
x = layer(x)
if i in self.out_indices:
outs.append(x)
return tuple(outs)
def _freeze_stages(self):
for i in range(self.frozen_stages):
m = self.layers[i]
m.eval()
for param in m.parameters():
param.requires_grad = False
def train(self, mode=True):
super(EfficientNet, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()

View File

@ -1,35 +1,35 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import ConvModule
from mmcv.cnn.bricks import DropPath
from mmcv.runner import BaseModule
from .se_layer import SELayer
# class InvertedResidual(nn.Module):
class InvertedResidual(BaseModule):
"""Inverted Residual Block.
Args:
in_channels (int): The input channels of this Module.
out_channels (int): The output channels of this Module.
in_channels (int): The input channels of this module.
out_channels (int): The output channels of this module.
mid_channels (int): The input channels of the depthwise convolution.
kernel_size (int): The kernel size of the depthwise convolution.
Default: 3.
stride (int): The stride of the depthwise convolution. Default: 1.
se_cfg (dict): Config dict for se layer. Default: None, which means no
se layer.
conv_cfg (dict): Config dict for convolution layer. Default: None,
Defaults to 3.
stride (int): The stride of the depthwise convolution. Defaults to 1.
se_cfg (dict, optional): Config dict for se layer. Defaults to None,
which means no se layer.
conv_cfg (dict): Config dict for convolution layer. Defaults to None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
Defaults to ``dict(type='BN')``.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
Defaults to ``dict(type='ReLU')``.
drop_path_rate (float): stochastic depth rate. Defaults to 0.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed. Default: False.
Returns:
Tensor: The output tensor.
memory while slowing down the training speed. Defaults to False.
init_cfg (dict | list[dict], optional): Initialization config dict.
"""
def __init__(self,
@ -42,12 +42,15 @@ class InvertedResidual(BaseModule):
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
drop_path_rate=0.,
with_cp=False,
init_cfg=None):
super(InvertedResidual, self).__init__(init_cfg)
self.with_res_shortcut = (stride == 1 and in_channels == out_channels)
assert stride in [1, 2]
self.with_cp = with_cp
self.drop_path = DropPath(
drop_path_rate) if drop_path_rate > 0 else nn.Identity()
self.with_se = se_cfg is not None
self.with_expand_conv = (mid_channels != in_channels)
@ -87,6 +90,14 @@ class InvertedResidual(BaseModule):
act_cfg=None)
def forward(self, x):
"""Forward function.
Args:
x (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
def _inner_forward(x):
out = x
@ -102,7 +113,7 @@ class InvertedResidual(BaseModule):
out = self.linear_conv(out)
if self.with_res_shortcut:
return x + out
return x + self.drop_path(out)
else:
return out

View File

@ -16,3 +16,4 @@ Import:
- configs/conformer/metafile.yml
- configs/regnet/metafile.yml
- configs/deit/metafile.yml
- configs/efficientnet/metafile.yml

View File

@ -14,7 +14,7 @@ line_length = 79
multi_line_output = 0
known_standard_library = pkg_resources,setuptools
known_first_party = mmcls
known_third_party = PIL,matplotlib,mmcv,mmdet,modelindex,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,requests,rich,sphinx,torch,torchvision,ts
known_third_party = PIL,matplotlib,mmcv,mmdet,modelindex,numpy,onnxruntime,packaging,pytest,pytorch_sphinx_theme,requests,rich,sphinx,tensorflow,torch,torchvision,ts
no_lines_before = STDLIB,LOCALFOLDER
default_section = THIRDPARTY

View File

@ -0,0 +1,143 @@
import pytest
import torch
from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import EfficientNet
def is_norm(modules):
"""Check if is one of the norms."""
if isinstance(modules, (GroupNorm, _BatchNorm)):
return True
return False
def check_norm_state(modules, train_state):
"""Check if norm layer is in correct train state."""
for mod in modules:
if isinstance(mod, _BatchNorm):
if mod.training != train_state:
return False
return True
def test_efficientnet_backbone():
archs = ['b0', 'b1', 'b2', 'b3', 'b4', 'b5', 'b7', 'b8', 'es', 'em', 'el']
with pytest.raises(TypeError):
# pretrained must be a string path
model = EfficientNet()
model.init_weights(pretrained=0)
with pytest.raises(AssertionError):
# arch must in arc_settings
EfficientNet(arch='others')
for arch in archs:
with pytest.raises(ValueError):
# frozen_stages must less than 7
EfficientNet(arch=arch, frozen_stages=12)
# Test EfficientNet
model = EfficientNet()
model.init_weights()
model.train()
# Test EfficientNet with first stage frozen
frozen_stages = 7
model = EfficientNet(arch='b0', frozen_stages=frozen_stages)
model.init_weights()
model.train()
for i in range(frozen_stages):
layer = model.layers[i]
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for param in layer.parameters():
assert param.requires_grad is False
# Test EfficientNet with norm eval
model = EfficientNet(norm_eval=True)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)
# Test EfficientNet forward with 'b0' arch
out_channels = [32, 16, 24, 40, 112, 320, 1280]
model = EfficientNet(arch='b0', out_indices=(0, 1, 2, 3, 4, 5, 6))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size([1, out_channels[0], 112, 112])
assert feat[1].shape == torch.Size([1, out_channels[1], 112, 112])
assert feat[2].shape == torch.Size([1, out_channels[2], 56, 56])
assert feat[3].shape == torch.Size([1, out_channels[3], 28, 28])
assert feat[4].shape == torch.Size([1, out_channels[4], 14, 14])
assert feat[5].shape == torch.Size([1, out_channels[5], 7, 7])
assert feat[6].shape == torch.Size([1, out_channels[6], 7, 7])
# Test EfficientNet forward with 'b0' arch and GroupNorm
out_channels = [32, 16, 24, 40, 112, 320, 1280]
model = EfficientNet(
arch='b0',
out_indices=(0, 1, 2, 3, 4, 5, 6),
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
for m in model.modules():
if is_norm(m):
assert isinstance(m, GroupNorm)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size([1, out_channels[0], 112, 112])
assert feat[1].shape == torch.Size([1, out_channels[1], 112, 112])
assert feat[2].shape == torch.Size([1, out_channels[2], 56, 56])
assert feat[3].shape == torch.Size([1, out_channels[3], 28, 28])
assert feat[4].shape == torch.Size([1, out_channels[4], 14, 14])
assert feat[5].shape == torch.Size([1, out_channels[5], 7, 7])
assert feat[6].shape == torch.Size([1, out_channels[6], 7, 7])
# Test EfficientNet forward with 'es' arch
out_channels = [32, 24, 32, 48, 144, 192, 1280]
model = EfficientNet(arch='es', out_indices=(0, 1, 2, 3, 4, 5, 6))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size([1, out_channels[0], 112, 112])
assert feat[1].shape == torch.Size([1, out_channels[1], 112, 112])
assert feat[2].shape == torch.Size([1, out_channels[2], 56, 56])
assert feat[3].shape == torch.Size([1, out_channels[3], 28, 28])
assert feat[4].shape == torch.Size([1, out_channels[4], 14, 14])
assert feat[5].shape == torch.Size([1, out_channels[5], 7, 7])
assert feat[6].shape == torch.Size([1, out_channels[6], 7, 7])
# Test EfficientNet forward with 'es' arch and GroupNorm
out_channels = [32, 24, 32, 48, 144, 192, 1280]
model = EfficientNet(
arch='es',
out_indices=(0, 1, 2, 3, 4, 5, 6),
norm_cfg=dict(type='GN', num_groups=2, requires_grad=True))
for m in model.modules():
if is_norm(m):
assert isinstance(m, GroupNorm)
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 7
assert feat[0].shape == torch.Size([1, out_channels[0], 112, 112])
assert feat[1].shape == torch.Size([1, out_channels[1], 112, 112])
assert feat[2].shape == torch.Size([1, out_channels[2], 56, 56])
assert feat[3].shape == torch.Size([1, out_channels[3], 28, 28])
assert feat[4].shape == torch.Size([1, out_channels[4], 14, 14])
assert feat[5].shape == torch.Size([1, out_channels[5], 7, 7])
assert feat[6].shape == torch.Size([1, out_channels[6], 7, 7])

View File

@ -0,0 +1,214 @@
import argparse
import os
import numpy as np
import torch
from mmcv.runner import Sequential
from tensorflow.python.training import py_checkpoint_reader
from mmcls.models.backbones.efficientnet import EfficientNet
def tf2pth(v):
if v.ndim == 4:
return np.ascontiguousarray(v.transpose(3, 2, 0, 1))
elif v.ndim == 2:
return np.ascontiguousarray(v.transpose())
return v
def read_ckpt(ckpt):
reader = py_checkpoint_reader.NewCheckpointReader(ckpt)
weights = {
n: torch.as_tensor(tf2pth(reader.get_tensor(n)))
for (n, _) in reader.get_variable_to_shape_map().items()
}
return weights
def map_key(weight):
m = dict()
has_expand_conv = set()
is_MBConv = set()
max_idx = 0
name = None
for k, v in weight.items():
seg = k.split('/')
if len(seg) == 1:
continue
if 'edgetpu' in seg[0]:
name = 'e' + seg[0][21:].lower()
else:
name = seg[0][13:]
if seg[2] == 'tpu_batch_normalization_2':
has_expand_conv.add(seg[1])
if seg[1].startswith('blocks_'):
idx = int(seg[1][7:]) + 1
max_idx = max(max_idx, idx)
if 'depthwise' in k:
is_MBConv.add(seg[1])
model = EfficientNet(name)
idx2key = []
for idx, module in enumerate(model.layers):
if isinstance(module, Sequential):
for j in range(len(module)):
idx2key.append('{}.{}'.format(idx, j))
else:
idx2key.append('{}'.format(idx))
for k, v in weight.items():
if 'Exponential' in k or 'RMS' in k:
continue
seg = k.split('/')
if len(seg) == 1:
continue
if seg[2] == 'depthwise_conv2d':
v = v.transpose(1, 0)
if seg[1] == 'stem':
prefix = 'backbone.layers.{}'.format(idx2key[0])
mapping = {
'conv2d/kernel': 'conv.weight',
'tpu_batch_normalization/beta': 'bn.bias',
'tpu_batch_normalization/gamma': 'bn.weight',
'tpu_batch_normalization/moving_mean': 'bn.running_mean',
'tpu_batch_normalization/moving_variance': 'bn.running_var',
}
suffix = mapping['/'.join(seg[2:])]
m[prefix + '.' + suffix] = v
elif seg[1].startswith('blocks_'):
idx = int(seg[1][7:]) + 1
prefix = '.'.join(['backbone', 'layers', idx2key[idx]])
if seg[1] not in is_MBConv:
mapping = {
'conv2d/kernel':
'conv1.conv.weight',
'tpu_batch_normalization/gamma':
'conv1.bn.weight',
'tpu_batch_normalization/beta':
'conv1.bn.bias',
'tpu_batch_normalization/moving_mean':
'conv1.bn.running_mean',
'tpu_batch_normalization/moving_variance':
'conv1.bn.running_var',
'conv2d_1/kernel':
'conv2.conv.weight',
'tpu_batch_normalization_1/gamma':
'conv2.bn.weight',
'tpu_batch_normalization_1/beta':
'conv2.bn.bias',
'tpu_batch_normalization_1/moving_mean':
'conv2.bn.running_mean',
'tpu_batch_normalization_1/moving_variance':
'conv2.bn.running_var',
}
else:
base_mapping = {
'depthwise_conv2d/depthwise_kernel':
'depthwise_conv.conv.weight',
'se/conv2d/kernel': 'se.conv1.conv.weight',
'se/conv2d/bias': 'se.conv1.conv.bias',
'se/conv2d_1/kernel': 'se.conv2.conv.weight',
'se/conv2d_1/bias': 'se.conv2.conv.bias'
}
if seg[1] not in has_expand_conv:
mapping = {
'conv2d/kernel':
'linear_conv.conv.weight',
'tpu_batch_normalization/beta':
'depthwise_conv.bn.bias',
'tpu_batch_normalization/gamma':
'depthwise_conv.bn.weight',
'tpu_batch_normalization/moving_mean':
'depthwise_conv.bn.running_mean',
'tpu_batch_normalization/moving_variance':
'depthwise_conv.bn.running_var',
'tpu_batch_normalization_1/beta':
'linear_conv.bn.bias',
'tpu_batch_normalization_1/gamma':
'linear_conv.bn.weight',
'tpu_batch_normalization_1/moving_mean':
'linear_conv.bn.running_mean',
'tpu_batch_normalization_1/moving_variance':
'linear_conv.bn.running_var',
}
else:
mapping = {
'depthwise_conv2d/depthwise_kernel':
'depthwise_conv.conv.weight',
'conv2d/kernel':
'expand_conv.conv.weight',
'conv2d_1/kernel':
'linear_conv.conv.weight',
'tpu_batch_normalization/beta':
'expand_conv.bn.bias',
'tpu_batch_normalization/gamma':
'expand_conv.bn.weight',
'tpu_batch_normalization/moving_mean':
'expand_conv.bn.running_mean',
'tpu_batch_normalization/moving_variance':
'expand_conv.bn.running_var',
'tpu_batch_normalization_1/beta':
'depthwise_conv.bn.bias',
'tpu_batch_normalization_1/gamma':
'depthwise_conv.bn.weight',
'tpu_batch_normalization_1/moving_mean':
'depthwise_conv.bn.running_mean',
'tpu_batch_normalization_1/moving_variance':
'depthwise_conv.bn.running_var',
'tpu_batch_normalization_2/beta':
'linear_conv.bn.bias',
'tpu_batch_normalization_2/gamma':
'linear_conv.bn.weight',
'tpu_batch_normalization_2/moving_mean':
'linear_conv.bn.running_mean',
'tpu_batch_normalization_2/moving_variance':
'linear_conv.bn.running_var',
}
mapping.update(base_mapping)
suffix = mapping['/'.join(seg[2:])]
m[prefix + '.' + suffix] = v
elif seg[1] == 'head':
seq_key = idx2key[max_idx + 1]
mapping = {
'conv2d/kernel':
'backbone.layers.{}.conv.weight'.format(seq_key),
'tpu_batch_normalization/beta':
'backbone.layers.{}.bn.bias'.format(seq_key),
'tpu_batch_normalization/gamma':
'backbone.layers.{}.bn.weight'.format(seq_key),
'tpu_batch_normalization/moving_mean':
'backbone.layers.{}.bn.running_mean'.format(seq_key),
'tpu_batch_normalization/moving_variance':
'backbone.layers.{}.bn.running_var'.format(seq_key),
'dense/kernel':
'head.fc.weight',
'dense/bias':
'head.fc.bias'
}
key = mapping['/'.join(seg[2:])]
if name.startswith('e') and 'fc' in key:
v = v[1:]
m[key] = v
return m
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('infile', type=str, help='Path to the ckpt.')
parser.add_argument('outfile', type=str, help='Output file.')
args = parser.parse_args()
assert args.outfile
outdir = os.path.dirname(os.path.abspath(args.outfile))
if not os.path.exists(outdir):
os.makedirs(outdir)
weights = read_ckpt(args.infile)
weights = map_key(weights)
torch.save(weights, args.outfile)