[Feature] Support MMRazor searchable backbone (#453)

* update subnet cfg

* add docs

* update model link

* fix lint

* mdformat

* update readme

* fix lint

* update link

* rename folder

* fix readme

* update readme

* make lint

* rename

* update readme

* sync mmrazor cfg

* fix cfg

* install issue

* require mmcls

* fix yolo cfg

---------

Co-authored-by: aptsunny <aptsunny@tongji.edu.cn>
Co-authored-by: sunyue1 <sunyue1@sensetime.com>
pull/547/head
Yue Sun 2023-02-10 16:24:47 +08:00 committed by GitHub
parent 3a6899e232
commit 164c319493
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 285 additions and 0 deletions

View File

@ -0,0 +1,79 @@
# Projecs Based on MMRazor
There are many research works and pre-trained models built on MMRazor. We list some of them as examples of how to use MMRazor slimmable models for downstream frameworks. As the page might not be completed, please feel free to contribute more efficient mmrazor-models to update this page.
## Description
This is an implementation of MMRazor Searchable Backbone Application, we provide detection configs and models for MMRazor in MMYOLO.
### Backbone support
Here are the Neural Architecture Search(NAS) Models that come from MMRazor which support YOLO Series. If you are looking for MMRazor models only for Backbone, you could refer to MMRazor [ModelZoo](https://github.com/open-mmlab/mmrazor/blob/dev-1.x/docs/en/get_started/model_zoo.md) and corresponding repository.
- [x] [AttentiveMobileNetV3](https://github.com/open-mmlab/mmrazor/blob/dev-1.x/configs/_base_/nas_backbones/attentive_mobilenetv3_supernet.py)
- [x] [SearchableShuffleNetV2](https://github.com/open-mmlab/mmrazor/blob/dev-1.x/configs/_base_/nas_backbones/spos_shufflenet_supernet.py)
- [x] [SearchableMobileNetV2](https://github.com/open-mmlab/mmrazor/blob/dev-1.x/configs/_base_/nas_backbones/spos_mobilenet_supernet.py)
## Usage
### Prerequisites
- [MMRazor v1.0.0rc2](https://github.com/open-mmlab/mmrazor/tree/v1.0.0rc2) or higher (dev-1.x)
Install MMRazor using MIM.
```shell
mim install mmengine
mim install "mmrazor>=1.0.0rc2"
```
Install MMRazor from source
```
git clone -b dev-1.x https://github.com/open-mmlab/mmrazor.git
cd mmrazor
# Install MMRazor
mim install -v -e .
```
### Training commands
In MMYOLO's root directory, if you want to use single GPU for training, run the following command to train the model:
```bash
CUDA_VISIBLE_DEVICES=0 PORT=29500 ./tools/dist_train.sh configs/razor/subnets/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py
```
If you want to use several of these GPUs to train in parallel, you can use the following command:
```bash
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh configs/razor/subnets/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py
```
### Testing commands
In MMYOLO's root directory, run the following command to test the model:
```bash
CUDA_VISIBLE_DEVICES=0 PORT=29500 ./tools/dist_test.sh configs/razor/subnets/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py ${CHECKPOINT_PATH}
```
## Results and Models
Here we provide the baseline version of YOLO Series with NAS backbone.
| Model | size | box AP | Params(M) | FLOPS(G) | Config | Download |
| :------------------------: | :--: | :----: | :----------: | :------: | :----------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
| yolov5-s | 640 | 37.7 | 7.235 | 8.265 | [config](https://github.com/open-mmlab/mmyolo/blob/main/configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth) |
| yolov5_s_spos_shufflenetv2 | 640 | 37.9 | 7.04(-2.7%) | 7.03 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/razor/subnets/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/spos/yolov5/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco_20230109_155302-777fd6f1.pth) |
| yolov6-s | 640 | 44.0 | 18.869 | 24.253 | [config](https://github.com/open-mmlab/mmyolo/blob/main/configs/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035-932e1d91.pth) |
| yolov6_l_attentivenas_a6 | 640 | 44.5 | 18.38(-2.6%) | 8.49 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/razor/subnets/yolov6_l_attentivenas_a6_d12_syncbn_fast_16xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/attentivenas/yolov6/yolov6_l_attentivenas_a6_d12_syncbn_fast_16xb16-300e_coco_20230108_174944-4970f0b7.pth) |
| RTMDet-tiny | 640 | 41.0 | 4.8 | 8.1 | [config](./rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth) |
| rtmdet_tiny_ofa_lat31 | 960 | 41.1 | 3.91(-18.5%) | 6.09 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/razor/subnets/rtmdet_tiny_ofa_lat31_syncbn_16xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/ofa/rtmdet/rtmdet_tiny_ofa_lat31_syncbn_16xb16-300e_coco_20230108_222141-24ff87dex.pth) |
**Note**:
1. For fair comparison, the training configuration is consistent with the original configuration and results in an improvement of about 0.2-0.5% AP.
2. `yolov5_s_spos_shufflenetv2` achieves 37.9% AP with only 7.042M parameters, directly instead of the backbone, and outperforms `yolov5_s` with a similar size by more than 0.2% AP.
3. With the efficient backbone of `yolov6_l_attentivenas_a6`, the input channels of `YOLOv6RepPAFPN` are reduced. Meanwhile, modify the **deepen_factor** and the neck is made deeper to restore the AP.
4. with the `rtmdet_tiny_ofa_lat31` backbone with only 3.315M parameters and 3.634G flops, we can modify the input resolution to 960, with a similar model size compared to `rtmdet_tiny` and exceeds `rtmdet_tiny` by 0.1% AP, reducing the size of the whole model to 3.91 MB.

View File

@ -0,0 +1,118 @@
_base_ = [
'mmrazor::_base_/nas_backbones/ofa_mobilenetv3_supernet.py',
'../../rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmrazor/v1/ofa/ofa_mobilenet_subnet_8xb256_in1k_note8_lat%4031ms_top1%4072.8_finetune%4025.py_20221214_0939-981a8b2a.pth' # noqa
fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/ofa/rtmdet/OFA_SUBNET_NOTE8_LAT31.yaml' # noqa
deepen_factor = 0.167
widen_factor = 1.0
channels = [40, 112, 160]
train_batch_size_per_gpu = 16
img_scale = (960, 960)
_base_.base_lr = 0.002
_base_.optim_wrapper.optimizer.lr = 0.002
_base_.param_scheduler[1].eta_min = 0.002 * 0.05
_base_.nas_backbone.out_indices = (2, 4, 5)
_base_.nas_backbone.conv_cfg = dict(type='mmrazor.OFAConv2d')
_base_.nas_backbone.init_cfg = dict(
type='Pretrained',
checkpoint=checkpoint_file,
prefix='architecture.backbone.')
nas_backbone = dict(
type='mmrazor.sub_model',
fix_subnet=fix_subnet,
cfg=_base_.nas_backbone,
extra_prefix='backbone.')
_base_.model.backbone = nas_backbone
_base_.model.neck.widen_factor = widen_factor
_base_.model.neck.deepen_factor = deepen_factor
_base_.model.neck.in_channels = channels
_base_.model.neck.out_channels = channels[0]
_base_.model.bbox_head.head_module.in_channels = channels[0]
_base_.model.bbox_head.head_module.feat_channels = channels[0]
_base_.model.bbox_head.head_module.widen_factor = widen_factor
train_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='Mosaic',
img_scale=img_scale,
use_cached=True,
max_cached_images=40,
pad_val=114.0),
dict(
type='mmdet.RandomResize',
# img_scale is (width, height)
scale=(img_scale[0] * 2, img_scale[1] * 2),
ratio_range=(0.5, 2.0), # note
resize_type='mmdet.Resize',
keep_ratio=True),
dict(type='mmdet.RandomCrop', crop_size=img_scale),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type='YOLOv5MixUp', use_cached=True, max_cached_images=20),
dict(type='mmdet.PackDetInputs')
]
train_pipeline_stage2 = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='LoadAnnotations', with_bbox=True),
dict(
type='mmdet.RandomResize',
scale=img_scale,
ratio_range=(0.5, 2.0), # note
resize_type='mmdet.Resize',
keep_ratio=True),
dict(type='mmdet.RandomCrop', crop_size=img_scale),
dict(type='mmdet.YOLOXHSVRandomAug'),
dict(type='mmdet.RandomFlip', prob=0.5),
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
dict(type='mmdet.PackDetInputs')
]
train_dataloader = dict(
batch_size=train_batch_size_per_gpu, dataset=dict(pipeline=train_pipeline))
test_pipeline = [
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
dict(
type='LetterResize',
scale=img_scale,
allow_scale_up=False,
pad_val=dict(img=114)),
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
dict(
type='mmdet.PackDetInputs',
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
'scale_factor', 'pad_param'))
]
batch_shapes_cfg = dict(img_size=img_scale[0])
val_dataloader = dict(
dataset=dict(pipeline=test_pipeline, batch_shapes_cfg=batch_shapes_cfg))
test_dataloader = val_dataloader
custom_hooks = [
dict(
type='EMAHook',
ema_type='ExpMomentumEMA',
momentum=0.0002,
update_buffers=True,
strict_load=False,
priority=49),
dict(
type='mmdet.PipelineSwitchHook',
switch_epoch=_base_.max_epochs - _base_.stage2_num_epochs,
switch_pipeline=train_pipeline_stage2)
]
find_unused_parameters = True

