mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add CBAM for experimentation
This commit is contained in:
parent
d725991870
commit
5e6dbbaf30
97
timm/models/layers/cbam.py
Normal file
97
timm/models/layers/cbam.py
Normal file
@ -0,0 +1,97 @@
|
|||||||
|
""" CBAM (sort-of) Attention
|
||||||
|
|
||||||
|
Experimental impl of CBAM: Convolutional Block Attention Module: https://arxiv.org/abs/1807.06521
|
||||||
|
|
||||||
|
Hacked together by Ross Wightman
|
||||||
|
"""
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import nn as nn
|
||||||
|
from .conv_bn_act import ConvBnAct
|
||||||
|
|
||||||
|
|
||||||
|
class ChannelAttn(nn.Module):
|
||||||
|
""" Original CBAM channel attention module, currently avg + max pool variant only.
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
|
||||||
|
super(ChannelAttn, self).__init__()
|
||||||
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||||
|
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||||
|
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
|
||||||
|
self.act = act_layer(inplace=True)
|
||||||
|
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_avg = self.avg_pool(x)
|
||||||
|
x_max = self.max_pool(x)
|
||||||
|
x_avg = self.fc2(self.act(self.fc1(x_avg)))
|
||||||
|
x_max = self.fc2(self.act(self.fc1(x_max)))
|
||||||
|
x_attn = x_avg + x_max
|
||||||
|
return x * x_attn.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
class LightChannelAttn(ChannelAttn):
|
||||||
|
"""An experimental 'lightweight' that sums avg + max pool first
|
||||||
|
"""
|
||||||
|
def __init__(self, channels, reduction=16):
|
||||||
|
super(LightChannelAttn, self).__init__(channels, reduction)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x)
|
||||||
|
x_attn = self.fc2(self.act(self.fc1(x_pool)))
|
||||||
|
return x * x_attn.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
class SpatialAttn(nn.Module):
|
||||||
|
""" Original CBAM spatial attention module
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size=7):
|
||||||
|
super(SpatialAttn, self).__init__()
|
||||||
|
self.conv = ConvBnAct(2, 1, kernel_size, act_layer=None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_avg = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
x_max = torch.max(x, dim=1, keepdim=True)[0]
|
||||||
|
x_attn = torch.cat([x_avg, x_max], dim=1)
|
||||||
|
x_attn = self.conv(x_attn)
|
||||||
|
return x * x_attn.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
class LightSpatialAttn(nn.Module):
|
||||||
|
"""An experimental 'lightweight' variant that sums avg_pool and max_pool results.
|
||||||
|
"""
|
||||||
|
def __init__(self, kernel_size=7):
|
||||||
|
super(LightSpatialAttn, self).__init__()
|
||||||
|
self.conv = ConvBnAct(1, 1, kernel_size, act_layer=None)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x_avg = torch.mean(x, dim=1, keepdim=True)
|
||||||
|
x_max = torch.max(x, dim=1, keepdim=True)[0]
|
||||||
|
x_attn = 0.5 * x_avg + 0.5 * x_max
|
||||||
|
x_attn = self.conv(x_attn)
|
||||||
|
return x * x_attn.sigmoid()
|
||||||
|
|
||||||
|
|
||||||
|
class CbamModule(nn.Module):
|
||||||
|
def __init__(self, channels, spatial_kernel_size=7):
|
||||||
|
super(CbamModule, self).__init__()
|
||||||
|
self.channel = ChannelAttn(channels)
|
||||||
|
self.spatial = SpatialAttn(spatial_kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.channel(x)
|
||||||
|
x = self.spatial(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
|
||||||
|
class LightCbamModule(nn.Module):
|
||||||
|
def __init__(self, channels, spatial_kernel_size=7):
|
||||||
|
super(LightCbamModule, self).__init__()
|
||||||
|
self.channel = LightChannelAttn(channels)
|
||||||
|
self.spatial = LightSpatialAttn(spatial_kernel_size)
|
||||||
|
|
||||||
|
def forward(self, x):
|
||||||
|
x = self.channel(x)
|
||||||
|
x = self.spatial(x)
|
||||||
|
return x
|
||||||
|
|
@ -5,6 +5,7 @@ Hacked together by Ross Wightman
|
|||||||
import torch
|
import torch
|
||||||
from .se import SEModule
|
from .se import SEModule
|
||||||
from .eca import EcaModule, CecaModule
|
from .eca import EcaModule, CecaModule
|
||||||
|
from .cbam import CbamModule, LightCbamModule
|
||||||
|
|
||||||
|
|
||||||
def create_attn(attn_type, channels, **kwargs):
|
def create_attn(attn_type, channels, **kwargs):
|
||||||
@ -18,6 +19,10 @@ def create_attn(attn_type, channels, **kwargs):
|
|||||||
module_cls = EcaModule
|
module_cls = EcaModule
|
||||||
elif attn_type == 'eca':
|
elif attn_type == 'eca':
|
||||||
module_cls = CecaModule
|
module_cls = CecaModule
|
||||||
|
elif attn_type == 'cbam':
|
||||||
|
module_cls = CbamModule
|
||||||
|
elif attn_type == 'lcbam':
|
||||||
|
module_cls = LightCbamModule
|
||||||
else:
|
else:
|
||||||
assert False, "Invalid attn module (%s)" % attn_type
|
assert False, "Invalid attn module (%s)" % attn_type
|
||||||
elif isinstance(attn_type, bool):
|
elif isinstance(attn_type, bool):
|
||||||
|
Loading…
x
Reference in New Issue
Block a user