[Feature] Add attention module of CBAM (#246)

* Add Attention Modules

* Adde tutorials on the use of the attention module in How_to

* Update how_to.md

Added tutorials on the use of the attention module

* Update attention_layers.py

* Rename attention_layers.py to cbam_layer.py

* Update __init__.py

* Update how_to.md

* Update how_to.md

* Update how_to.md

* Update cbam_layer.py

* Update cbam_layer.py

* Update cbam_layer.py

* Update how_to.md

* update

* add docstring typehint

* add unit test

* refine unit test

* updata how_to

* add plugins directory

* refine plugin.md

* refine cbam.py and plugins.md

* refine cbam.py and plugins.md

* fix error in test_cbam.py

* refine cbam.py and fix error in test_cbam.py

* refine cbam.py and plugins.md

* refine cbam.py and docs
pull/276/head
kitecats 2022-11-10 10:03:04 +08:00 committed by huanghaian
parent 596ffd6617
commit f1279eb3de
7 changed files with 206 additions and 31 deletions

View File

@ -4,30 +4,12 @@
## 给主干网络增加插件
MMYOLO 支持在 Backbone 的不同 Stage 后增加如 `none_local`、`dropblock` 等插件,用户可以直接通过修改 config 文件中 `backbone``plugins` 参数来实现对插件的管理。例如为 `YOLOv5` 增加 `GeneralizedAttention` 插件,其配置文件如下:
```python
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
model = dict(
backbone=dict(
plugins=[
dict(
cfg=dict(
type='mmdet.GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0011',
kv_stride=2),
stages=(False, False, True, True)),
], ))
```
`cfg` 参数表示插件的具体配置, `stages` 参数表示是否在 backbone 对应的 stage 后面增加插件,长度需要和 backbone 的 stage 数量相同。
[更多的插件使用](plugins.md)
## 应用多个 Neck
如果你想堆叠多个 Neck可以直接在配置文件中的 Neck 参数MMYOLO 支持以 `List` 形式拼接多个 Neck 配置,你需要保证上一个 Neck 的输出通道与下一个 Neck 的输入通道相匹配。如需要调整通道,可以插入 `mmdet.ChannelMapper` 模块用来对齐多个 Neck 之间的通道数量。具体配置如下:
如果你想堆叠多个 Neck可以直接在配置文件中的 Neck 参数MMYOLO 支持以 `List` 形式拼接多个 Neck 配置,你需要保证上一个 Neck 的输出通道与下一个 Neck
的输入通道相匹配。如需要调整通道,可以插入 `mmdet.ChannelMapper` 模块用来对齐多个 Neck 之间的通道数量。具体配置如下:
```python
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
@ -42,7 +24,8 @@ model = dict(
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=[256, 512, 1024],
out_channels=[256, 512, 1024], # 因为 out_channels 由 widen_factor 控制YOLOv5PAFPN 的 out_channels = out_channels * widen_factor
out_channels=[256, 512, 1024],
# 因为 out_channels 由 widen_factor 控制YOLOv5PAFPN 的 out_channels = out_channels * widen_factor
num_csp_blocks=3,
norm_cfg=dict(type='BN', momentum=0.03, eps=0.001),
act_cfg=dict(type='SiLU', inplace=True)),
@ -58,14 +41,16 @@ model = dict(
num_blocks=2,
# disable zero_init_offset to follow official implementation
zero_init_offset=False)
]
bbox_head=dict(head_module=dict(in_channels=[512,512,512])) # 因为 out_channels 由 widen_factor 控制YOLOv5HeadModuled 的 in_channels * widen_factor 才会等于最后一个 neck 的 out_channels
],
bbox_head=dict(head_module=dict(in_channels=[512, 512, 512]))
# 因为 out_channels 由 widen_factor 控制YOLOv5HeadModuled 的 in_channels * widen_factor 才会等于最后一个 neck 的 out_channels
)
```
## 跨库使用主干网络
OpenMMLab 2.0 体系中 MMYOLO、MMDetection、MMClassification、MMSegmentation 中的模型注册表都继承自 MMEngine 中的根注册表,允许这些 OpenMMLab 开源库直接使用彼此已经实现的模块。 因此用户可以在 MMYOLO 中使用来自 MMDetection、MMClassification 的主干网络,而无需重新实现。
OpenMMLab 2.0 体系中 MMYOLO、MMDetection、MMClassification、MMSegmentation 中的模型注册表都继承自 MMEngine 中的根注册表,允许这些 OpenMMLab
开源库直接使用彼此已经实现的模块。 因此用户可以在 MMYOLO 中使用来自 MMDetection、MMClassification 的主干网络,而无需重新实现。
```{note}
1. 使用其他主干网络时,你需要保证主干网络的输出通道与 Neck 的输入通道相匹配。
@ -235,7 +220,8 @@ OpenMMLab 2.0 体系中 MMYOLO、MMDetection、MMClassification、MMSegmentation
### 通过 MMClassification 使用 `timm` 中实现的主干网络
由于 MMClassification 提供了 Py**T**orch **Im**age **M**odels (`timm`) 主干网络的封装,用户也可以通过 MMClassification 直接使用 `timm` 中的主干网络。假设想将 `EfficientNet-B1`作为 `YOLOv5` 的主干网络,则配置文件如下:
由于 MMClassification 提供了 Py**T**orch **Im**age **M**odels (`timm`) 主干网络的封装,用户也可以通过 MMClassification 直接使用 `timm`
中的主干网络。假设想将 `EfficientNet-B1`作为 `YOLOv5` 的主干网络,则配置文件如下:
```python
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
@ -251,9 +237,9 @@ channels = [40, 112, 320]
model = dict(
backbone=dict(
_delete_=True, # 将 _base_ 中关于 backbone 的字段删除
type='mmcls.TIMMBackbone', # 使用 mmcls 中的 timm 主干网络
model_name='efficientnet_b1', # 使用 TIMM 中的 efficientnet_b1
_delete_=True, # 将 _base_ 中关于 backbone 的字段删除
type='mmcls.TIMMBackbone', # 使用 mmcls 中的 timm 主干网络
model_name='efficientnet_b1', # 使用 TIMM 中的 efficientnet_b1
features_only=True,
pretrained=True,
out_indices=(2, 3, 4)),
@ -261,13 +247,13 @@ model = dict(
type='YOLOv5PAFPN',
deepen_factor=deepen_factor,
widen_factor=widen_factor,
in_channels=channels, # 注意EfficientNet-B1 输出的3个通道是 [40, 112, 320],和原先的 yolov5-s neck 不匹配,需要更改
in_channels=channels, # 注意EfficientNet-B1 输出的3个通道是 [40, 112, 320],和原先的 yolov5-s neck 不匹配,需要更改
out_channels=channels),
bbox_head=dict(
type='YOLOv5Head',
head_module=dict(
type='YOLOv5HeadModule',
in_channels=channels, # head 部分输入通道也要做相应更改
in_channels=channels, # head 部分输入通道也要做相应更改
widen_factor=widen_factor))
)
```

View File

@ -0,0 +1,35 @@
# 更多的插件使用
MMYOLO 支持在 Backbone 的不同 Stage 后增加如 `none_local`、`dropblock` 等插件,用户可以直接通过修改 config 文件中 `backbone``plugins`
参数来实现对插件的管理。例如为 `YOLOv5` 增加 `GeneralizedAttention` 插件,其配置文件如下:
```python
_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py'
model = dict(
backbone=dict(
plugins=[
dict(
cfg=dict(
type='GeneralizedAttention',
spatial_range=-1,
num_heads=8,
attention_type='0011',
kv_stride=2),
stages=(False, False, True, True))
]))
```
`cfg` 参数表示插件的具体配置, `stages` 参数表示是否在 backbone 对应的 stage 后面增加插件,长度需要和 backbone 的 stage 数量相同。
目前 `MMYOLO` 支持了如下插件:
<details open>
<summary><b>支持的插件</b></summary>
- [x] [CBAM](mmyolo/models/plugins)
- [x] [GeneralizedAttention](https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/bricks/generalized_attention.py#L13)
- [x] [NonLocal2d](https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/bricks/non_local.py#L250)
- [x] [ContextBlock](https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/bricks/context_block.py#L18)
</details>

View File

@ -6,4 +6,5 @@ from .detectors import * # noqa: F401,F403
from .layers import * # noqa: F401,F403
from .losses import * # noqa: F401,F403
from .necks import * # noqa: F401,F403
from .plugins import * # noqa: F401,F403
from .task_modules import * # noqa: F401,F403

View File

@ -0,0 +1,4 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .cbam import CBAM
__all__ = ['CBAM']

View File

@ -0,0 +1,117 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmdet.utils import OptMultiConfig
from mmengine.model import BaseModule
from mmyolo.registry import MODELS
class ChannelAttention(BaseModule):
"""ChannelAttention
Args:
channels (int): The input (and output) channels of the
ChannelAttention.
reduce_ratio (int): Squeeze ratio in ChannelAttention, the intermediate
channel will be ``int(channels/ratio)``. Default: 16.
act_cfg (dict): Config dict for activation layer
Default: (dict(type='ReLU'), dict(type='Sigmoid')).
"""
def __init__(self,
channels: int,
reduce_ratio: int = 16,
act_cfg: dict = dict(type='ReLU')):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.max_pool = nn.AdaptiveMaxPool2d(1)
self.fc = nn.Sequential(
ConvModule(
in_channels=channels,
out_channels=int(channels / reduce_ratio),
kernel_size=1,
stride=1,
conv_cfg=None,
act_cfg=act_cfg),
ConvModule(
in_channels=int(channels / reduce_ratio),
out_channels=channels,
kernel_size=1,
stride=1,
conv_cfg=None,
act_cfg=None))
self.sigmoid = nn.Sigmoid()
def forward(self, x: torch.Tensor) -> torch.Tensor:
avgpool_out = self.fc(self.avg_pool(x))
maxpool_out = self.fc(self.max_pool(x))
out = self.sigmoid(avgpool_out + maxpool_out)
return out
class SpatialAttention(BaseModule):
"""SpatialAttention
Args:
kernel_size (int): The size of the convolution kernel in
SpatialAttention. Default: 7.
"""
def __init__(self, kernel_size: int = 7):
super().__init__()
self.conv = ConvModule(
in_channels=2,
out_channels=1,
kernel_size=kernel_size,
stride=1,
padding=kernel_size // 2,
conv_cfg=None,
act_cfg=dict(type='Sigmoid'))
def forward(self, x: torch.Tensor) -> torch.Tensor:
avg_out = torch.mean(x, dim=1, keepdim=True)
max_out, _ = torch.max(x, dim=1, keepdim=True)
out = torch.cat([avg_out, max_out], dim=1)
out = self.conv(out)
return out
@MODELS.register_module()
class CBAM(BaseModule):
"""Convolutional Block Attention Module.
arxiv link: https://arxiv.org/abs/1807.06521v2
Args:
in_channels (int): The input (and output) channels of the CBAM.
reduce_ratio (int): Squeeze ratio in ChannelAttention, the intermediate
channel will be ``int(channels/ratio)``. Default: 16.
kernel_size (int): The size of the convolution kernel in
SpatialAttention. Default: 7.
act_cfg (dict): Config dict for activation layer in ChannelAttention
Defaults: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Defaults to None.
"""
def __init__(
self,
in_channels: int,
reduce_ratio: int = 16,
kernel_size: int = 7,
act_cfg: dict = dict(type='ReLU'),
init_cfg: OptMultiConfig = None,
):
super().__init__(init_cfg)
self.channel_attention = ChannelAttention(
channels=in_channels, reduce_ratio=reduce_ratio, act_cfg=act_cfg)
self.spatial_attention = SpatialAttention(kernel_size)
def forward(self, x: torch.Tensor) -> torch.Tensor:
out = self.channel_attention(x) * x
out = self.spatial_attention(out) * out
return out

View File

@ -0,0 +1 @@
# Copyright (c) OpenMMLab. All rights reserved.

View File

@ -0,0 +1,31 @@
# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
import torch
from mmyolo.models.plugins import CBAM
from mmyolo.utils import register_all_modules
register_all_modules()
class TestCBAM(TestCase):
def test_forward(self):
tensor_shape = (2, 16, 20, 20)
images = torch.randn(*tensor_shape)
cbam = CBAM(16)
out = cbam(images)
self.assertEqual(out.shape, tensor_shape)
# test other ratio
cbam = CBAM(16, reduce_ratio=8)
out = cbam(images)
self.assertEqual(out.shape, tensor_shape)
# test other act_cfg in ChannelAttention
cbam = CBAM(in_channels=16, act_cfg=dict(type='Sigmoid'))
out = cbam(images)
self.assertEqual(out.shape, tensor_shape)