View File

@ -0,0 +1,30 @@
_base_ = [
'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py',
'../../yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_v3.pth' # noqa
fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_subnet_cfg_v3.yaml' # noqa
widen_factor = 1.0
channels = [160, 320, 640]
_base_.nas_backbone.out_indices = (1, 2, 3)
_base_.nas_backbone.init_cfg = dict(
type='Pretrained',
checkpoint=checkpoint_file,
prefix='architecture.backbone.')
nas_backbone = dict(
type='mmrazor.sub_model',
fix_subnet=fix_subnet,
cfg=_base_.nas_backbone,
extra_prefix='architecture.backbone.')
_base_.model.backbone = nas_backbone
_base_.model.neck.widen_factor = widen_factor
_base_.model.neck.in_channels = channels
_base_.model.neck.out_channels = channels
_base_.model.bbox_head.head_module.in_channels = channels
_base_.model.bbox_head.head_module.widen_factor = widen_factor
find_unused_parameters = True

View File

@ -0,0 +1,35 @@
_base_ = [
'mmrazor::_base_/nas_backbones/attentive_mobilenetv3_supernet.py',
'../../yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco.py'
]
checkpoint_file = 'https://download.openmmlab.com/mmrazor/v1/bignas/attentive_mobilenet_subnet_8xb256_in1k_flops-0.93G_acc-80.81_20221229_200440-73d92cc6.pth' # noqa
fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/bignas/ATTENTIVE_SUBNET_A6.yaml' # noqa
deepen_factor = 1.2
widen_factor = 1
channels = [40, 128, 224]
mid_channels = [40, 128, 224]
_base_.train_dataloader.batch_size = 16
_base_.nas_backbone.out_indices = (2, 4, 6)
_base_.nas_backbone.conv_cfg = dict(type='mmrazor.BigNasConv2d')
_base_.nas_backbone.norm_cfg = dict(type='mmrazor.DynamicBatchNorm2d')
_base_.nas_backbone.init_cfg = dict(
type='Pretrained',
checkpoint=checkpoint_file,
prefix='architecture.backbone.')
nas_backbone = dict(
type='mmrazor.sub_model',
fix_subnet=fix_subnet,
cfg=_base_.nas_backbone,
extra_prefix='backbone.')
_base_.model.backbone = nas_backbone
_base_.model.neck.widen_factor = widen_factor
_base_.model.neck.deepen_factor = deepen_factor
_base_.model.neck.in_channels = channels
_base_.model.neck.out_channels = mid_channels
_base_.model.bbox_head.head_module.in_channels = mid_channels
_base_.model.bbox_head.head_module.widen_factor = widen_factor
find_unused_parameters = True

View File

@ -5,6 +5,8 @@ isort==4.3.21
# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
kwarray
memory_profiler
mmcls>=1.0.0rc4
mmrazor>=1.0.0rc2
parameterized
protobuf<=3.20.1
psutil

View File

@ -0,0 +1,21 @@
# Copyright (c) OpenMMLab. All rights reserved.
import copy
import pytest
from mmcls.models.backbones.base_backbone import BaseBackbone
from mmyolo.testing import get_detector_cfg
@pytest.mark.parametrize('cfg_file', [
'razor/subnets/'
'yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py', 'razor/subnets/'
'rtmdet_tiny_ofa_lat31_syncbn_16xb16-300e_coco.py', 'razor/subnets/'
'yolov6_l_attentivenas_a6_d12_syncbn_fast_16xb16-300e_coco.py'
])
def test_razor_backbone_forward(cfg_file):
model = get_detector_cfg(cfg_file)
model_cfg = copy.deepcopy(model.backbone)
from mmrazor.registry import MODELS
model = MODELS.build(model_cfg)
assert isinstance(model, BaseBackbone)