mirror of https://github.com/open-mmlab/mmyolo.git
[Feature] Add attention module of CBAM (#246)
* Add Attention Modules * Adde tutorials on the use of the attention module in How_to * Update how_to.md Added tutorials on the use of the attention module * Update attention_layers.py * Rename attention_layers.py to cbam_layer.py * Update __init__.py * Update how_to.md * Update how_to.md * Update how_to.md * Update cbam_layer.py * Update cbam_layer.py * Update cbam_layer.py * Update how_to.md * update * add docstring typehint * add unit test * refine unit test * updata how_to * add plugins directory * refine plugin.md * refine cbam.py and plugins.md * refine cbam.py and plugins.md * fix error in test_cbam.py * refine cbam.py and fix error in test_cbam.py * refine cbam.py and plugins.md * refine cbam.py and docspull/276/head
parent
596ffd6617
commit
f1279eb3de
|
@ -4,30 +4,12 @@
|
|||
|
||||
## 给主干网络增加插件
|
||||
|
||||
MMYOLO 支持在 Backbone 的不同 Stage 后增加如 `none_local`、`dropblock` 等插件,用户可以直接通过修改 config 文件中 `backbone` 的 `plugins` 参数来实现对插件的管理。例如为 `YOLOv5` 增加 `GeneralizedAttention` 插件,其配置文件如下:
|
||||
|
||||
```python
|
||||
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
plugins=[
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='mmdet.GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0011',
|
||||
kv_stride=2),
|
||||
stages=(False, False, True, True)),
|
||||
], ))
|
||||
```
|
||||
|
||||
`cfg` 参数表示插件的具体配置, `stages` 参数表示是否在 backbone 对应的 stage 后面增加插件,长度需要和 backbone 的 stage 数量相同。
|
||||
[更多的插件使用](plugins.md)
|
||||
|
||||
## 应用多个 Neck
|
||||
|
||||
如果你想堆叠多个 Neck,可以直接在配置文件中的 Neck 参数,MMYOLO 支持以 `List` 形式拼接多个 Neck 配置,你需要保证上一个 Neck 的输出通道与下一个 Neck 的输入通道相匹配。如需要调整通道,可以插入 `mmdet.ChannelMapper` 模块用来对齐多个 Neck 之间的通道数量。具体配置如下:
|
||||
如果你想堆叠多个 Neck,可以直接在配置文件中的 Neck 参数,MMYOLO 支持以 `List` 形式拼接多个 Neck 配置,你需要保证上一个 Neck 的输出通道与下一个 Neck
|
||||
的输入通道相匹配。如需要调整通道,可以插入 `mmdet.ChannelMapper` 模块用来对齐多个 Neck 之间的通道数量。具体配置如下:
|
||||
|
||||
```python
|
||||
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
|
||||
|
@ -42,7 +24,8 @@ model = dict(
|
|||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
in_channels=[256, 512, 1024],
|
||||
out_channels=[256, 512, 1024], # 因为 out_channels 由 widen_factor 控制,YOLOv5PAFPN 的 out_channels = out_channels * widen_factor
|
||||
out_channels=[256, 512, 1024],
|
||||
# 因为 out_channels 由 widen_factor 控制,YOLOv5PAFPN 的 out_channels = out_channels * widen_factor
|
||||
num_csp_blocks=3,
|
||||
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
|
||||
act_cfg=dict(type='SiLU', inplace=True)),
|
||||
|
@ -58,14 +41,16 @@ model = dict(
|
|||
num_blocks=2,
|
||||
# disable zero_init_offset to follow official implementation
|
||||
zero_init_offset=False)
|
||||
]
|
||||
bbox_head=dict(head_module=dict(in_channels=[512,512,512])) # 因为 out_channels 由 widen_factor 控制,YOLOv5HeadModuled 的 in_channels * widen_factor 才会等于最后一个 neck 的 out_channels
|
||||
],
|
||||
bbox_head=dict(head_module=dict(in_channels=[512, 512, 512]))
|
||||
# 因为 out_channels 由 widen_factor 控制,YOLOv5HeadModuled 的 in_channels * widen_factor 才会等于最后一个 neck 的 out_channels
|
||||
)
|
||||
```
|
||||
|
||||
## 跨库使用主干网络
|
||||
|
||||
OpenMMLab 2.0 体系中 MMYOLO、MMDetection、MMClassification、MMSegmentation 中的模型注册表都继承自 MMEngine 中的根注册表,允许这些 OpenMMLab 开源库直接使用彼此已经实现的模块。 因此用户可以在 MMYOLO 中使用来自 MMDetection、MMClassification 的主干网络,而无需重新实现。
|
||||
OpenMMLab 2.0 体系中 MMYOLO、MMDetection、MMClassification、MMSegmentation 中的模型注册表都继承自 MMEngine 中的根注册表,允许这些 OpenMMLab
|
||||
开源库直接使用彼此已经实现的模块。 因此用户可以在 MMYOLO 中使用来自 MMDetection、MMClassification 的主干网络,而无需重新实现。
|
||||
|
||||
```{note}
|
||||
1. 使用其他主干网络时,你需要保证主干网络的输出通道与 Neck 的输入通道相匹配。
|
||||
|
@ -235,7 +220,8 @@ OpenMMLab 2.0 体系中 MMYOLO、MMDetection、MMClassification、MMSegmentation
|
|||
|
||||
### 通过 MMClassification 使用 `timm` 中实现的主干网络
|
||||
|
||||
由于 MMClassification 提供了 Py**T**orch **Im**age **M**odels (`timm`) 主干网络的封装,用户也可以通过 MMClassification 直接使用 `timm` 中的主干网络。假设想将 `EfficientNet-B1`作为 `YOLOv5` 的主干网络,则配置文件如下:
|
||||
由于 MMClassification 提供了 Py**T**orch **Im**age **M**odels (`timm`) 主干网络的封装,用户也可以通过 MMClassification 直接使用 `timm`
|
||||
中的主干网络。假设想将 `EfficientNet-B1`作为 `YOLOv5` 的主干网络,则配置文件如下:
|
||||
|
||||
```python
|
||||
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
|
||||
|
@ -251,9 +237,9 @@ channels = [40, 112, 320]
|
|||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
_delete_=True, # 将 _base_ 中关于 backbone 的字段删除
|
||||
type='mmcls.TIMMBackbone', # 使用 mmcls 中的 timm 主干网络
|
||||
model_name='efficientnet_b1', # 使用 TIMM 中的 efficientnet_b1
|
||||
_delete_=True, # 将 _base_ 中关于 backbone 的字段删除
|
||||
type='mmcls.TIMMBackbone', # 使用 mmcls 中的 timm 主干网络
|
||||
model_name='efficientnet_b1', # 使用 TIMM 中的 efficientnet_b1
|
||||
features_only=True,
|
||||
pretrained=True,
|
||||
out_indices=(2, 3, 4)),
|
||||
|
@ -261,13 +247,13 @@ model = dict(
|
|||
type='YOLOv5PAFPN',
|
||||
deepen_factor=deepen_factor,
|
||||
widen_factor=widen_factor,
|
||||
in_channels=channels, # 注意:EfficientNet-B1 输出的3个通道是 [40, 112, 320],和原先的 yolov5-s neck 不匹配,需要更改
|
||||
in_channels=channels, # 注意:EfficientNet-B1 输出的3个通道是 [40, 112, 320],和原先的 yolov5-s neck 不匹配,需要更改
|
||||
out_channels=channels),
|
||||
bbox_head=dict(
|
||||
type='YOLOv5Head',
|
||||
head_module=dict(
|
||||
type='YOLOv5HeadModule',
|
||||
in_channels=channels, # head 部分输入通道也要做相应更改
|
||||
in_channels=channels, # head 部分输入通道也要做相应更改
|
||||
widen_factor=widen_factor))
|
||||
)
|
||||
```
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
# 更多的插件使用
|
||||
|
||||
MMYOLO 支持在 Backbone 的不同 Stage 后增加如 `none_local`、`dropblock` 等插件,用户可以直接通过修改 config 文件中 `backbone` 的 `plugins`
|
||||
参数来实现对插件的管理。例如为 `YOLOv5` 增加 `GeneralizedAttention` 插件,其配置文件如下:
|
||||
|
||||
```python
|
||||
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
|
||||
|
||||
model = dict(
|
||||
backbone=dict(
|
||||
plugins=[
|
||||
dict(
|
||||
cfg=dict(
|
||||
type='GeneralizedAttention',
|
||||
spatial_range=-1,
|
||||
num_heads=8,
|
||||
attention_type='0011',
|
||||
kv_stride=2),
|
||||
stages=(False, False, True, True))
|
||||
]))
|
||||
```
|
||||
|
||||
`cfg` 参数表示插件的具体配置, `stages` 参数表示是否在 backbone 对应的 stage 后面增加插件,长度需要和 backbone 的 stage 数量相同。
|
||||
|
||||
目前 `MMYOLO` 支持了如下插件:
|
||||
|
||||
<details open>
|
||||
<summary><b>支持的插件</b></summary>
|
||||
|
||||
- [x] [CBAM](mmyolo/models/plugins)
|
||||
- [x] [GeneralizedAttention](https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/bricks/generalized_attention.py#L13)
|
||||
- [x] [NonLocal2d](https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/bricks/non_local.py#L250)
|
||||
- [x] [ContextBlock](https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/bricks/context_block.py#L18)
|
||||
|
||||
</details>
|
|
@ -6,4 +6,5 @@ from .detectors import * # noqa: F401,F403
|
|||
from .layers import * # noqa: F401,F403
|
||||
from .losses import * # noqa: F401,F403
|
||||
from .necks import * # noqa: F401,F403
|
||||
from .plugins import * # noqa: F401,F403
|
||||
from .task_modules import * # noqa: F401,F403
|
||||
|
|
|
@ -0,0 +1,4 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .cbam import CBAM
|
||||
|
||||
__all__ = ['CBAM']
|
|
@ -0,0 +1,117 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import ConvModule
|
||||
from mmdet.utils import OptMultiConfig
|
||||
from mmengine.model import BaseModule
|
||||
|
||||
from mmyolo.registry import MODELS
|
||||
|
||||
|
||||
class ChannelAttention(BaseModule):
|
||||
"""ChannelAttention
|
||||
Args:
|
||||
channels (int): The input (and output) channels of the
|
||||
ChannelAttention.
|
||||
reduce_ratio (int): Squeeze ratio in ChannelAttention, the intermediate
|
||||
channel will be ``int(channels/ratio)``. Default: 16.
|
||||
act_cfg (dict): Config dict for activation layer
|
||||
Default: (dict(type='ReLU'), dict(type='Sigmoid')).
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
reduce_ratio: int = 16,
|
||||
act_cfg: dict = dict(type='ReLU')):
|
||||
super().__init__()
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
|
||||
self.fc = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=channels,
|
||||
out_channels=int(channels / reduce_ratio),
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=int(channels / reduce_ratio),
|
||||
out_channels=channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
act_cfg=None))
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
avgpool_out = self.fc(self.avg_pool(x))
|
||||
maxpool_out = self.fc(self.max_pool(x))
|
||||
out = self.sigmoid(avgpool_out + maxpool_out)
|
||||
return out
|
||||
|
||||
|
||||
class SpatialAttention(BaseModule):
|
||||
"""SpatialAttention
|
||||
Args:
|
||||
kernel_size (int): The size of the convolution kernel in
|
||||
SpatialAttention. Default: 7.
|
||||
"""
|
||||
|
||||
def __init__(self, kernel_size: int = 7):
|
||||
super().__init__()
|
||||
|
||||
self.conv = ConvModule(
|
||||
in_channels=2,
|
||||
out_channels=1,
|
||||
kernel_size=kernel_size,
|
||||
stride=1,
|
||||
padding=kernel_size // 2,
|
||||
conv_cfg=None,
|
||||
act_cfg=dict(type='Sigmoid'))
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
avg_out = torch.mean(x, dim=1, keepdim=True)
|
||||
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
||||
out = torch.cat([avg_out, max_out], dim=1)
|
||||
out = self.conv(out)
|
||||
return out
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class CBAM(BaseModule):
|
||||
"""Convolutional Block Attention Module.
|
||||
|
||||
arxiv link: https://arxiv.org/abs/1807.06521v2
|
||||
Args:
|
||||
in_channels (int): The input (and output) channels of the CBAM.
|
||||
reduce_ratio (int): Squeeze ratio in ChannelAttention, the intermediate
|
||||
channel will be ``int(channels/ratio)``. Default: 16.
|
||||
kernel_size (int): The size of the convolution kernel in
|
||||
SpatialAttention. Default: 7.
|
||||
act_cfg (dict): Config dict for activation layer in ChannelAttention
|
||||
Defaults: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
reduce_ratio: int = 16,
|
||||
kernel_size: int = 7,
|
||||
act_cfg: dict = dict(type='ReLU'),
|
||||
init_cfg: OptMultiConfig = None,
|
||||
):
|
||||
super().__init__(init_cfg)
|
||||
self.channel_attention = ChannelAttention(
|
||||
channels=in_channels, reduce_ratio=reduce_ratio, act_cfg=act_cfg)
|
||||
|
||||
self.spatial_attention = SpatialAttention(kernel_size)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self.channel_attention(x) * x
|
||||
out = self.spatial_attention(out) * out
|
||||
return out
|
|
@ -0,0 +1 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
|
@ -0,0 +1,31 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
|
||||
from unittest import TestCase
|
||||
|
||||
import torch
|
||||
|
||||
from mmyolo.models.plugins import CBAM
|
||||
from mmyolo.utils import register_all_modules
|
||||
|
||||
register_all_modules()
|
||||
|
||||
|
||||
class TestCBAM(TestCase):
|
||||
|
||||
def test_forward(self):
|
||||
tensor_shape = (2, 16, 20, 20)
|
||||
|
||||
images = torch.randn(*tensor_shape)
|
||||
cbam = CBAM(16)
|
||||
out = cbam(images)
|
||||
self.assertEqual(out.shape, tensor_shape)
|
||||
|
||||
# test other ratio
|
||||
cbam = CBAM(16, reduce_ratio=8)
|
||||
out = cbam(images)
|
||||
self.assertEqual(out.shape, tensor_shape)
|
||||
|
||||
# test other act_cfg in ChannelAttention
|
||||
cbam = CBAM(in_channels=16, act_cfg=dict(type='Sigmoid'))
|
||||
out = cbam(images)
|
||||
self.assertEqual(out.shape, tensor_shape)
|
Loading…
Reference in New Issue