From f1279eb3deab1ed3642385ce5f3d380b631139c3 Mon Sep 17 00:00:00 2001 From: kitecats <90194592+kitecats@users.noreply.github.com> Date: Thu, 10 Nov 2022 10:03:04 +0800 Subject: [PATCH] [Feature] Add attention module of CBAM (#246) * Add Attention Modules * Adde tutorials on the use of the attention module in How_to * Update how_to.md Added tutorials on the use of the attention module * Update attention_layers.py * Rename attention_layers.py to cbam_layer.py * Update __init__.py * Update how_to.md * Update how_to.md * Update how_to.md * Update cbam_layer.py * Update cbam_layer.py * Update cbam_layer.py * Update how_to.md * update * add docstring typehint * add unit test * refine unit test * updata how_to * add plugins directory * refine plugin.md * refine cbam.py and plugins.md * refine cbam.py and plugins.md * fix error in test_cbam.py * refine cbam.py and fix error in test_cbam.py * refine cbam.py and plugins.md * refine cbam.py and docs --- docs/zh_cn/advanced_guides/how_to.md | 48 +++----- docs/zh_cn/advanced_guides/plugins.md | 35 ++++++ mmyolo/models/__init__.py | 1 + mmyolo/models/plugins/__init__.py | 4 + mmyolo/models/plugins/cbam.py | 117 ++++++++++++++++++++ tests/test_models/test_plugins/__init__.py | 1 + tests/test_models/test_plugins/test_cbam.py | 31 ++++++ 7 files changed, 206 insertions(+), 31 deletions(-) create mode 100644 docs/zh_cn/advanced_guides/plugins.md create mode 100644 mmyolo/models/plugins/__init__.py create mode 100644 mmyolo/models/plugins/cbam.py create mode 100644 tests/test_models/test_plugins/__init__.py create mode 100644 tests/test_models/test_plugins/test_cbam.py diff --git a/docs/zh_cn/advanced_guides/how_to.md b/docs/zh_cn/advanced_guides/how_to.md index b0834020..00bd8246 100644 --- a/docs/zh_cn/advanced_guides/how_to.md +++ b/docs/zh_cn/advanced_guides/how_to.md @@ -4,30 +4,12 @@ ## 给主干网络增加插件 -MMYOLO 支持在 Backbone 的不同 Stage 后增加如 `none_local`、`dropblock` 等插件,用户可以直接通过修改 config 文件中 `backbone` 的 `plugins` 参数来实现对插件的管理。例如为 `YOLOv5` 增加 `GeneralizedAttention` 插件,其配置文件如下: - -```python -_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py' - -model = dict( - backbone=dict( - plugins=[ - dict( - cfg=dict( - type='mmdet.GeneralizedAttention', - spatial_range=-1, - num_heads=8, - attention_type='0011', - kv_stride=2), - stages=(False, False, True, True)), - ], )) -``` - -`cfg` 参数表示插件的具体配置, `stages` 参数表示是否在 backbone 对应的 stage 后面增加插件,长度需要和 backbone 的 stage 数量相同。 +[更多的插件使用](plugins.md) ## 应用多个 Neck -如果你想堆叠多个 Neck,可以直接在配置文件中的 Neck 参数,MMYOLO 支持以 `List` 形式拼接多个 Neck 配置,你需要保证上一个 Neck 的输出通道与下一个 Neck 的输入通道相匹配。如需要调整通道,可以插入 `mmdet.ChannelMapper` 模块用来对齐多个 Neck 之间的通道数量。具体配置如下: +如果你想堆叠多个 Neck,可以直接在配置文件中的 Neck 参数,MMYOLO 支持以 `List` 形式拼接多个 Neck 配置,你需要保证上一个 Neck 的输出通道与下一个 Neck +的输入通道相匹配。如需要调整通道,可以插入 `mmdet.ChannelMapper` 模块用来对齐多个 Neck 之间的通道数量。具体配置如下: ```python _base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py' @@ -42,7 +24,8 @@ model = dict( deepen_factor=deepen_factor, widen_factor=widen_factor, in_channels=[256, 512, 1024], - out_channels=[256, 512, 1024], # 因为 out_channels 由 widen_factor 控制,YOLOv5PAFPN 的 out_channels = out_channels * widen_factor + out_channels=[256, 512, 1024], + # 因为 out_channels 由 widen_factor 控制,YOLOv5PAFPN 的 out_channels = out_channels * widen_factor num_csp_blocks=3, norm_cfg=dict(type='BN', momentum=0.03, eps=0.001), act_cfg=dict(type='SiLU', inplace=True)), @@ -58,14 +41,16 @@ model = dict( num_blocks=2, # disable zero_init_offset to follow official implementation zero_init_offset=False) - ] - bbox_head=dict(head_module=dict(in_channels=[512,512,512])) # 因为 out_channels 由 widen_factor 控制,YOLOv5HeadModuled 的 in_channels * widen_factor 才会等于最后一个 neck 的 out_channels + ], + bbox_head=dict(head_module=dict(in_channels=[512, 512, 512])) + # 因为 out_channels 由 widen_factor 控制,YOLOv5HeadModuled 的 in_channels * widen_factor 才会等于最后一个 neck 的 out_channels ) ``` ## 跨库使用主干网络 -OpenMMLab 2.0 体系中 MMYOLO、MMDetection、MMClassification、MMSegmentation 中的模型注册表都继承自 MMEngine 中的根注册表,允许这些 OpenMMLab 开源库直接使用彼此已经实现的模块。 因此用户可以在 MMYOLO 中使用来自 MMDetection、MMClassification 的主干网络,而无需重新实现。 +OpenMMLab 2.0 体系中 MMYOLO、MMDetection、MMClassification、MMSegmentation 中的模型注册表都继承自 MMEngine 中的根注册表,允许这些 OpenMMLab +开源库直接使用彼此已经实现的模块。 因此用户可以在 MMYOLO 中使用来自 MMDetection、MMClassification 的主干网络,而无需重新实现。 ```{note} 1. 使用其他主干网络时,你需要保证主干网络的输出通道与 Neck 的输入通道相匹配。 @@ -235,7 +220,8 @@ OpenMMLab 2.0 体系中 MMYOLO、MMDetection、MMClassification、MMSegmentation ### 通过 MMClassification 使用 `timm` 中实现的主干网络 -由于 MMClassification 提供了 Py**T**orch **Im**age **M**odels (`timm`) 主干网络的封装,用户也可以通过 MMClassification 直接使用 `timm` 中的主干网络。假设想将 `EfficientNet-B1`作为 `YOLOv5` 的主干网络,则配置文件如下: +由于 MMClassification 提供了 Py**T**orch **Im**age **M**odels (`timm`) 主干网络的封装,用户也可以通过 MMClassification 直接使用 `timm` +中的主干网络。假设想将 `EfficientNet-B1`作为 `YOLOv5` 的主干网络,则配置文件如下: ```python _base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py' @@ -251,9 +237,9 @@ channels = [40, 112, 320] model = dict( backbone=dict( - _delete_=True, # 将 _base_ 中关于 backbone 的字段删除 - type='mmcls.TIMMBackbone', # 使用 mmcls 中的 timm 主干网络 - model_name='efficientnet_b1', # 使用 TIMM 中的 efficientnet_b1 + _delete_=True, # 将 _base_ 中关于 backbone 的字段删除 + type='mmcls.TIMMBackbone', # 使用 mmcls 中的 timm 主干网络 + model_name='efficientnet_b1', # 使用 TIMM 中的 efficientnet_b1 features_only=True, pretrained=True, out_indices=(2, 3, 4)), @@ -261,13 +247,13 @@ model = dict( type='YOLOv5PAFPN', deepen_factor=deepen_factor, widen_factor=widen_factor, - in_channels=channels, # 注意:EfficientNet-B1 输出的3个通道是 [40, 112, 320],和原先的 yolov5-s neck 不匹配,需要更改 + in_channels=channels, # 注意:EfficientNet-B1 输出的3个通道是 [40, 112, 320],和原先的 yolov5-s neck 不匹配,需要更改 out_channels=channels), bbox_head=dict( type='YOLOv5Head', head_module=dict( type='YOLOv5HeadModule', - in_channels=channels, # head 部分输入通道也要做相应更改 + in_channels=channels, # head 部分输入通道也要做相应更改 widen_factor=widen_factor)) ) ``` diff --git a/docs/zh_cn/advanced_guides/plugins.md b/docs/zh_cn/advanced_guides/plugins.md new file mode 100644 index 00000000..7fcbc2bf --- /dev/null +++ b/docs/zh_cn/advanced_guides/plugins.md @@ -0,0 +1,35 @@ +# 更多的插件使用 + +MMYOLO 支持在 Backbone 的不同 Stage 后增加如 `none_local`、`dropblock` 等插件,用户可以直接通过修改 config 文件中 `backbone` 的 `plugins` +参数来实现对插件的管理。例如为 `YOLOv5` 增加 `GeneralizedAttention` 插件,其配置文件如下: + +```python +_base_ = './yolov5_s-v61_syncbn_8xb16-300e_coco.py' + +model = dict( + backbone=dict( + plugins=[ + dict( + cfg=dict( + type='GeneralizedAttention', + spatial_range=-1, + num_heads=8, + attention_type='0011', + kv_stride=2), + stages=(False, False, True, True)) + ])) +``` + +`cfg` 参数表示插件的具体配置, `stages` 参数表示是否在 backbone 对应的 stage 后面增加插件,长度需要和 backbone 的 stage 数量相同。 + +目前 `MMYOLO` 支持了如下插件: + +
+支持的插件 + +- [x] [CBAM](mmyolo/models/plugins) +- [x] [GeneralizedAttention](https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/bricks/generalized_attention.py#L13) +- [x] [NonLocal2d](https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/bricks/non_local.py#L250) +- [x] [ContextBlock](https://github.com/open-mmlab/mmcv/blob/2.x/mmcv/cnn/bricks/context_block.py#L18) + +
diff --git a/mmyolo/models/__init__.py b/mmyolo/models/__init__.py index b290017a..51c37f04 100644 --- a/mmyolo/models/__init__.py +++ b/mmyolo/models/__init__.py @@ -6,4 +6,5 @@ from .detectors import * # noqa: F401,F403 from .layers import * # noqa: F401,F403 from .losses import * # noqa: F401,F403 from .necks import * # noqa: F401,F403 +from .plugins import * # noqa: F401,F403 from .task_modules import * # noqa: F401,F403 diff --git a/mmyolo/models/plugins/__init__.py b/mmyolo/models/plugins/__init__.py new file mode 100644 index 00000000..497233ac --- /dev/null +++ b/mmyolo/models/plugins/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from .cbam import CBAM + +__all__ = ['CBAM'] diff --git a/mmyolo/models/plugins/cbam.py b/mmyolo/models/plugins/cbam.py new file mode 100644 index 00000000..512cf21f --- /dev/null +++ b/mmyolo/models/plugins/cbam.py @@ -0,0 +1,117 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +import torch +import torch.nn as nn +from mmcv.cnn import ConvModule +from mmdet.utils import OptMultiConfig +from mmengine.model import BaseModule + +from mmyolo.registry import MODELS + + +class ChannelAttention(BaseModule): + """ChannelAttention + Args: + channels (int): The input (and output) channels of the + ChannelAttention. + reduce_ratio (int): Squeeze ratio in ChannelAttention, the intermediate + channel will be ``int(channels/ratio)``. Default: 16. + act_cfg (dict): Config dict for activation layer + Default: (dict(type='ReLU'), dict(type='Sigmoid')). + """ + + def __init__(self, + channels: int, + reduce_ratio: int = 16, + act_cfg: dict = dict(type='ReLU')): + super().__init__() + + self.avg_pool = nn.AdaptiveAvgPool2d(1) + self.max_pool = nn.AdaptiveMaxPool2d(1) + + self.fc = nn.Sequential( + ConvModule( + in_channels=channels, + out_channels=int(channels / reduce_ratio), + kernel_size=1, + stride=1, + conv_cfg=None, + act_cfg=act_cfg), + ConvModule( + in_channels=int(channels / reduce_ratio), + out_channels=channels, + kernel_size=1, + stride=1, + conv_cfg=None, + act_cfg=None)) + self.sigmoid = nn.Sigmoid() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + avgpool_out = self.fc(self.avg_pool(x)) + maxpool_out = self.fc(self.max_pool(x)) + out = self.sigmoid(avgpool_out + maxpool_out) + return out + + +class SpatialAttention(BaseModule): + """SpatialAttention + Args: + kernel_size (int): The size of the convolution kernel in + SpatialAttention. Default: 7. + """ + + def __init__(self, kernel_size: int = 7): + super().__init__() + + self.conv = ConvModule( + in_channels=2, + out_channels=1, + kernel_size=kernel_size, + stride=1, + padding=kernel_size // 2, + conv_cfg=None, + act_cfg=dict(type='Sigmoid')) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + avg_out = torch.mean(x, dim=1, keepdim=True) + max_out, _ = torch.max(x, dim=1, keepdim=True) + out = torch.cat([avg_out, max_out], dim=1) + out = self.conv(out) + return out + + +@MODELS.register_module() +class CBAM(BaseModule): + """Convolutional Block Attention Module. + + arxiv link: https://arxiv.org/abs/1807.06521v2 + Args: + in_channels (int): The input (and output) channels of the CBAM. + reduce_ratio (int): Squeeze ratio in ChannelAttention, the intermediate + channel will be ``int(channels/ratio)``. Default: 16. + kernel_size (int): The size of the convolution kernel in + SpatialAttention. Default: 7. + act_cfg (dict): Config dict for activation layer in ChannelAttention + Defaults: dict(type='ReLU'). + init_cfg (dict or list[dict], optional): Initialization config dict. + Defaults to None. + """ + + def __init__( + self, + in_channels: int, + reduce_ratio: int = 16, + kernel_size: int = 7, + act_cfg: dict = dict(type='ReLU'), + init_cfg: OptMultiConfig = None, + ): + super().__init__(init_cfg) + self.channel_attention = ChannelAttention( + channels=in_channels, reduce_ratio=reduce_ratio, act_cfg=act_cfg) + + self.spatial_attention = SpatialAttention(kernel_size) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + out = self.channel_attention(x) * x + out = self.spatial_attention(out) * out + return out diff --git a/tests/test_models/test_plugins/__init__.py b/tests/test_models/test_plugins/__init__.py new file mode 100644 index 00000000..ef101fec --- /dev/null +++ b/tests/test_models/test_plugins/__init__.py @@ -0,0 +1 @@ +# Copyright (c) OpenMMLab. All rights reserved. diff --git a/tests/test_models/test_plugins/test_cbam.py b/tests/test_models/test_plugins/test_cbam.py new file mode 100644 index 00000000..4af547c0 --- /dev/null +++ b/tests/test_models/test_plugins/test_cbam.py @@ -0,0 +1,31 @@ +# Copyright (c) OpenMMLab. All rights reserved. + +from unittest import TestCase + +import torch + +from mmyolo.models.plugins import CBAM +from mmyolo.utils import register_all_modules + +register_all_modules() + + +class TestCBAM(TestCase): + + def test_forward(self): + tensor_shape = (2, 16, 20, 20) + + images = torch.randn(*tensor_shape) + cbam = CBAM(16) + out = cbam(images) + self.assertEqual(out.shape, tensor_shape) + + # test other ratio + cbam = CBAM(16, reduce_ratio=8) + out = cbam(images) + self.assertEqual(out.shape, tensor_shape) + + # test other act_cfg in ChannelAttention + cbam = CBAM(in_channels=16, act_cfg=dict(type='Sigmoid')) + out = cbam(images) + self.assertEqual(out.shape, tensor_shape)