[Feature] Add RepVGG backbone and checkpoints. (#414)
* Add RepVGG code. * Add se_module as plugin. * Add the repvggA0 primitive config * Change repvggA0.py to fit mmcls * Add RepVGG configs * Add repvgg_to_mmcls * Add tools/deployment/convert_repvggblock_param_to_deploy.py * Change configs/repvgg/README.md * Streamlining the number of configuration files. * Fix lints * Delete plugins * Delete code about plugin. * Modify the code for using se module. * Modify config to fit repvgg with se. * Change se_cfg to allow loading of pre-training parameters. * Reduce the complexity of the configuration file. * Finsh unitest for repvgg. * Fix bug about se in repvgg_to_mmcls. * Rename convert_repvggblock_param_to_deploy.py to reparameterize_repvgg.py, and delete setting about device. * test commit * test commit * test commit command * Modify repvgg.py to make the code more readable. * Add value=0 in F.pad() * Add se_cfg to arch_settings. * Fix bug. * modeify some attr name and Update unit tests * rename stage_0 to stem and branch_identity to branch_norm * update unit tests * add m.eval in unit tests * [Enhance] Enhence SE layer to support custom squeeze channels. (#417) * add enhenced SE * Update * rm basechannel * fix docstring * Update se_layer.py fix docstring * [Docs] Add algorithm readme and update meta yml (#418) * Add README.md for models without checkpoints. * Update model-index.yml * Update metafile.yml of seresnet * [Enhance] Add `hparams` argument in `AutoAugment` and `RandAugment` and some other improvement. (#398) * Add hparams argument in `AutoAugment` and `RandAugment`. And `pad_val` supports sequence instead of tuple only. * Add unit tests for `AutoAugment` and `hparams` in `RandAugment`. * Use smaller test image to speed up uni tests. * Use hparams to simplify RandAugment config in swin-transformer. * Rename augment config name from `pipeline` to `pipelines`. * Add some commnet ad docstring. * [Feature] Support classwise weight in losses (#388) * Add classwise weight in losses:CE,BCE,softBCE * Update unit test * rm some extra code * rm some extra code * fix broadcast * fix broadcast * update unit tests * use new_tensor * fix lint * [Enhance] Better result visualization (#419) * Imporve result visualization to support wait time and change the backend to matplotlib. * Add unit test for visualization * Add adaptive dpi function * Rename `imshow_cls_result` to `imshow_infos`. * Support str in `imshow_infos` * Improve docstring. * Bump version to v0.15.0 (#426) * [CI] Add PyTorch 1.9 and Python 3.9 build workflow, and remove some CI. (#422) * Add PyTorch 1.9 build workflow, and remove some CI. * Add Python 3.9 CI * Show Python 3.9 support. * [Enhance] Rename the option `--options` in some tools to `--cfg-options`. (#425) * [Docs] Fix sphinx version (#429) * [Docs] Add `CITATION.cff` (#428) * Add CITATION.cff * Fix typo in setup.py * Change author in setup.py * modeify some attr name and Update unit tests * rename stage_0 to stem and branch_identity to branch_norm * update unit tests * add m.eval in unit tests * Update unit tests * refactor * refactor * Alignment inference accuracy * Update configs, readme and metafile * Update readme * return tuple and fix metafile * fix unit test * rm regnet and classifiers changes * update auto_aug * update metafile & readme * use delattr * rename cfgs * Update checkpoint url * Update readme * Rename config files. * Update readme and metafile * add comment * Update mmcls/models/backbones/repvgg.py Co-authored-by: Ma Zerun <mzr1996@163.com> * Update docstring * Improve docstring. * Update unittest_testblock Co-authored-by: Ezra-Yu <1105212286@qq.com> Co-authored-by: Ma Zerun <mzr1996@163.com>pull/471/head
parent
8b7d38b243
commit
90496b4687
|
@ -0,0 +1,43 @@
|
|||
_base_ = ['./pipelines/auto_aug.py']
|
||||
|
||||
# dataset settings
|
||||
dataset_type = 'ImageNet'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='RandomResizedCrop', size=224),
|
||||
dict(type='RandomFlip', flip_prob=0.5, direction='horizontal'),
|
||||
dict(type='AutoAugment', policies={{_base_.auto_increasing_policies}}),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='ToTensor', keys=['gt_label']),
|
||||
dict(type='Collect', keys=['img', 'gt_label'])
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='Resize', size=(256, -1)),
|
||||
dict(type='CenterCrop', crop_size=224),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img'])
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=64,
|
||||
workers_per_gpu=2,
|
||||
train=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/train',
|
||||
pipeline=train_pipeline),
|
||||
val=dict(
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline),
|
||||
test=dict(
|
||||
# replace `data/val` with `data/test` for standard test
|
||||
type=dataset_type,
|
||||
data_prefix='data/imagenet/val',
|
||||
ann_file='data/imagenet/meta/val.txt',
|
||||
pipeline=test_pipeline))
|
||||
evaluation = dict(interval=1, metric='accuracy')
|
|
@ -0,0 +1,96 @@
|
|||
# Policy for ImageNet, refers to
|
||||
# https://github.com/DeepVoltaire/AutoAugment/blame/master/autoaugment.py
|
||||
policy_imagenet = [
|
||||
[
|
||||
dict(type='Posterize', bits=4, prob=0.4),
|
||||
dict(type='Rotate', angle=30., prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||
dict(type='AutoContrast', prob=0.6)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.8),
|
||||
dict(type='Equalize', prob=0.6)],
|
||||
[
|
||||
dict(type='Posterize', bits=5, prob=0.6),
|
||||
dict(type='Posterize', bits=5, prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.4),
|
||||
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.4),
|
||||
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8)
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 6, prob=0.6),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[dict(type='Posterize', bits=6, prob=0.8),
|
||||
dict(type='Equalize', prob=1.)],
|
||||
[
|
||||
dict(type='Rotate', angle=10., prob=0.2),
|
||||
dict(type='Solarize', thr=256 / 9, prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.6),
|
||||
dict(type='Posterize', bits=5, prob=0.4)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||
dict(type='ColorTransform', magnitude=0., prob=0.4)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=30., prob=0.4),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.0),
|
||||
dict(type='Equalize', prob=0.8)],
|
||||
[dict(type='Invert', prob=0.6),
|
||||
dict(type='Equalize', prob=1.)],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||
],
|
||||
[
|
||||
dict(type='Rotate', angle=30 / 9 * 8, prob=0.8),
|
||||
dict(type='ColorTransform', magnitude=0.2, prob=1.)
|
||||
],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.8, prob=0.8),
|
||||
dict(type='Solarize', thr=256 / 9 * 2, prob=0.8)
|
||||
],
|
||||
[
|
||||
dict(type='Sharpness', magnitude=0.7, prob=0.4),
|
||||
dict(type='Invert', prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(
|
||||
type='Shear',
|
||||
magnitude=0.3 / 9 * 5,
|
||||
prob=0.6,
|
||||
direction='horizontal'),
|
||||
dict(type='Equalize', prob=1.)
|
||||
],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0., prob=0.4),
|
||||
dict(type='Equalize', prob=0.6)
|
||||
],
|
||||
[
|
||||
dict(type='Equalize', prob=0.4),
|
||||
dict(type='Solarize', thr=256 / 9 * 5, prob=0.2)
|
||||
],
|
||||
[
|
||||
dict(type='Solarize', thr=256 / 9 * 4, prob=0.6),
|
||||
dict(type='AutoContrast', prob=0.6)
|
||||
],
|
||||
[dict(type='Invert', prob=0.6),
|
||||
dict(type='Equalize', prob=1.)],
|
||||
[
|
||||
dict(type='ColorTransform', magnitude=0.4, prob=0.6),
|
||||
dict(type='Contrast', magnitude=0.8, prob=1.)
|
||||
],
|
||||
[dict(type='Equalize', prob=0.8),
|
||||
dict(type='Equalize', prob=0.6)],
|
||||
]
|
|
@ -0,0 +1,15 @@
|
|||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='RepVGG',
|
||||
arch='A0',
|
||||
out_indices=(3, ),
|
||||
),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=1280,
|
||||
loss=dict(type='CrossEntropyLoss', loss_weight=1.0),
|
||||
topk=(1, 5),
|
||||
))
|
|
@ -0,0 +1,23 @@
|
|||
model = dict(
|
||||
type='ImageClassifier',
|
||||
backbone=dict(
|
||||
type='RepVGG',
|
||||
arch='B3',
|
||||
out_indices=(3, ),
|
||||
),
|
||||
neck=dict(type='GlobalAveragePooling'),
|
||||
head=dict(
|
||||
type='LinearClsHead',
|
||||
num_classes=1000,
|
||||
in_channels=2560,
|
||||
loss=dict(
|
||||
type='LabelSmoothLoss',
|
||||
loss_weight=1.0,
|
||||
label_smooth_val=0.1,
|
||||
mode='classy_vision',
|
||||
num_classes=1000),
|
||||
topk=(1, 5),
|
||||
),
|
||||
train_cfg=dict(
|
||||
augments=dict(type='BatchMixup', alpha=0.2, num_classes=1000,
|
||||
prob=1.)))
|
|
@ -0,0 +1,11 @@
|
|||
# optimizer
|
||||
optimizer = dict(type='SGD', lr=0.1, momentum=0.9, weight_decay=0.0001)
|
||||
optimizer_config = dict(grad_clip=None)
|
||||
# learning policy
|
||||
lr_config = dict(
|
||||
policy='CosineAnnealing',
|
||||
min_lr=0,
|
||||
warmup='linear',
|
||||
warmup_iters=25025,
|
||||
warmup_ratio=0.25)
|
||||
runner = dict(type='EpochBasedRunner', max_epochs=200)
|
|
@ -0,0 +1,48 @@
|
|||
# Repvgg: Making vgg-style convnets great again
|
||||
|
||||
## Introduction
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
```latex
|
||||
@inproceedings{ding2021repvgg,
|
||||
title={Repvgg: Making vgg-style convnets great again},
|
||||
author={Ding, Xiaohan and Zhang, Xiangyu and Ma, Ningning and Han, Jungong and Ding, Guiguang and Sun, Jian},
|
||||
booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
|
||||
pages={13733--13742},
|
||||
year={2021}
|
||||
}
|
||||
```
|
||||
|
||||
## Pretrain model
|
||||
|
||||
| Model | Epochs | Params(M) | Flops(G) | Top-1 (%) | Top-5 (%) | Config | Download |
|
||||
| :---------: | :----: | :-------------------------------: | :-----------------------------: | :-------: | :-------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
|
||||
| RepVGG-A0 | 120 | 9.11(train) \| 8.31 (deploy) | 1.52 (train) \| 1.36 (deploy) | 72.41 | 90.50 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-A0_4xb64-coslr-120e_in1k.py) \| [config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-A0_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth) |
|
||||
| RepVGG-A1 | 120 | 14.09 (train) \| 12.79 (deploy) | 2.64 (train) \| 2.37 (deploy) | 74.47 | 91.85 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-A1_4xb64-coslr-120e_in1k.py) \| [config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-A1_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth) |
|
||||
| RepVGG-A2 | 120 | 28.21 (train) \| 25.5 (deploy) | 5.7 (train) \| 5.12 (deploy) | 76.48 | 93.01 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/masterconfigs/repvgg/repvgg-A2_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-A2_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth) |
|
||||
| RepVGG-B0 | 120 | 15.82 (train) \| 14.34 (deploy) | 3.42 (train) \| 3.06 (deploy) | 75.14 | 92.42 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B0_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B0_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth) |
|
||||
| RepVGG-B1 | 120 | 57.42 (train) \| 51.83 (deploy) | 13.16 (train) \| 11.82 (deploy) | 78.37 | 94.11 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B1_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B1_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth) |
|
||||
| RepVGG-B1g2 | 120 | 45.78 (train) \| 41.36 (deploy) | 9.82 (train) \| 8.82 (deploy) | 77.79 | 93.88 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B1g2_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B1g2_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth) |
|
||||
| RepVGG-B1g4 | 120 | 39.97 (train) \| 36.13 (deploy) | 8.15 (train) \| 7.32 (deploy) | 77.58 | 93.84 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B1g4_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B1g4_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth) |
|
||||
| RepVGG-B2 | 120 | 89.02 (train) \| 80.32 (deploy) | 20.46 (train) \| 18.39 (deploy) | 78.78 | 94.42 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B2_4xb64-coslr-120e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B2_deploy_4xb64-coslr-120e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth) |
|
||||
| RepVGG-B2g4 | 200 | 61.76 (train) \| 55.78 (deploy) | 12.63 (train) \| 11.34 (deploy) | 79.38 | 94.68 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B2g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B2g4_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth) |
|
||||
| RepVGG-B3 | 200 | 123.09 (train) \| 110.96 (deploy) | 29.17 (train) \| 26.22 (deploy) | 80.52 | 95.26 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B3_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth) |
|
||||
| RepVGG-B3g4 | 200 | 83.83 (train) \| 75.63 (deploy) | 17.9 (train) \| 16.08 (deploy) | 80.22 | 95.10 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-B3g4_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth) |
|
||||
| RepVGG-D2se | 200 | 133.33 (train) \| 120.39 (deploy) | 36.56 (train) \| 32.85 (deploy) | 81.81 | 95.94 | [config (train)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) \|[config (deploy)](https://github.com/open-mmlab/mmclassification/blob/master/configs/repvgg/deploy/repvgg-D2se_deploy_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py) | [model](https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth) |
|
||||
|
||||
## Reparameterize RepVGG
|
||||
|
||||
The checkpoints provided are all in `train` form. Use the reparameterize tool to switch them to more efficient `deploy` form, which not only has fewer parameters but also less calculations.
|
||||
|
||||
```bash
|
||||
python ./tools/convert_models/reparameterize_repvgg.py ${CFG_PATH} ${SRC_CKPT_PATH} ${TARGET_CKPT_PATH}
|
||||
```
|
||||
|
||||
`${CFG_PATH}` is the config file, `${SRC_CKPT_PATH}` is the source chenpoint file, `${TARGET_CKPT_PATH}` is the target deploy weight file path.
|
||||
|
||||
To use reparameterized repvgg weight, the config file must switch to [the deploy config files](./configs/repvgg/deploy) as below:
|
||||
|
||||
```bash
|
||||
python ./tools/test.py ${RapVGG_Deploy_CFG} ${CHECK_POINT}
|
||||
```
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-A0_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-A1_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-A2_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-B0_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-B1_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-B1g2_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-B1g4_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-B2_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-B2g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = '../repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(deploy=True))
|
|
@ -0,0 +1,205 @@
|
|||
Collections:
|
||||
- Name: RepVGG
|
||||
Metadata:
|
||||
Training Data: ImageNet-1k
|
||||
Architecture:
|
||||
- re-parameterization Convolution
|
||||
- VGG-style Neural Network
|
||||
Paper:
|
||||
URL: https://arxiv.org/abs/2101.03697
|
||||
Title: 'RepVGG: Making VGG-style ConvNets Great Again'
|
||||
README: configs/repvgg/README.md
|
||||
|
||||
Models:
|
||||
- Name: repvgg-A0_4xb64-coslr-120e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-A0_4xb64-coslr-120e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 1520000000
|
||||
Parameters: 9110000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 72.41
|
||||
Top 5 Accuracy: 90.50
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A0_3rdparty_4xb64-coslr-120e_in1k_20210909-883ab98c.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L196
|
||||
- Name: repvgg-A1_4xb64-coslr-120e_in1k
|
||||
In Collection: Repvgg
|
||||
Config: configs/repvgg/repvgg-A1_4xb64-coslr-120e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 2640000000
|
||||
Parameters: 14090000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 74.47
|
||||
Top 5 Accuracy: 91.85
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A1_3rdparty_4xb64-coslr-120e_in1k_20210909-24003a24.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L200
|
||||
- Name: repvgg-A2_4xb64-coslr-120e_in1k
|
||||
In Collection: Repvgg
|
||||
Config: configs/repvgg/repvgg-A2_4xb64-coslr-120e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 28210000000
|
||||
Parameters: 5700000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 76.48
|
||||
Top 5 Accuracy: 93.01
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-A2_3rdparty_4xb64-coslr-120e_in1k_20210909-97d7695a.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L204
|
||||
- Name: repvgg-B0_4xb64-coslr-120e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-B0_4xb64-coslr-120e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 15820000000
|
||||
Parameters: 3420000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 75.14
|
||||
Top 5 Accuracy: 92.42
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B0_3rdparty_4xb64-coslr-120e_in1k_20210909-446375f4.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L208
|
||||
- Name: repvgg-B1_4xb64-coslr-120e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-B1_4xb64-coslr-120e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 57420000000
|
||||
Parameters: 13160000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.37
|
||||
Top 5 Accuracy: 94.11
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1_3rdparty_4xb64-coslr-120e_in1k_20210909-750cdf67.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L212
|
||||
- Name: repvgg-B1g2_4xb64-coslr-120e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-B1g2_4xb64-coslr-120e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 45780000000
|
||||
Parameters: 9820000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 77.79
|
||||
Top 5 Accuracy: 93.88
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g2_3rdparty_4xb64-coslr-120e_in1k_20210909-344f6422.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L216
|
||||
- Name: repvgg-B1g4_4xb64-coslr-120e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-B1g4_4xb64-coslr-120e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 39970000000
|
||||
Parameters: 8150000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 77.58
|
||||
Top 5 Accuracy: 93.84
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B1g4_3rdparty_4xb64-coslr-120e_in1k_20210909-d4c1a642.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L220
|
||||
- Name: repvgg-B2_4xb64-coslr-120e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-B2_4xb64-coslr-120e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 89020000000
|
||||
Parameters: 20420000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 78.78
|
||||
Top 5 Accuracy: 94.42
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2_3rdparty_4xb64-coslr-120e_in1k_20210909-bd6b937c.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L225
|
||||
- Name: repvgg-B2g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-B2g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 61760000000
|
||||
Parameters: 12630000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 79.38
|
||||
Top 5 Accuracy: 94.68
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B2g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-7b7955f0.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L229
|
||||
- Name: repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 123090000000
|
||||
Parameters: 29170000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 80.52
|
||||
Top 5 Accuracy: 95.26
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-dda968bf.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L238
|
||||
- Name: repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-B3g4_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 83830000000
|
||||
Parameters: 17900000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 80.22
|
||||
Top 5 Accuracy: 95.10
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-B3g4_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-4e54846a.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L238
|
||||
- Name: repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k
|
||||
In Collection: RepVGG
|
||||
Config: configs/repvgg/repvgg-D2se_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py
|
||||
Metadata:
|
||||
FLOPs: 133330000000
|
||||
Parameters: 36560000
|
||||
Results:
|
||||
- Dataset: ImageNet-1k
|
||||
Task: Image Classification
|
||||
Metrics:
|
||||
Top 1 Accuracy: 81.81
|
||||
Top 5 Accuracy: 95.94
|
||||
Weights: https://download.openmmlab.com/mmclassification/v0/repvgg/repvgg-D2se_3rdparty_4xb64-autoaug-lbs-mixup-coslr-200e_in1k_20210909-cf3139b7.pth
|
||||
Converted From:
|
||||
Weights: https://drive.google.com/drive/folders/1Avome4KvNp0Lqh2QwhXO6L5URQjzCjUq
|
||||
Code: https://github.com/DingXiaoH/RepVGG/blob/9f272318abfc47a2b702cd0e916fca8d25d683e7/repvgg.py#L250
|
|
@ -0,0 +1,8 @@
|
|||
_base_ = [
|
||||
'../_base_/models/repvgg-A0_in1k.py',
|
||||
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256_coslr.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
||||
|
||||
runner = dict(max_epochs=120)
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='A1'))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='A2'), head=dict(in_channels=1408))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='B0'), head=dict(in_channels=1280))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='B1'), head=dict(in_channels=2048))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='B1g2'), head=dict(in_channels=2048))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='B1g4'), head=dict(in_channels=2048))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-A0_4xb64-coslr-120e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='B2'), head=dict(in_channels=2560))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='B2g4'))
|
|
@ -0,0 +1,6 @@
|
|||
_base_ = [
|
||||
'../_base_/models/repvgg-B3_lbs-mixup_in1k.py',
|
||||
'../_base_/datasets/imagenet_bs64_pil_resize.py',
|
||||
'../_base_/schedules/imagenet_bs256_200e_coslr_warmup.py',
|
||||
'../_base_/default_runtime.py'
|
||||
]
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='B3g4'))
|
|
@ -0,0 +1,3 @@
|
|||
_base_ = './repvgg-B3_4xb64-autoaug-lbs-mixup-coslr-200e_in1k.py'
|
||||
|
||||
model = dict(backbone=dict(arch='D2se'))
|
|
@ -4,6 +4,7 @@ from .lenet import LeNet5
|
|||
from .mobilenet_v2 import MobileNetV2
|
||||
from .mobilenet_v3 import MobileNetV3
|
||||
from .regnet import RegNet
|
||||
from .repvgg import RepVGG
|
||||
from .resnest import ResNeSt
|
||||
from .resnet import ResNet, ResNetV1d
|
||||
from .resnet_cifar import ResNet_CIFAR
|
||||
|
@ -22,5 +23,5 @@ __all__ = [
|
|||
'LeNet5', 'AlexNet', 'VGG', 'RegNet', 'ResNet', 'ResNeXt', 'ResNetV1d',
|
||||
'ResNeSt', 'ResNet_CIFAR', 'SEResNet', 'SEResNeXt', 'ShuffleNetV1',
|
||||
'ShuffleNetV2', 'MobileNetV2', 'MobileNetV3', 'VisionTransformer',
|
||||
'SwinTransformer', 'TNT', 'TIMMBackbone'
|
||||
'SwinTransformer', 'TNT', 'RepVGG', 'TIMMBackbone'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,537 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmcv.runner import BaseModule, Sequential
|
||||
from mmcv.utils.parrots_wrapper import _BatchNorm
|
||||
|
||||
from ..builder import BACKBONES
|
||||
from ..utils.se_layer import SELayer
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
class RepVGGBlock(BaseModule):
|
||||
"""RepVGG block for RepVGG backbone.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of the block.
|
||||
out_channels (int): The output channels of the block.
|
||||
stride (int): Stride of the 3x3 and 1x1 convolution layer. Default: 1.
|
||||
padding (int): Padding of the 3x3 convolution layer.
|
||||
dilation (int): Dilation of the 3x3 convolution layer.
|
||||
groups (int): Groups of the 3x3 and 1x1 convolution layer. Default: 1.
|
||||
padding_mode (str): Padding mode of the 3x3 convolution layer.
|
||||
Default: 'zeros'.
|
||||
se_cfg (None or dict): The configuration of the se module.
|
||||
Default: None.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
conv_cfg (dict, optional): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN', requires_grad=True).
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
deploy (bool): Whether to switch the model structure to
|
||||
deployment mode. Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dilation=1,
|
||||
groups=1,
|
||||
padding_mode='zeros',
|
||||
se_cfg=None,
|
||||
with_cp=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
deploy=False,
|
||||
init_cfg=None):
|
||||
super(RepVGGBlock, self).__init__(init_cfg)
|
||||
|
||||
assert se_cfg is None or isinstance(se_cfg, dict)
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.stride = stride
|
||||
self.padding = padding
|
||||
self.dilation = dilation
|
||||
self.groups = groups
|
||||
self.se_cfg = se_cfg
|
||||
self.with_cp = with_cp
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.deploy = deploy
|
||||
|
||||
if deploy:
|
||||
self.branch_reparam = build_conv_layer(
|
||||
conv_cfg,
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
dilation=dilation,
|
||||
groups=groups,
|
||||
bias=True,
|
||||
padding_mode=padding_mode)
|
||||
else:
|
||||
# judge if input shape and output shape are the same.
|
||||
# If true, add a normalized identity shortcut.
|
||||
if out_channels == in_channels and stride == 1 and \
|
||||
padding == dilation:
|
||||
self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1]
|
||||
else:
|
||||
self.branch_norm = None
|
||||
|
||||
self.branch_3x3 = self.create_conv_bn(
|
||||
kernel_size=3,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
)
|
||||
self.branch_1x1 = self.create_conv_bn(kernel_size=1)
|
||||
|
||||
if se_cfg is not None:
|
||||
self.se_layer = SELayer(channels=out_channels, **se_cfg)
|
||||
else:
|
||||
self.se_layer = None
|
||||
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def create_conv_bn(self, kernel_size, dilation=1, padding=0):
|
||||
conv_bn = Sequential()
|
||||
conv_bn.add_module(
|
||||
'conv',
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=kernel_size,
|
||||
stride=self.stride,
|
||||
dilation=dilation,
|
||||
padding=padding,
|
||||
groups=self.groups,
|
||||
bias=False))
|
||||
conv_bn.add_module(
|
||||
'norm',
|
||||
build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1])
|
||||
|
||||
return conv_bn
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(inputs):
|
||||
if self.deploy:
|
||||
return self.branch_reparam(inputs)
|
||||
|
||||
if self.branch_norm is None:
|
||||
branch_norm_out = 0
|
||||
else:
|
||||
branch_norm_out = self.branch_norm(inputs)
|
||||
|
||||
inner_out = self.branch_3x3(inputs) + self.branch_1x1(
|
||||
inputs) + branch_norm_out
|
||||
|
||||
if self.se_cfg is not None:
|
||||
inner_out = self.se_layer(inner_out)
|
||||
|
||||
return inner_out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.act(out)
|
||||
|
||||
return out
|
||||
|
||||
def switch_to_deploy(self):
|
||||
"""Switch the model structure from training mode to deployment mode."""
|
||||
if self.deploy:
|
||||
return
|
||||
assert self.norm_cfg['type'] == 'BN', \
|
||||
"Switch is not allowed when norm_cfg['type'] != 'BN'."
|
||||
|
||||
reparam_weight, reparam_bias = self.reparameterize()
|
||||
self.branch_reparam = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
self.in_channels,
|
||||
self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=self.padding,
|
||||
dilation=self.dilation,
|
||||
groups=self.groups,
|
||||
bias=True)
|
||||
self.branch_reparam.weight.data = reparam_weight
|
||||
self.branch_reparam.bias.data = reparam_bias
|
||||
|
||||
for param in self.parameters():
|
||||
param.detach_()
|
||||
delattr(self, 'branch_3x3')
|
||||
delattr(self, 'branch_1x1')
|
||||
delattr(self, 'branch_norm')
|
||||
|
||||
self.deploy = True
|
||||
|
||||
def reparameterize(self):
|
||||
"""Fuse all the parameters of all branchs.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all
|
||||
branches. the first element is the weights and the second is
|
||||
the bias.
|
||||
"""
|
||||
weight_3x3, bias_3x3 = self._fuse_conv_bn(self.branch_3x3)
|
||||
weight_1x1, bias_1x1 = self._fuse_conv_bn(self.branch_1x1)
|
||||
# pad a conv1x1 weight to a conv3x3 weight
|
||||
weight_1x1 = F.pad(weight_1x1, [1, 1, 1, 1], value=0)
|
||||
|
||||
weight_norm, bias_norm = 0, 0
|
||||
if self.branch_norm:
|
||||
tmp_conv_bn = self._norm_to_conv3x3(self.branch_norm)
|
||||
weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn)
|
||||
|
||||
return (weight_3x3 + weight_1x1 + weight_norm,
|
||||
bias_3x3 + bias_1x1 + bias_norm)
|
||||
|
||||
def _fuse_conv_bn(self, branch):
|
||||
"""Fuse the parameters in a branch with a conv and bn.
|
||||
|
||||
Args:
|
||||
branch (mmcv.runner.Sequential): A branch with conv and bn.
|
||||
|
||||
Returns:
|
||||
tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
|
||||
fusing the parameters of conv and bn in one branch.
|
||||
The first element is the weight and the second is the bias.
|
||||
"""
|
||||
if branch is None:
|
||||
return 0, 0
|
||||
conv_weight = branch.conv.weight
|
||||
running_mean = branch.norm.running_mean
|
||||
running_var = branch.norm.running_var
|
||||
gamma = branch.norm.weight
|
||||
beta = branch.norm.bias
|
||||
eps = branch.norm.eps
|
||||
|
||||
std = (running_var + eps).sqrt()
|
||||
fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * conv_weight
|
||||
fused_bias = -running_mean * gamma / std + beta
|
||||
|
||||
return fused_weight, fused_bias
|
||||
|
||||
def _norm_to_conv3x3(self, branch_nrom):
|
||||
"""Convert a norm layer to a conv3x3-bn sequence.
|
||||
|
||||
Args:
|
||||
branch (nn.BatchNorm2d): A branch only with bn in the block.
|
||||
|
||||
Returns:
|
||||
tmp_conv3x3 (mmcv.runner.Sequential): a sequential with conv3x3 and
|
||||
bn.
|
||||
"""
|
||||
input_dim = self.in_channels // self.groups
|
||||
conv_weight = torch.zeros((self.in_channels, input_dim, 3, 3),
|
||||
dtype=branch_nrom.weight.dtype)
|
||||
|
||||
for i in range(self.in_channels):
|
||||
conv_weight[i, i % input_dim, 1, 1] = 1
|
||||
conv_weight = conv_weight.to(branch_nrom.weight.device)
|
||||
|
||||
tmp_conv3x3 = self.create_conv_bn(kernel_size=3)
|
||||
tmp_conv3x3.conv.weight.data = conv_weight
|
||||
tmp_conv3x3.norm = branch_nrom
|
||||
return tmp_conv3x3
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class RepVGG(BaseBackbone):
|
||||
"""RepVGG backbone.
|
||||
|
||||
A PyTorch impl of : `RepVGG: Making VGG-style ConvNets Great Again
|
||||
<https://arxiv.org/abs/2101.03697>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): The parameter of RepVGG.
|
||||
If it's a dict, it should contain the following keys:
|
||||
|
||||
- num_blocks (Sequence[int]): Number of blocks in each stage.
|
||||
- width_factor (Sequence[float]): Width deflator in each stage.
|
||||
- group_layer_map (dict | None): RepVGG Block that declares
|
||||
the need to apply group convolution.
|
||||
- se_cfg (dict | None): Se Layer config
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
base_channels (int): Base channels of RepVGG backbone, work
|
||||
with width_factor together. Default: 64.
|
||||
out_indices (Sequence[int]): Output from which stages. Default: (3, ).
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: (2, 2, 2, 2).
|
||||
dilations (Sequence[int]): Dilation of each stage.
|
||||
Default: (1, 1, 1, 1).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters. Default: -1.
|
||||
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
||||
norm_cfg (dict): The config dict for norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
deploy (bool): Whether to switch the model structure to deployment
|
||||
mode. Default: False.
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
|
||||
g2_layer_map = {layer: 2 for layer in groupwise_layers}
|
||||
g4_layer_map = {layer: 4 for layer in groupwise_layers}
|
||||
|
||||
arch_settings = {
|
||||
'A0':
|
||||
dict(
|
||||
num_blocks=[2, 4, 14, 1],
|
||||
width_factor=[0.75, 0.75, 0.75, 2.5],
|
||||
group_layer_map=None,
|
||||
se_cfg=None),
|
||||
'A1':
|
||||
dict(
|
||||
num_blocks=[2, 4, 14, 1],
|
||||
width_factor=[1, 1, 1, 2.5],
|
||||
group_layer_map=None,
|
||||
se_cfg=None),
|
||||
'A2':
|
||||
dict(
|
||||
num_blocks=[2, 4, 14, 1],
|
||||
width_factor=[1.5, 1.5, 1.5, 2.75],
|
||||
group_layer_map=None,
|
||||
se_cfg=None),
|
||||
'B0':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[1, 1, 1, 2.5],
|
||||
group_layer_map=None,
|
||||
se_cfg=None),
|
||||
'B1':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[2, 2, 2, 4],
|
||||
group_layer_map=None,
|
||||
se_cfg=None),
|
||||
'B1g2':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[2, 2, 2, 4],
|
||||
group_layer_map=g2_layer_map,
|
||||
se_cfg=None),
|
||||
'B1g4':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[2, 2, 2, 4],
|
||||
group_layer_map=g4_layer_map,
|
||||
se_cfg=None),
|
||||
'B2':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[2.5, 2.5, 2.5, 5],
|
||||
group_layer_map=None,
|
||||
se_cfg=None),
|
||||
'B2g2':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[2.5, 2.5, 2.5, 5],
|
||||
group_layer_map=g2_layer_map,
|
||||
se_cfg=None),
|
||||
'B2g4':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[2.5, 2.5, 2.5, 5],
|
||||
group_layer_map=g4_layer_map,
|
||||
se_cfg=None),
|
||||
'B3':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[3, 3, 3, 5],
|
||||
group_layer_map=None,
|
||||
se_cfg=None),
|
||||
'B3g2':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[3, 3, 3, 5],
|
||||
group_layer_map=g2_layer_map,
|
||||
se_cfg=None),
|
||||
'B3g4':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[3, 3, 3, 5],
|
||||
group_layer_map=g4_layer_map,
|
||||
se_cfg=None),
|
||||
'D2se':
|
||||
dict(
|
||||
num_blocks=[8, 14, 24, 1],
|
||||
width_factor=[2.5, 2.5, 2.5, 5],
|
||||
group_layer_map=None,
|
||||
se_cfg=dict(ratio=16, divisor=1))
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
arch,
|
||||
in_channels=3,
|
||||
base_channels=64,
|
||||
out_indices=(3, ),
|
||||
strides=(2, 2, 2, 2),
|
||||
dilations=(1, 1, 1, 1),
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_cp=False,
|
||||
deploy=False,
|
||||
norm_eval=False,
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer=['Conv2d']),
|
||||
dict(
|
||||
type='Constant',
|
||||
val=1,
|
||||
layer=['_BatchNorm', 'GroupNorm'])
|
||||
]):
|
||||
super(RepVGG, self).__init__(init_cfg)
|
||||
|
||||
if isinstance(arch, str):
|
||||
assert arch in self.arch_settings, \
|
||||
f'"arch": "{arch}" is not one of the arch_settings'
|
||||
arch = self.arch_settings[arch]
|
||||
elif not isinstance(arch, dict):
|
||||
raise TypeError('Expect "arch" to be either a string '
|
||||
f'or a dict, got {type(arch)}')
|
||||
|
||||
assert len(arch['num_blocks']) == len(
|
||||
arch['width_factor']) == len(strides) == len(dilations)
|
||||
assert max(out_indices) < len(arch['num_blocks'])
|
||||
if arch['group_layer_map'] is not None:
|
||||
assert max(arch['group_layer_map'].keys()) <= sum(
|
||||
arch['num_blocks'])
|
||||
|
||||
if arch['se_cfg'] is not None:
|
||||
assert isinstance(arch['se_cfg'], dict)
|
||||
|
||||
self.arch = arch
|
||||
self.in_channels = in_channels
|
||||
self.base_channels = base_channels
|
||||
self.out_indices = out_indices
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
self.deploy = deploy
|
||||
self.frozen_stages = frozen_stages
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.with_cp = with_cp
|
||||
self.norm_eval = norm_eval
|
||||
|
||||
channels = min(64, int(base_channels * self.arch['width_factor'][0]))
|
||||
self.stem = RepVGGBlock(
|
||||
self.in_channels,
|
||||
channels,
|
||||
stride=2,
|
||||
se_cfg=arch['se_cfg'],
|
||||
with_cp=with_cp,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
deploy=deploy)
|
||||
|
||||
next_create_block_idx = 1
|
||||
self.stages = []
|
||||
for i in range(len(arch['num_blocks'])):
|
||||
num_blocks = self.arch['num_blocks'][i]
|
||||
stride = self.strides[i]
|
||||
dilation = self.dilations[i]
|
||||
out_channels = int(base_channels * 2**i *
|
||||
self.arch['width_factor'][i])
|
||||
|
||||
stage, next_create_block_idx = self._make_stage(
|
||||
channels, out_channels, num_blocks, stride, dilation,
|
||||
next_create_block_idx, init_cfg)
|
||||
stage_name = f'stage_{i + 1}'
|
||||
self.add_module(stage_name, stage)
|
||||
self.stages.append(stage_name)
|
||||
|
||||
channels = out_channels
|
||||
|
||||
def _make_stage(self, in_channels, out_channels, num_blocks, stride,
|
||||
dilation, next_create_block_idx, init_cfg):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
dilations = [dilation] * num_blocks
|
||||
|
||||
blocks = []
|
||||
for i in range(num_blocks):
|
||||
groups = self.arch['group_layer_map'].get(
|
||||
next_create_block_idx,
|
||||
1) if self.arch['group_layer_map'] is not None else 1
|
||||
blocks.append(
|
||||
RepVGGBlock(
|
||||
in_channels,
|
||||
out_channels,
|
||||
stride=strides[i],
|
||||
padding=dilations[i],
|
||||
dilation=dilations[i],
|
||||
groups=groups,
|
||||
se_cfg=self.arch['se_cfg'],
|
||||
with_cp=self.with_cp,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
deploy=self.deploy,
|
||||
init_cfg=init_cfg))
|
||||
in_channels = out_channels
|
||||
next_create_block_idx += 1
|
||||
|
||||
return Sequential(*blocks), next_create_block_idx
|
||||
|
||||
def forward(self, x):
|
||||
x = self.stem(x)
|
||||
outs = []
|
||||
for i, stage_name in enumerate(self.stages):
|
||||
stage = getattr(self, stage_name)
|
||||
x = stage(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
return tuple(outs)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
self.stem.eval()
|
||||
for param in self.stem.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(self.frozen_stages):
|
||||
stage = getattr(self, f'stage_{i+1}')
|
||||
stage.eval()
|
||||
for param in stage.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(RepVGG, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
||||
|
||||
def switch_to_deploy(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, RepVGGBlock):
|
||||
m.switch_to_deploy()
|
||||
self.deploy = True
|
|
@ -8,4 +8,5 @@ Import:
|
|||
- configs/shufflenet_v2/metafile.yml
|
||||
- configs/swin_transformer/metafile.yml
|
||||
- configs/vgg/metafile.yml
|
||||
- configs/repvgg/metafile.yml
|
||||
- configs/tnt/metafile.yml
|
||||
|
|
|
@ -0,0 +1,293 @@
|
|||
import os
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from mmcv.runner import load_checkpoint, save_checkpoint
|
||||
from torch import nn
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import RepVGG
|
||||
from mmcls.models.backbones.repvgg import RepVGGBlock
|
||||
from mmcls.models.utils import SELayer
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_repvgg_block(modules):
|
||||
if isinstance(modules, RepVGGBlock):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def test_repvgg_repvggblock():
|
||||
# Test RepVGGBlock with in_channels != out_channels, stride = 1
|
||||
block = RepVGGBlock(5, 10, stride=1)
|
||||
block.eval()
|
||||
x = torch.randn(1, 5, 16, 16)
|
||||
x_out_not_deploy = block(x)
|
||||
assert block.branch_norm is None
|
||||
assert not hasattr(block, 'branch_reparam')
|
||||
assert hasattr(block, 'branch_1x1')
|
||||
assert hasattr(block, 'branch_3x3')
|
||||
assert hasattr(block, 'branch_norm')
|
||||
assert block.se_cfg is None
|
||||
assert x_out_not_deploy.shape == torch.Size((1, 10, 16, 16))
|
||||
block.switch_to_deploy()
|
||||
assert block.deploy is True
|
||||
x_out_deploy = block(x)
|
||||
assert x_out_deploy.shape == torch.Size((1, 10, 16, 16))
|
||||
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||
|
||||
# Test RepVGGBlock with in_channels == out_channels, stride = 1
|
||||
block = RepVGGBlock(12, 12, stride=1)
|
||||
block.eval()
|
||||
x = torch.randn(1, 12, 8, 8)
|
||||
x_out_not_deploy = block(x)
|
||||
assert isinstance(block.branch_norm, nn.BatchNorm2d)
|
||||
assert not hasattr(block, 'branch_reparam')
|
||||
assert x_out_not_deploy.shape == torch.Size((1, 12, 8, 8))
|
||||
block.switch_to_deploy()
|
||||
assert block.deploy is True
|
||||
x_out_deploy = block(x)
|
||||
assert x_out_deploy.shape == torch.Size((1, 12, 8, 8))
|
||||
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||
|
||||
# Test RepVGGBlock with in_channels == out_channels, stride = 2
|
||||
block = RepVGGBlock(16, 16, stride=2)
|
||||
block.eval()
|
||||
x = torch.randn(1, 16, 8, 8)
|
||||
x_out_not_deploy = block(x)
|
||||
assert block.branch_norm is None
|
||||
assert x_out_not_deploy.shape == torch.Size((1, 16, 4, 4))
|
||||
block.switch_to_deploy()
|
||||
assert block.deploy is True
|
||||
x_out_deploy = block(x)
|
||||
assert x_out_deploy.shape == torch.Size((1, 16, 4, 4))
|
||||
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||
|
||||
# Test RepVGGBlock with padding == dilation == 2
|
||||
block = RepVGGBlock(14, 14, stride=1, padding=2, dilation=2)
|
||||
block.eval()
|
||||
x = torch.randn(1, 14, 16, 16)
|
||||
x_out_not_deploy = block(x)
|
||||
assert isinstance(block.branch_norm, nn.BatchNorm2d)
|
||||
assert x_out_not_deploy.shape == torch.Size((1, 14, 16, 16))
|
||||
block.switch_to_deploy()
|
||||
assert block.deploy is True
|
||||
x_out_deploy = block(x)
|
||||
assert x_out_deploy.shape == torch.Size((1, 14, 16, 16))
|
||||
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||
|
||||
# Test RepVGGBlock with groups = 2
|
||||
block = RepVGGBlock(4, 4, stride=1, groups=2)
|
||||
block.eval()
|
||||
x = torch.randn(1, 4, 5, 6)
|
||||
x_out_not_deploy = block(x)
|
||||
assert x_out_not_deploy.shape == torch.Size((1, 4, 5, 6))
|
||||
block.switch_to_deploy()
|
||||
assert block.deploy is True
|
||||
x_out_deploy = block(x)
|
||||
assert x_out_deploy.shape == torch.Size((1, 4, 5, 6))
|
||||
assert torch.allclose(x_out_not_deploy, x_out_deploy, atol=1e-5, rtol=1e-4)
|
||||
|
||||
# Test RepVGGBlock with se
|
||||
se_cfg = dict(ratio=4, divisor=1)
|
||||
block = RepVGGBlock(18, 18, stride=1, se_cfg=se_cfg)
|
||||
block.train()
|
||||
x = torch.randn(1, 18, 5, 5)
|
||||
x_out_not_deploy = block(x)
|
||||
assert isinstance(block.se_layer, SELayer)
|
||||
assert x_out_not_deploy.shape == torch.Size((1, 18, 5, 5))
|
||||
|
||||
# Test RepVGGBlock with checkpoint forward
|
||||
block = RepVGGBlock(24, 24, stride=1, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 24, 7, 7)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size((1, 24, 7, 7))
|
||||
|
||||
# Test RepVGGBlock with deploy == True
|
||||
block = RepVGGBlock(8, 8, stride=1, deploy=True)
|
||||
assert isinstance(block.branch_reparam, nn.Conv2d)
|
||||
assert not hasattr(block, 'branch_3x3')
|
||||
assert not hasattr(block, 'branch_1x1')
|
||||
assert not hasattr(block, 'branch_norm')
|
||||
x = torch.randn(1, 8, 16, 16)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size((1, 8, 16, 16))
|
||||
|
||||
|
||||
def test_repvgg_backbone():
|
||||
with pytest.raises(TypeError):
|
||||
# arch must be str or dict
|
||||
RepVGG(arch=[4, 6, 16, 1])
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# arch must in arch_settings
|
||||
RepVGG(arch='A3')
|
||||
|
||||
with pytest.raises(KeyError):
|
||||
# arch must have num_blocks and width_factor
|
||||
arch = dict(num_blocks=[2, 4, 14, 1])
|
||||
RepVGG(arch=arch)
|
||||
|
||||
# len(arch['num_blocks']) == len(arch['width_factor'])
|
||||
# == len(strides) == len(dilations)
|
||||
with pytest.raises(AssertionError):
|
||||
arch = dict(num_blocks=[2, 4, 14, 1], width_factor=[0.75, 0.75, 0.75])
|
||||
RepVGG(arch=arch)
|
||||
|
||||
# len(strides) must equal to 4
|
||||
with pytest.raises(AssertionError):
|
||||
RepVGG('A0', strides=(1, 1, 1))
|
||||
|
||||
# len(dilations) must equal to 4
|
||||
with pytest.raises(AssertionError):
|
||||
RepVGG('A0', strides=(1, 1, 1, 1), dilations=(1, 1, 2))
|
||||
|
||||
# max(out_indices) < len(arch['num_blocks'])
|
||||
with pytest.raises(AssertionError):
|
||||
RepVGG('A0', out_indices=(5, ))
|
||||
|
||||
# max(arch['group_idx'].keys()) <= sum(arch['num_blocks'])
|
||||
with pytest.raises(AssertionError):
|
||||
arch = dict(
|
||||
num_blocks=[2, 4, 14, 1],
|
||||
width_factor=[0.75, 0.75, 0.75],
|
||||
group_idx={22: 2})
|
||||
RepVGG(arch=arch)
|
||||
|
||||
# Test RepVGG norm state
|
||||
model = RepVGG('A0')
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
# Test RepVGG with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = RepVGG('A0', frozen_stages=frozen_stages)
|
||||
model.train()
|
||||
for param in model.stem.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(0, frozen_stages):
|
||||
stage_name = model.stages[i]
|
||||
stage = model.__getattr__(stage_name)
|
||||
for mod in stage:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in stage.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test RepVGG with norm_eval
|
||||
model = RepVGG('A0', norm_eval=True)
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test RepVGG forward with layer 3 forward
|
||||
model = RepVGG('A0', out_indices=(3, ))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert isinstance(feat, tuple)
|
||||
assert len(feat) == 1
|
||||
assert isinstance(feat[0], torch.Tensor)
|
||||
assert feat[0].shape == torch.Size((1, 1280, 7, 7))
|
||||
|
||||
# Test RepVGG forward
|
||||
model_test_settings = [
|
||||
dict(model_name='A0', out_sizes=(48, 96, 192, 1280)),
|
||||
dict(model_name='A1', out_sizes=(64, 128, 256, 1280)),
|
||||
dict(model_name='A2', out_sizes=(96, 192, 384, 1408)),
|
||||
dict(model_name='B0', out_sizes=(64, 128, 256, 1280)),
|
||||
dict(model_name='B1', out_sizes=(128, 256, 512, 2048)),
|
||||
dict(model_name='B1g2', out_sizes=(128, 256, 512, 2048)),
|
||||
dict(model_name='B1g4', out_sizes=(128, 256, 512, 2048)),
|
||||
dict(model_name='B2', out_sizes=(160, 320, 640, 2560)),
|
||||
dict(model_name='B2g2', out_sizes=(160, 320, 640, 2560)),
|
||||
dict(model_name='B2g4', out_sizes=(160, 320, 640, 2560)),
|
||||
dict(model_name='B3', out_sizes=(192, 384, 768, 2560)),
|
||||
dict(model_name='B3g2', out_sizes=(192, 384, 768, 2560)),
|
||||
dict(model_name='B3g4', out_sizes=(192, 384, 768, 2560)),
|
||||
dict(model_name='D2se', out_sizes=(160, 320, 640, 2560))
|
||||
]
|
||||
|
||||
choose_models = ['A0', 'B1', 'B1g2', 'D2se']
|
||||
# Test RepVGG model forward
|
||||
for model_test_setting in model_test_settings:
|
||||
if model_test_setting['model_name'] not in choose_models:
|
||||
continue
|
||||
model = RepVGG(
|
||||
model_test_setting['model_name'], out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
|
||||
# Test Norm
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert feat[0].shape == torch.Size(
|
||||
(1, model_test_setting['out_sizes'][0], 56, 56))
|
||||
assert feat[1].shape == torch.Size(
|
||||
(1, model_test_setting['out_sizes'][1], 28, 28))
|
||||
assert feat[2].shape == torch.Size(
|
||||
(1, model_test_setting['out_sizes'][2], 14, 14))
|
||||
assert feat[3].shape == torch.Size(
|
||||
(1, model_test_setting['out_sizes'][3], 7, 7))
|
||||
|
||||
# Test eval of "train" mode and "deploy" mode
|
||||
gap = nn.AdaptiveAvgPool2d(output_size=(1))
|
||||
fc = nn.Linear(model_test_setting['out_sizes'][3], 10)
|
||||
model.eval()
|
||||
feat = model(imgs)
|
||||
pred = fc(gap(feat[3]).flatten(1))
|
||||
model.switch_to_deploy()
|
||||
for m in model.modules():
|
||||
if isinstance(m, RepVGGBlock):
|
||||
assert m.deploy is True
|
||||
feat_deploy = model(imgs)
|
||||
pred_deploy = fc(gap(feat_deploy[3]).flatten(1))
|
||||
for i in range(4):
|
||||
torch.allclose(feat[i], feat_deploy[i])
|
||||
torch.allclose(pred, pred_deploy)
|
||||
|
||||
|
||||
def test_repvgg_load():
|
||||
# Test ouput before and load from deploy checkpoint
|
||||
model = RepVGG('A1', out_indices=(0, 1, 2, 3))
|
||||
inputs = torch.randn((1, 3, 224, 224))
|
||||
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
|
||||
model.switch_to_deploy()
|
||||
model.eval()
|
||||
outputs = model(inputs)
|
||||
|
||||
model_deploy = RepVGG('A1', out_indices=(0, 1, 2, 3), deploy=True)
|
||||
save_checkpoint(model, ckpt_path)
|
||||
load_checkpoint(model_deploy, ckpt_path, strict=True)
|
||||
|
||||
outputs_load = model_deploy(inputs)
|
||||
for feat, feat_load in zip(outputs, outputs_load):
|
||||
assert torch.allclose(feat, feat_load)
|
|
@ -0,0 +1,46 @@
|
|||
import argparse
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from mmcls.apis import init_model
|
||||
|
||||
|
||||
def convert_repvggblock_param(config_path, checkpoint_path, save_path):
|
||||
model = init_model(config_path, checkpoint=checkpoint_path)
|
||||
print('Converting...')
|
||||
|
||||
model.backbone.switch_to_deploy()
|
||||
torch.save(model.state_dict(), save_path)
|
||||
|
||||
print('Done! Save at path "{}"'.format(save_path))
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description='Convert the parameters of the repvgg block '
|
||||
'from training mode to deployment mode.')
|
||||
parser.add_argument(
|
||||
'config_path',
|
||||
help='The path to the configuration file of the network '
|
||||
'containing the repvgg block.')
|
||||
parser.add_argument(
|
||||
'checkpoint_path',
|
||||
help='The path to the checkpoint file corresponding to the model.')
|
||||
parser.add_argument(
|
||||
'save_path',
|
||||
help='The path where the converted checkpoint file is stored.')
|
||||
args = parser.parse_args()
|
||||
|
||||
save_path = Path(args.save_path)
|
||||
if save_path.suffix != '.pth':
|
||||
print('The path should contain the name of the pth format file.')
|
||||
exit()
|
||||
save_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
convert_repvggblock_param(args.config_path, args.checkpoint_path,
|
||||
args.save_path)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
|
@ -0,0 +1,59 @@
|
|||
import argparse
|
||||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def convert(src, dst):
|
||||
print('Converting...')
|
||||
blobs = torch.load(src, map_location='cpu')
|
||||
converted_state_dict = OrderedDict()
|
||||
|
||||
for key in blobs:
|
||||
splited_key = key.split('.')
|
||||
splited_key = ['norm' if i == 'bn' else i for i in splited_key]
|
||||
splited_key = [
|
||||
'branch_norm' if i == 'rbr_identity' else i for i in splited_key
|
||||
]
|
||||
splited_key = [
|
||||
'branch_1x1' if i == 'rbr_1x1' else i for i in splited_key
|
||||
]
|
||||
splited_key = [
|
||||
'branch_3x3' if i == 'rbr_dense' else i for i in splited_key
|
||||
]
|
||||
splited_key = [
|
||||
'backbone.stem' if i[:6] == 'stage0' else i for i in splited_key
|
||||
]
|
||||
splited_key = [
|
||||
'backbone.stage_' + i[5] if i[:5] == 'stage' else i
|
||||
for i in splited_key
|
||||
]
|
||||
splited_key = ['se_layer' if i == 'se' else i for i in splited_key]
|
||||
splited_key = ['conv1.conv' if i == 'down' else i for i in splited_key]
|
||||
splited_key = ['conv2.conv' if i == 'up' else i for i in splited_key]
|
||||
splited_key = ['head.fc' if i == 'linear' else i for i in splited_key]
|
||||
new_key = '.'.join(splited_key)
|
||||
converted_state_dict[new_key] = blobs[key]
|
||||
|
||||
torch.save(converted_state_dict, dst)
|
||||
print('Done!')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description='Convert model keys')
|
||||
parser.add_argument('src', help='src detectron model path')
|
||||
parser.add_argument('dst', help='save path')
|
||||
args = parser.parse_args()
|
||||
|
||||
dst = Path(args.dst)
|
||||
if dst.suffix != '.pth':
|
||||
print('The path should contain the name of the pth format file.')
|
||||
exit()
|
||||
dst.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
convert(args.src, args.dst)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
Loading…
Reference in New Issue