mirror of https://github.com/open-mmlab/mmyolo.git
[Feature]Support plugin layers for backbone (#75)
* add plugin layer * add docstring. add config * clean the config * add how_to-doc * add ut * update * update * update lint * fix yoloxbug * update * update * update doc * update * update * update * update * updatepull/137/head
parent
05a5b2aaa2
commit
d705f1c57b
|
@ -1 +1,58 @@
|
|||
# How to
|
||||
This tutorial collects answers to any `How to xxx with MMYOLO`. Feel free to update this doc if you meet new questions about `How to` and find the answers!
|
||||
|
||||
# Add plugins to the BackBone network
|
||||
|
||||
MMYOLO supports adding plug-ins such as none_local and dropout after different stages of BackBone. Users can directly manage plug-ins by modifying the plugins parameter of backbone in config. For example, add GeneralizedAttention plug-ins for `YOLOv5`. The configuration files are as follows:
|
||||
|
||||
```python
|
||||
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
plugins=[
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='mmdet.GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0011',
|
||||
kv_stride=2),
|
||||
stages=(False, False, True, True)),
|
||||
], ))
|
||||
```
|
||||
|
||||
`cfg` parameter indicates the specific configuration of the plug-in. The `stages` parameter indicates whether to add plug-ins after the corresponding stage of the backbone. The length of list `stages` must be the same as the number of backbone stages.
|
||||
|
||||
## Apply multiple Necks
|
||||
|
||||
If you want to stack multiple Necks, you can directly set the Neck parameters in the config. MMYOLO supports concatenating multiple Necks in the form of `List`. You need to ensure that the output channel of the previous Neck matches the input channel of the next Neck. If you need to adjust the number of channels, you can insert the `mmdet.ChannelMapper` module to align the number of channels between multiple Necks. The specific configuration is as follows:
|
||||
|
||||
```python
|
||||
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
|
||||
|
||||
model = dict(
|
||||
type='YOLODetector',
|
||||
neck=[
|
||||
dict(
|
||||
type='YOLOv5PAFPN',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, 1024],
|
||||
out_channels=[256, 512, 1024],
|
||||
num_csp_blocks=3,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True)),
|
||||
dict(
|
||||
type='mmdet.ChannelMapper',
|
||||
in_channels=[128, 256, 512],
|
||||
out_channels=128,
|
||||
),
|
||||
dict(
|
||||
type='mmdet.DyHead',
|
||||
in_channels=128,
|
||||
out_channels=256,
|
||||
num_blocks=2,
|
||||
# disable zero_init_offset to follow official implementation
|
||||
zero_init_offset=False)
|
||||
]
|
||||
```
|
||||
|
|
|
@ -1 +1,58 @@
|
|||
# how to
|
||||
本教程收集了任何如何使用 MMYOLO 进行 xxx 的答案。 如果您遇到有关`如何做`的问题及答案,请随时更新此文档!
|
||||
|
||||
## 给骨干网络增加插件
|
||||
|
||||
MMYOLO 支持在 BackBone 的不同 Stage 后增加如 `none_local`、`dropblock` 等插件,用户可以直接通过修改 config 文件中 `backbone` 的 `plugins` 参数来实现对插件的管理。例如为 `YOLOv5` 增加 `GeneralizedAttention` 插件,其配置文件如下:
|
||||
|
||||
```python
|
||||
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
plugins=[
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='mmdet.GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0011',
|
||||
kv_stride=2),
|
||||
stages=(False, False, True, True)),
|
||||
], ))
|
||||
```
|
||||
|
||||
`cfg` 参数表示插件的具体配置, `stages` 参数表示是否在 backbone 对应的 stage 后面增加插件,长度需要和 backbone 的 stage 数量相同。
|
||||
|
||||
## 应用多个 Neck
|
||||
|
||||
如果你想堆叠多个 Neck,可以直接在配置文件中的 Neck 参数,MMYOLO 支持以 `List` 形式拼接多个 Neck 配置,你需要保证的是上一个 Neck 的输出通道与下一个 Neck 的输入通道相匹配。如需要调整通道,可以插入 `mmdet.ChannelMapper` 模块用来对齐多个 Neck 之间的通道数量。具体配置如下:
|
||||
|
||||
```python
|
||||
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
|
||||
|
||||
model = dict(
|
||||
type='YOLODetector',
|
||||
neck=[
|
||||
dict(
|
||||
type='YOLOv5PAFPN',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, 1024],
|
||||
out_channels=[256, 512, 1024],
|
||||
num_csp_blocks=3,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True)),
|
||||
dict(
|
||||
type='mmdet.ChannelMapper',
|
||||
in_channels=[128, 256, 512],
|
||||
out_channels=128,
|
||||
),
|
||||
dict(
|
||||
type='mmdet.DyHead',
|
||||
in_channels=128,
|
||||
out_channels=256,
|
||||
num_blocks=2,
|
||||
# disable zero_init_offset to follow official implementation
|
||||
zero_init_offset=False)
|
||||
]
|
||||
```
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from abc import ABCMeta, abstractmethod
|
||||
from typing import Sequence
|
||||
from typing import List, Sequence, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_plugin_layer
|
||||
from mmdet.utils import ConfigType, OptMultiConfig
|
||||
from mmengine.model import BaseModule
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
@ -17,6 +18,11 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
|
|||
|
||||
Args:
|
||||
arch_setting (dict): Architecture of BaseBackbone.
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
|
@ -44,6 +50,7 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
|
|||
input_channels: int = 3,
|
||||
out_indices: Sequence[int] = (2, 3, 4),
|
||||
frozen_stages: int = -1,
|
||||
plugins: Union[dict, List[dict]] = None,
|
||||
norm_cfg: ConfigType = None,
|
||||
act_cfg: ConfigType = None,
|
||||
norm_eval: bool = False,
|
||||
|
@ -69,6 +76,7 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
|
|||
self.norm_eval = norm_eval
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.plugins = plugins
|
||||
|
||||
self.stem = self.build_stem_layer()
|
||||
self.layers = ['stem']
|
||||
|
@ -76,6 +84,8 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
|
|||
for idx, setting in enumerate(arch_setting):
|
||||
stage = []
|
||||
stage += self.build_stage_layer(idx, setting)
|
||||
if plugins is not None:
|
||||
stage += self.make_stage_plugins(plugins, idx, setting)
|
||||
self.add_module(f'stage{idx + 1}', nn.Sequential(*stage))
|
||||
self.layers.append(f'stage{idx + 1}')
|
||||
|
||||
|
@ -94,6 +104,65 @@ class BaseBackbone(BaseModule, metaclass=ABCMeta):
|
|||
"""
|
||||
pass
|
||||
|
||||
def make_stage_plugins(self, plugins, idx, setting):
|
||||
"""Make plugins for backbone ``stage_idx`` th stage.
|
||||
|
||||
Currently we support to insert ``context_block``,
|
||||
``empirical_attention_block``, ``nonlocal_block``, ``dropout_block``
|
||||
into the backbone.
|
||||
|
||||
|
||||
An example of plugins format could be:
|
||||
|
||||
Examples:
|
||||
>>> plugins=[
|
||||
... dict(cfg=dict(type='xxx', arg1='xxx'),
|
||||
... stages=(False, True, True, True)),
|
||||
... dict(cfg=dict(type='yyy'),
|
||||
... stages=(True, True, True, True)),
|
||||
... ]
|
||||
>>> model = YOLOv5CSPDarknet()
|
||||
>>> stage_plugins = model.make_stage_plugins(plugins, 0, setting)
|
||||
>>> assert len(stage_plugins) == 3
|
||||
|
||||
Suppose ``stage_idx=0``, the structure of blocks in the stage would be:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
conv1 -> conv2 -> conv3 -> yyy
|
||||
|
||||
Suppose 'stage_idx=1', the structure of blocks in the stage would be:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
conv1 -> conv2 -> conv3 -> xxx -> yyy
|
||||
|
||||
|
||||
Args:
|
||||
plugins (list[dict]): List of plugins cfg to build. The postfix is
|
||||
required if multiple same type plugins are inserted.
|
||||
stage_idx (int): Index of stage to build
|
||||
If stages is missing, the plugin would be applied to all
|
||||
stages.
|
||||
setting (list): The architecture setting of a stage layer.
|
||||
|
||||
Returns:
|
||||
list[nn.Module]: Plugins for current stage
|
||||
"""
|
||||
# TODO: It is not general enough to support any channel and needs
|
||||
# to be refactored
|
||||
in_channels = int(setting[1] * self.widen_factor)
|
||||
plugin_layers = []
|
||||
for plugin in plugins:
|
||||
plugin = plugin.copy()
|
||||
stages = plugin.pop('stages', None)
|
||||
assert stages is None or len(stages) == self.num_stages
|
||||
if stages is None or stages[idx]:
|
||||
name, layer = build_plugin_layer(
|
||||
plugin['cfg'], in_channels=in_channels)
|
||||
plugin_layers.append(layer)
|
||||
return plugin_layers
|
||||
|
||||
def _freeze_stages(self):
|
||||
"""Freeze the parameters of the specified stage so that they are no
|
||||
longer updated."""
|
||||
|
|
|
@ -1,5 +1,5 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -20,6 +20,11 @@ class YOLOv5CSPDarknet(BaseBackbone):
|
|||
Args:
|
||||
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
|
||||
Defaults to P5.
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
|
@ -62,6 +67,7 @@ class YOLOv5CSPDarknet(BaseBackbone):
|
|||
|
||||
def __init__(self,
|
||||
arch: str = 'P5',
|
||||
plugins: Union[dict, List[dict]] = None,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
input_channels: int = 3,
|
||||
|
@ -78,6 +84,7 @@ class YOLOv5CSPDarknet(BaseBackbone):
|
|||
widen_factor,
|
||||
input_channels=input_channels,
|
||||
out_indices=out_indices,
|
||||
plugins=plugins,
|
||||
frozen_stages=frozen_stages,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
|
@ -151,6 +158,11 @@ class YOLOXCSPDarknet(BaseBackbone):
|
|||
Args:
|
||||
arch (str): Architecture of CSP-Darknet, from {P5, P6}.
|
||||
Defaults to P5.
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
|
@ -194,6 +206,7 @@ class YOLOXCSPDarknet(BaseBackbone):
|
|||
|
||||
def __init__(self,
|
||||
arch: str = 'P5',
|
||||
plugins: Union[dict, List[dict]] = None,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
input_channels: int = 3,
|
||||
|
@ -207,8 +220,8 @@ class YOLOXCSPDarknet(BaseBackbone):
|
|||
init_cfg: OptMultiConfig = None):
|
||||
self.spp_kernal_sizes = spp_kernal_sizes
|
||||
super().__init__(self.arch_settings[arch], deepen_factor, widen_factor,
|
||||
input_channels, out_indices, frozen_stages, norm_cfg,
|
||||
act_cfg, norm_eval, init_cfg)
|
||||
input_channels, out_indices, frozen_stages, plugins,
|
||||
norm_cfg, act_cfg, norm_eval, init_cfg)
|
||||
|
||||
def build_stem_layer(self) -> nn.Module:
|
||||
"""Build a stem layer."""
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from typing import Tuple
|
||||
|
||||
from typing import List, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
@ -19,6 +20,11 @@ class YOLOv6EfficientRep(BaseBackbone):
|
|||
Args:
|
||||
arch (str): Architecture of BaseDarknet, from {P5, P6}.
|
||||
Defaults to P5.
|
||||
plugins (list[dict]): List of plugins for stages, each dict contains:
|
||||
|
||||
- cfg (dict, required): Cfg dict to build plugin.
|
||||
- stages (tuple[bool], optional): Stages to apply plugin, length
|
||||
should be same as 'num_stages'.
|
||||
deepen_factor (float): Depth multiplier, multiply number of
|
||||
blocks in CSP layer by this amount. Defaults to 1.0.
|
||||
widen_factor (float): Width multiplier, multiply number of
|
||||
|
@ -62,6 +68,7 @@ class YOLOv6EfficientRep(BaseBackbone):
|
|||
|
||||
def __init__(self,
|
||||
arch: str = 'P5',
|
||||
plugins: Union[dict, List[dict]] = None,
|
||||
deepen_factor: float = 1.0,
|
||||
widen_factor: float = 1.0,
|
||||
input_channels: int = 3,
|
||||
|
@ -80,6 +87,7 @@ class YOLOv6EfficientRep(BaseBackbone):
|
|||
widen_factor,
|
||||
input_channels=input_channels,
|
||||
out_indices=out_indices,
|
||||
plugins=plugins,
|
||||
frozen_stages=frozen_stages,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
|
|
|
@ -95,3 +95,22 @@ class TestCSPDarknet(TestCase):
|
|||
assert feat[2].shape == torch.Size((1, 32, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 64, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 128, 2, 2))
|
||||
|
||||
# Test CSPDarknet with Dropout Block
|
||||
model = module_class(plugins=[
|
||||
dict(
|
||||
cfg=dict(type='mmdet.DropBlock', drop_prob=0.1, block_size=3),
|
||||
stages=(False, False, True, True)),
|
||||
])
|
||||
|
||||
assert len(model.stage1) == 2
|
||||
assert len(model.stage2) == 2
|
||||
assert len(model.stage3) == 3 # +DropBlock
|
||||
assert len(model.stage4) == 4 # +SPPF+DropBlock
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 256, 256)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size((1, 256, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 512, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 1024, 8, 8))
|
||||
|
|
|
@ -92,3 +92,22 @@ class TestYOLOv6EfficientRep(TestCase):
|
|||
assert feat[2].shape == torch.Size((1, 32, 8, 8))
|
||||
assert feat[3].shape == torch.Size((1, 64, 4, 4))
|
||||
assert feat[4].shape == torch.Size((1, 128, 2, 2))
|
||||
|
||||
# Test YOLOv6EfficientRep with BatchNorm forward
|
||||
model = YOLOv6EfficientRep(plugins=[
|
||||
dict(
|
||||
cfg=dict(type='mmdet.DropBlock', drop_prob=0.1, block_size=3),
|
||||
stages=(False, False, True, True)),
|
||||
])
|
||||
|
||||
assert len(model.stage1) == 1
|
||||
assert len(model.stage2) == 1
|
||||
assert len(model.stage3) == 2 # +DropBlock
|
||||
assert len(model.stage4) == 3 # +SPPF+DropBlock
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 256, 256)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size((1, 256, 32, 32))
|
||||
assert feat[1].shape == torch.Size((1, 512, 16, 16))
|
||||
assert feat[2].shape == torch.Size((1, 1024, 8, 8))
|
||||
|
|
Loading…
Reference in New Issue