[Feature]Support plugin layers for backbone (#75)

* add plugin layer

* add docstring. add config

* clean the config

* add how_to-doc

* add ut

* update

* update

* update lint

* fix yoloxbug

* update

* update

* update doc

* update

* update

* update

* update

* update
pull/137/head
wanghonglie 2022-09-29 15:36:32 +08:00 committed by Haian Huang(深度眸)
parent 05a5b2aaa2
commit d705f1c57b
7 changed files with 249 additions and 7 deletions

View File

@ -1 +1,58 @@
# How to
This tutorial collects answers to any `How to xxx with MMYOLO`. Feel free to update this doc if you meet new questions about `How to` and find the answers!
# Add plugins to the BackBone network
MMYOLO supports adding plug-ins such as none_local and dropout after different stages of BackBone. Users can directly manage plug-ins by modifying the plugins parameter of backbone in config. For example, add GeneralizedAttention plug-ins for `YOLOv5`. The configuration files are as follows:
```python
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
model = dict(
backbone=dict(
plugins=[
dict(
cfg=dict(
type='mmdet.GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0011',
kv_stride=2),
stages=(False, False, True, True)),
], ))
```
`cfg` parameter indicates the specific configuration of the plug-in. The `stages` parameter indicates whether to add plug-ins after the corresponding stage of the backbone. The length of list `stages` must be the same as the number of backbone stages.
## Apply multiple Necks
If you want to stack multiple Necks, you can directly set the Neck parameters in the config. MMYOLO supports concatenating multiple Necks in the form of `List`. You need to ensure that the output channel of the previous Neck matches the input channel of the next Neck. If you need to adjust the number of channels, you can insert the `mmdet.ChannelMapper` module to align the number of channels between multiple Necks. The specific configuration is as follows:
```python
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
model = dict(
type='YOLODetector',
neck=[
dict(
type='YOLOv5PAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
num_csp_blocks=3,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
dict(
type='mmdet.ChannelMapper',
in_channels=[128, 256, 512],
out_channels=128,
),
dict(
type='mmdet.DyHead',
in_channels=128,
out_channels=256,
num_blocks=2,
# disable zero_init_offset to follow official implementation
zero_init_offset=False)
]
```

View File

@ -1 +1,58 @@
# how to
本教程收集了任何如何使用 MMYOLO 进行 xxx 的答案。 如果您遇到有关`如何做`的问题及答案,请随时更新此文档!
## 给骨干网络增加插件
MMYOLO 支持在 BackBone 的不同 Stage 后增加如 `none_local`、`dropblock` 等插件,用户可以直接通过修改 config 文件中 `backbone``plugins` 参数来实现对插件的管理。例如为 `YOLOv5` 增加 `GeneralizedAttention` 插件,其配置文件如下:
```python
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
model = dict(
backbone=dict(
plugins=[
dict(
cfg=dict(
type='mmdet.GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0011',
kv_stride=2),
stages=(False, False, True, True)),
], ))
```
`cfg` 参数表示插件的具体配置, `stages` 参数表示是否在 backbone 对应的 stage 后面增加插件,长度需要和 backbone 的 stage 数量相同。
## 应用多个 Neck
如果你想堆叠多个 Neck可以直接在配置文件中的 Neck 参数MMYOLO 支持以 `List` 形式拼接多个 Neck 配置,你需要保证的是上一个 Neck 的输出通道与下一个 Neck 的输入通道相匹配。如需要调整通道,可以插入 `mmdet.ChannelMapper` 模块用来对齐多个 Neck 之间的通道数量。具体配置如下:
```python
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
model = dict(
type='YOLODetector',
neck=[
dict(
type='YOLOv5PAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024],
num_csp_blocks=3,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
dict(
type='mmdet.ChannelMapper',
in_channels=[128, 256, 512],
out_channels=128,
),
dict(
type='mmdet.DyHead',
in_channels=128,
out_channels=256,
num_blocks=2,
# disable zero_init_offset to follow official implementation
zero_init_offset=False)
]
```

View File

@ -1,9 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from abc import ABCMeta, abstractmethod
from typing import Sequence
from typing import List, Sequence, Union
import torch
import torch.nn as nn
from mmcv.cnn import build_plugin_layer
from mmdet.utils import ConfigType, OptMultiConfig
from mmengine.model import BaseModule
from torch.nn.modules.batchnorm import _BatchNorm
@ -17,6 +18,11 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
Args:
arch_setting (dict): Architecture of BaseBackbone.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
@ -44,6 +50,7 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
input_channels: int = 3,
out_indices: Sequence[int] = (2, 3, 4),
frozen_stages: int = -1,
plugins: Union[dict, List[dict]] = None,
norm_cfg: ConfigType = None,
act_cfg: ConfigType = None,
norm_eval: bool = False,
@ -69,6 +76,7 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
self.norm_eval = norm_eval
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.plugins = plugins
self.stem = self.build_stem_layer()
self.layers = ['stem']
@ -76,6 +84,8 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
for idx, setting in enumerate(arch_setting):
stage = []
stage += self.build_stage_layer(idx, setting)
if plugins is not None:
stage += self.make_stage_plugins(plugins, idx, setting)
self.add_module(f'stage{idx + 1}', nn.Sequential(*stage))
self.layers.append(f'stage{idx + 1}')
@ -94,6 +104,65 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
"""
pass
def make_stage_plugins(self, plugins, idx, setting):
"""Make plugins for backbone ``stage_idx`` th stage.
Currently we support to insert ``context_block``,
``empirical_attention_block``, ``nonlocal_block``, ``dropout_block``
into the backbone.
An example of plugins format could be:
Examples:
>>> plugins=[
... dict(cfg=dict(type='xxx', arg1='xxx'),
... stages=(False, True, True, True)),
... dict(cfg=dict(type='yyy'),
... stages=(True, True, True, True)),
... ]
>>> model = YOLOv5CSPDarknet()
>>> stage_plugins = model.make_stage_plugins(plugins, 0, setting)
>>> assert len(stage_plugins) == 3
Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
.. code-block:: none
conv1 -> conv2 -> conv3 -> yyy
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
.. code-block:: none
conv1 -> conv2 -> conv3 -> xxx -> yyy
Args:
plugins (list[dict]): List of plugins cfg to build. The postfix is
required if multiple same type plugins are inserted.
stage_idx (int): Index of stage to build
If stages is missing, the plugin would be applied to all
stages.
setting (list): The architecture setting of a stage layer.
Returns:
list[nn.Module]: Plugins for current stage
"""
# TODO: It is not general enough to support any channel and needs
# to be refactored
in_channels = int(setting[1] * self.widen_factor)
plugin_layers = []
for plugin in plugins:
plugin = plugin.copy()
stages = plugin.pop('stages', None)
assert stages is None or len(stages) == self.num_stages
if stages is None or stages[idx]:
name, layer = build_plugin_layer(
plugin['cfg'], in_channels=in_channels)
plugin_layers.append(layer)
return plugin_layers
def _freeze_stages(self):
"""Freeze the parameters of the specified stage so that they are no
longer updated."""

View File

@ -1,5 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import List, Tuple, Union
import torch
import torch.nn as nn
@ -20,6 +20,11 @@ class YOLOv5CSPDarknet(BaseBackbone):
Args:
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
Defaults to P5.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
@ -62,6 +67,7 @@ class YOLOv5CSPDarknet(BaseBackbone):
def __init__(self,
arch: str = 'P5',
plugins: Union[dict, List[dict]] = None,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
@ -78,6 +84,7 @@ class YOLOv5CSPDarknet(BaseBackbone):
widen_factor,
input_channels=input_channels,
out_indices=out_indices,
plugins=plugins,
frozen_stages=frozen_stages,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
@ -151,6 +158,11 @@ class YOLOXCSPDarknet(BaseBackbone):
Args:
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
Defaults to P5.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
@ -194,6 +206,7 @@ class YOLOXCSPDarknet(BaseBackbone):
def __init__(self,
arch: str = 'P5',
plugins: Union[dict, List[dict]] = None,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
@ -207,8 +220,8 @@ class YOLOXCSPDarknet(BaseBackbone):
init_cfg: OptMultiConfig = None):
self.spp_kernal_sizes = spp_kernal_sizes
super().__init__(self.arch_settings[arch], deepen_factor, widen_factor,
input_channels, out_indices, frozen_stages, norm_cfg,
act_cfg, norm_eval, init_cfg)
input_channels, out_indices, frozen_stages, plugins,
norm_cfg, act_cfg, norm_eval, init_cfg)
def build_stem_layer(self) -> nn.Module:
"""Build a stem layer."""

View File

@ -1,5 +1,6 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple
from typing import List, Tuple, Union
import torch
import torch.nn as nn
@ -19,6 +20,11 @@ class YOLOv6EfficientRep(BaseBackbone):
Args:
arch (str): Architecture of BaseDarknet, from {P5, P6}.
Defaults to P5.
plugins (list[dict]): List of plugins for stages, each dict contains:
- cfg (dict, required): Cfg dict to build plugin.
- stages (tuple[bool], optional): Stages to apply plugin, length
should be same as 'num_stages'.
deepen_factor (float): Depth multiplier, multiply number of
blocks in CSP layer by this amount. Defaults to 1.0.
widen_factor (float): Width multiplier, multiply number of
@ -62,6 +68,7 @@ class YOLOv6EfficientRep(BaseBackbone):
def __init__(self,
arch: str = 'P5',
plugins: Union[dict, List[dict]] = None,
deepen_factor: float = 1.0,
widen_factor: float = 1.0,
input_channels: int = 3,
@ -80,6 +87,7 @@ class YOLOv6EfficientRep(BaseBackbone):
widen_factor,
input_channels=input_channels,
out_indices=out_indices,
plugins=plugins,
frozen_stages=frozen_stages,
norm_cfg=norm_cfg,
act_cfg=act_cfg,

View File

@ -95,3 +95,22 @@ class TestCSPDarknet(TestCase):
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))
# Test CSPDarknet with Dropout Block
model = module_class(plugins=[
dict(
cfg=dict(type='mmdet.DropBlock', drop_prob=0.1, block_size=3),
stages=(False, False, True, True)),
])
assert len(model.stage1) == 2
assert len(model.stage2) == 2
assert len(model.stage3) == 3 # +DropBlock
assert len(model.stage4) == 4 # +SPPF+DropBlock
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 256, 32, 32))
assert feat[1].shape == torch.Size((1, 512, 16, 16))
assert feat[2].shape == torch.Size((1, 1024, 8, 8))

View File

@ -92,3 +92,22 @@ class TestYOLOv6EfficientRep(TestCase):
assert feat[2].shape == torch.Size((1, 32, 8, 8))
assert feat[3].shape == torch.Size((1, 64, 4, 4))
assert feat[4].shape == torch.Size((1, 128, 2, 2))
# Test YOLOv6EfficientRep with BatchNorm forward
model = YOLOv6EfficientRep(plugins=[
dict(
cfg=dict(type='mmdet.DropBlock', drop_prob=0.1, block_size=3),
stages=(False, False, True, True)),
])
assert len(model.stage1) == 1
assert len(model.stage2) == 1
assert len(model.stage3) == 2 # +DropBlock
assert len(model.stage4) == 3 # +SPPF+DropBlock
model.train()
imgs = torch.randn(1, 3, 256, 256)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 256, 32, 32))
assert feat[1].shape == torch.Size((1, 512, 16, 16))
assert feat[2].shape == torch.Size((1, 1024, 8, 8))