mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Support MMRazor searchable backbone (#453)
* update subnet cfg * add docs * update model link * fix lint * mdformat * update readme * fix lint * update link * rename folder * fix readme * update readme * make lint * rename * update readme * sync mmrazor cfg * fix cfg * install issue * require mmcls * fix yolo cfg --------- Co-authored-by: aptsunny <aptsunny@tongji.edu.cn> Co-authored-by: sunyue1 <sunyue1@sensetime.com>pull/547/head
parent
3a6899e232
commit
164c319493
|
@ -0,0 +1,79 @@
|
|||
# Projecs Based on MMRazor
|
||||
|
||||
There are many research works and pre-trained models built on MMRazor. We list some of them as examples of how to use MMRazor slimmable models for downstream frameworks. As the page might not be completed, please feel free to contribute more efficient mmrazor-models to update this page.
|
||||
|
||||
## Description
|
||||
|
||||
This is an implementation of MMRazor Searchable Backbone Application, we provide detection configs and models for MMRazor in MMYOLO.
|
||||
|
||||
### Backbone support
|
||||
|
||||
Here are the Neural Architecture Search(NAS) Models that come from MMRazor which support YOLO Series. If you are looking for MMRazor models only for Backbone, you could refer to MMRazor [ModelZoo](https://github.com/open-mmlab/mmrazor/blob/dev-1.x/docs/en/get_started/model_zoo.md) and corresponding repository.
|
||||
|
||||
- [x] [AttentiveMobileNetV3](https://github.com/open-mmlab/mmrazor/blob/dev-1.x/configs/_base_/nas_backbones/attentive_mobilenetv3_supernet.py)
|
||||
- [x] [SearchableShuffleNetV2](https://github.com/open-mmlab/mmrazor/blob/dev-1.x/configs/_base_/nas_backbones/spos_shufflenet_supernet.py)
|
||||
- [x] [SearchableMobileNetV2](https://github.com/open-mmlab/mmrazor/blob/dev-1.x/configs/_base_/nas_backbones/spos_mobilenet_supernet.py)
|
||||
|
||||
## Usage
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- [MMRazor v1.0.0rc2](https://github.com/open-mmlab/mmrazor/tree/v1.0.0rc2) or higher (dev-1.x)
|
||||
|
||||
Install MMRazor using MIM.
|
||||
|
||||
```shell
|
||||
mim install mmengine
|
||||
mim install "mmrazor>=1.0.0rc2"
|
||||
```
|
||||
|
||||
Install MMRazor from source
|
||||
|
||||
```
|
||||
git clone -b dev-1.x https://github.com/open-mmlab/mmrazor.git
|
||||
cd mmrazor
|
||||
# Install MMRazor
|
||||
mim install -v -e .
|
||||
```
|
||||
|
||||
### Training commands
|
||||
|
||||
In MMYOLO's root directory, if you want to use single GPU for training, run the following command to train the model:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PORT=29500 ./tools/dist_train.sh configs/razor/subnets/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py
|
||||
```
|
||||
|
||||
If you want to use several of these GPUs to train in parallel, you can use the following command:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 PORT=29500 ./tools/dist_train.sh configs/razor/subnets/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py
|
||||
```
|
||||
|
||||
### Testing commands
|
||||
|
||||
In MMYOLO's root directory, run the following command to test the model:
|
||||
|
||||
```bash
|
||||
CUDA_VISIBLE_DEVICES=0 PORT=29500 ./tools/dist_test.sh configs/razor/subnets/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py ${CHECKPOINT_PATH}
|
||||
```
|
||||
|
||||
## Results and Models
|
||||
|
||||
Here we provide the baseline version of YOLO Series with NAS backbone.
|
||||
|
||||
| Model | size | box AP | Params(M) | FLOPS(G) | Config | Download |
|
||||
| :------------------------: | :--: | :----: | :----------: | :------: | :----------------------------------------------------------------------------------------------------------------------------------------: | :---------------------------------------------------------------------------------------------------------------------------------------------------------------------: |
|
||||
| yolov5-s | 640 | 37.7 | 7.235 | 8.265 | [config](https://github.com/open-mmlab/mmyolo/blob/main/configs/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco/yolov5_s-v61_syncbn_fast_8xb16-300e_coco_20220918_084700-86e02187.pth) |
|
||||
| yolov5_s_spos_shufflenetv2 | 640 | 37.9 | 7.04(-2.7%) | 7.03 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/razor/subnets/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/spos/yolov5/yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco_20230109_155302-777fd6f1.pth) |
|
||||
| yolov6-s | 640 | 44.0 | 18.869 | 24.253 | [config](https://github.com/open-mmlab/mmyolo/blob/main/configs/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/yolov6/yolov6_s_syncbn_fast_8xb32-400e_coco/yolov6_s_syncbn_fast_8xb32-400e_coco_20221102_203035-932e1d91.pth) |
|
||||
| yolov6_l_attentivenas_a6 | 640 | 44.5 | 18.38(-2.6%) | 8.49 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/razor/subnets/yolov6_l_attentivenas_a6_d12_syncbn_fast_16xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/attentivenas/yolov6/yolov6_l_attentivenas_a6_d12_syncbn_fast_16xb16-300e_coco_20230108_174944-4970f0b7.pth) |
|
||||
| RTMDet-tiny | 640 | 41.0 | 4.8 | 8.1 | [config](./rtmdet_l_syncbn_fast_8xb32-300e_coco.py) | [model](https://download.openmmlab.com/mmyolo/v0/rtmdet/rtmdet_tiny_syncbn_fast_8xb32-300e_coco/rtmdet_tiny_syncbn_fast_8xb32-300e_coco_20230102_140117-dbb1dc83.pth) |
|
||||
| rtmdet_tiny_ofa_lat31 | 960 | 41.1 | 3.91(-18.5%) | 6.09 | [config](https://github.com/open-mmlab/mmyolo/tree/dev/configs/razor/subnets/rtmdet_tiny_ofa_lat31_syncbn_16xb16-300e_coco.py) | [model](https://download.openmmlab.com/mmrazor/v1/ofa/rtmdet/rtmdet_tiny_ofa_lat31_syncbn_16xb16-300e_coco_20230108_222141-24ff87dex.pth) |
|
||||
|
||||
**Note**:
|
||||
|
||||
1. For fair comparison, the training configuration is consistent with the original configuration and results in an improvement of about 0.2-0.5% AP.
|
||||
2. `yolov5_s_spos_shufflenetv2` achieves 37.9% AP with only 7.042M parameters, directly instead of the backbone, and outperforms `yolov5_s` with a similar size by more than 0.2% AP.
|
||||
3. With the efficient backbone of `yolov6_l_attentivenas_a6`, the input channels of `YOLOv6RepPAFPN` are reduced. Meanwhile, modify the **deepen_factor** and the neck is made deeper to restore the AP.
|
||||
4. with the `rtmdet_tiny_ofa_lat31` backbone with only 3.315M parameters and 3.634G flops, we can modify the input resolution to 960, with a similar model size compared to `rtmdet_tiny` and exceeds `rtmdet_tiny` by 0.1% AP, reducing the size of the whole model to 3.91 MB.
|
|
@ -0,0 +1,118 @@
|
|||
_base_ = [
|
||||
'mmrazor::_base_/nas_backbones/ofa_mobilenetv3_supernet.py',
|
||||
'../../rtmdet/rtmdet_s_syncbn_fast_8xb32-300e_coco.py'
|
||||
]
|
||||
|
||||
checkpoint_file = 'https://download.openmmlab.com/mmrazor/v1/ofa/ofa_mobilenet_subnet_8xb256_in1k_note8_lat%4031ms_top1%4072.8_finetune%4025.py_20221214_0939-981a8b2a.pth' # noqa
|
||||
fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/ofa/rtmdet/OFA_SUBNET_NOTE8_LAT31.yaml' # noqa
|
||||
deepen_factor = 0.167
|
||||
widen_factor = 1.0
|
||||
channels = [40, 112, 160]
|
||||
train_batch_size_per_gpu = 16
|
||||
img_scale = (960, 960)
|
||||
|
||||
_base_.base_lr = 0.002
|
||||
_base_.optim_wrapper.optimizer.lr = 0.002
|
||||
_base_.param_scheduler[1].eta_min = 0.002 * 0.05
|
||||
|
||||
_base_.nas_backbone.out_indices = (2, 4, 5)
|
||||
_base_.nas_backbone.conv_cfg = dict(type='mmrazor.OFAConv2d')
|
||||
_base_.nas_backbone.init_cfg = dict(
|
||||
type='Pretrained',
|
||||
checkpoint=checkpoint_file,
|
||||
prefix='architecture.backbone.')
|
||||
nas_backbone = dict(
|
||||
type='mmrazor.sub_model',
|
||||
fix_subnet=fix_subnet,
|
||||
cfg=_base_.nas_backbone,
|
||||
extra_prefix='backbone.')
|
||||
|
||||
_base_.model.backbone = nas_backbone
|
||||
_base_.model.neck.widen_factor = widen_factor
|
||||
_base_.model.neck.deepen_factor = deepen_factor
|
||||
_base_.model.neck.in_channels = channels
|
||||
_base_.model.neck.out_channels = channels[0]
|
||||
_base_.model.bbox_head.head_module.in_channels = channels[0]
|
||||
_base_.model.bbox_head.head_module.feat_channels = channels[0]
|
||||
_base_.model.bbox_head.head_module.widen_factor = widen_factor
|
||||
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||
dict(type='LoadAnnotations', with_bbox=True),
|
||||
dict(
|
||||
type='Mosaic',
|
||||
img_scale=img_scale,
|
||||
use_cached=True,
|
||||
max_cached_images=40,
|
||||
pad_val=114.0),
|
||||
dict(
|
||||
type='mmdet.RandomResize',
|
||||
# img_scale is (width, height)
|
||||
scale=(img_scale[0] * 2, img_scale[1] * 2),
|
||||
ratio_range=(0.5, 2.0), # note
|
||||
resize_type='mmdet.Resize',
|
||||
keep_ratio=True),
|
||||
dict(type='mmdet.RandomCrop', crop_size=img_scale),
|
||||
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
|
||||
dict(type='YOLOv5MixUp', use_cached=True, max_cached_images=20),
|
||||
dict(type='mmdet.PackDetInputs')
|
||||
]
|
||||
|
||||
train_pipeline_stage2 = [
|
||||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||
dict(type='LoadAnnotations', with_bbox=True),
|
||||
dict(
|
||||
type='mmdet.RandomResize',
|
||||
scale=img_scale,
|
||||
ratio_range=(0.5, 2.0), # note
|
||||
resize_type='mmdet.Resize',
|
||||
keep_ratio=True),
|
||||
dict(type='mmdet.RandomCrop', crop_size=img_scale),
|
||||
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||
dict(type='mmdet.RandomFlip', prob=0.5),
|
||||
dict(type='mmdet.Pad', size=img_scale, pad_val=dict(img=(114, 114, 114))),
|
||||
dict(type='mmdet.PackDetInputs')
|
||||
]
|
||||
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size_per_gpu, dataset=dict(pipeline=train_pipeline))
|
||||
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile', file_client_args=_base_.file_client_args),
|
||||
dict(type='YOLOv5KeepRatioResize', scale=img_scale),
|
||||
dict(
|
||||
type='LetterResize',
|
||||
scale=img_scale,
|
||||
allow_scale_up=False,
|
||||
pad_val=dict(img=114)),
|
||||
dict(type='LoadAnnotations', with_bbox=True, _scope_='mmdet'),
|
||||
dict(
|
||||
type='mmdet.PackDetInputs',
|
||||
meta_keys=('img_id', 'img_path', 'ori_shape', 'img_shape',
|
||||
'scale_factor', 'pad_param'))
|
||||
]
|
||||
|
||||
batch_shapes_cfg = dict(img_size=img_scale[0])
|
||||
|
||||
val_dataloader = dict(
|
||||
dataset=dict(pipeline=test_pipeline, batch_shapes_cfg=batch_shapes_cfg))
|
||||
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0002,
|
||||
update_buffers=True,
|
||||
strict_load=False,
|
||||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=_base_.max_epochs - _base_.stage2_num_epochs,
|
||||
switch_pipeline=train_pipeline_stage2)
|
||||
]
|
||||
|
||||
find_unused_parameters = True
|
|
@ -0,0 +1,30 @@
|
|||
_base_ = [
|
||||
'mmrazor::_base_/nas_backbones/spos_shufflenet_supernet.py',
|
||||
'../../yolov5/yolov5_s-v61_syncbn_fast_8xb16-300e_coco.py'
|
||||
]
|
||||
|
||||
checkpoint_file = 'https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_v3.pth' # noqa
|
||||
fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/spos/spos_shufflenetv2_subnet_8xb128_in1k_flops_0.33M_acc_73.87_20211222-1f0a0b4d_subnet_cfg_v3.yaml' # noqa
|
||||
|
||||
widen_factor = 1.0
|
||||
channels = [160, 320, 640]
|
||||
|
||||
_base_.nas_backbone.out_indices = (1, 2, 3)
|
||||
_base_.nas_backbone.init_cfg = dict(
|
||||
type='Pretrained',
|
||||
checkpoint=checkpoint_file,
|
||||
prefix='architecture.backbone.')
|
||||
nas_backbone = dict(
|
||||
type='mmrazor.sub_model',
|
||||
fix_subnet=fix_subnet,
|
||||
cfg=_base_.nas_backbone,
|
||||
extra_prefix='architecture.backbone.')
|
||||
|
||||
_base_.model.backbone = nas_backbone
|
||||
_base_.model.neck.widen_factor = widen_factor
|
||||
_base_.model.neck.in_channels = channels
|
||||
_base_.model.neck.out_channels = channels
|
||||
_base_.model.bbox_head.head_module.in_channels = channels
|
||||
_base_.model.bbox_head.head_module.widen_factor = widen_factor
|
||||
|
||||
find_unused_parameters = True
|
|
@ -0,0 +1,35 @@
|
|||
_base_ = [
|
||||
'mmrazor::_base_/nas_backbones/attentive_mobilenetv3_supernet.py',
|
||||
'../../yolov6/yolov6_l_syncbn_fast_8xb32-300e_coco.py'
|
||||
]
|
||||
|
||||
checkpoint_file = 'https://download.openmmlab.com/mmrazor/v1/bignas/attentive_mobilenet_subnet_8xb256_in1k_flops-0.93G_acc-80.81_20221229_200440-73d92cc6.pth' # noqa
|
||||
fix_subnet = 'https://download.openmmlab.com/mmrazor/v1/bignas/ATTENTIVE_SUBNET_A6.yaml' # noqa
|
||||
deepen_factor = 1.2
|
||||
widen_factor = 1
|
||||
channels = [40, 128, 224]
|
||||
mid_channels = [40, 128, 224]
|
||||
|
||||
_base_.train_dataloader.batch_size = 16
|
||||
_base_.nas_backbone.out_indices = (2, 4, 6)
|
||||
_base_.nas_backbone.conv_cfg = dict(type='mmrazor.BigNasConv2d')
|
||||
_base_.nas_backbone.norm_cfg = dict(type='mmrazor.DynamicBatchNorm2d')
|
||||
_base_.nas_backbone.init_cfg = dict(
|
||||
type='Pretrained',
|
||||
checkpoint=checkpoint_file,
|
||||
prefix='architecture.backbone.')
|
||||
nas_backbone = dict(
|
||||
type='mmrazor.sub_model',
|
||||
fix_subnet=fix_subnet,
|
||||
cfg=_base_.nas_backbone,
|
||||
extra_prefix='backbone.')
|
||||
|
||||
_base_.model.backbone = nas_backbone
|
||||
_base_.model.neck.widen_factor = widen_factor
|
||||
_base_.model.neck.deepen_factor = deepen_factor
|
||||
_base_.model.neck.in_channels = channels
|
||||
_base_.model.neck.out_channels = mid_channels
|
||||
_base_.model.bbox_head.head_module.in_channels = mid_channels
|
||||
_base_.model.bbox_head.head_module.widen_factor = widen_factor
|
||||
|
||||
find_unused_parameters = True
|
|
@ -5,6 +5,8 @@ isort==4.3.21
|
|||
# Note: used for kwarray.group_items, this may be ported to mmcv in the future.
|
||||
kwarray
|
||||
memory_profiler
|
||||
mmcls>=1.0.0rc4
|
||||
mmrazor>=1.0.0rc2
|
||||
parameterized
|
||||
protobuf<=3.20.1
|
||||
psutil
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import copy
|
||||
|
||||
import pytest
|
||||
from mmcls.models.backbones.base_backbone import BaseBackbone
|
||||
|
||||
from mmyolo.testing import get_detector_cfg
|
||||
|
||||
|
||||
@pytest.mark.parametrize('cfg_file', [
|
||||
'razor/subnets/'
|
||||
'yolov5_s_spos_shufflenetv2_syncbn_8xb16-300e_coco.py', 'razor/subnets/'
|
||||
'rtmdet_tiny_ofa_lat31_syncbn_16xb16-300e_coco.py', 'razor/subnets/'
|
||||
'yolov6_l_attentivenas_a6_d12_syncbn_fast_16xb16-300e_coco.py'
|
||||
])
|
||||
def test_razor_backbone_forward(cfg_file):
|
||||
model = get_detector_cfg(cfg_file)
|
||||
model_cfg = copy.deepcopy(model.backbone)
|
||||
from mmrazor.registry import MODELS
|
||||
model = MODELS.build(model_cfg)
|
||||
assert isinstance(model, BaseBackbone)
|
Loading…
Reference in New Issue