Implement Eca modules
implement ECA module by 1. adopting original eca_module.py into models folder 2. adding use_eca layer besides every instance of SE layerpull/82/head
parent
697e05cb3e
commit
f87fcd7e88
|
@ -105,3 +105,4 @@ venv.bak/
|
|||
*.pth
|
||||
*.gz
|
||||
Untitled.ipynb
|
||||
Testing notebook.ipynb
|
||||
|
|
|
@ -0,0 +1,110 @@
|
|||
'''
|
||||
ECA module from ECAnet
|
||||
original paper: ECA-Net: Efficient Channel Attention for Deep Convolutional Neural Networks
|
||||
https://arxiv.org/abs/1910.03151
|
||||
|
||||
https://github.com/BangguWu/ECANet
|
||||
original ECA model borrowed from original github
|
||||
modified circular ECA implementation and
|
||||
adoptation for use in pytorch image models package
|
||||
by Chris Ha https://github.com/VRandme
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2019 BangguWu, Qilong Wang
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
'''
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn.parameter import Parameter
|
||||
|
||||
|
||||
class eca_layer(nn.Module):
|
||||
"""Constructs a ECA module.
|
||||
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
def __init__(self, channel, k_size=3):
|
||||
super(eca_layer, self).__init__()
|
||||
assert k_size % 2 == 1
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
b, c, h, w = x.size()
|
||||
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
||||
|
||||
class ceca_layer(nn.Module):
|
||||
"""Constructs a circular ECA module.
|
||||
the primary difference is that the conv uses a circular padding rather than zero padding.
|
||||
This is because unlike images, the channels themselves do not have inherent ordering nor
|
||||
locality. Although this module in essence, applies such an assumption, it is unnecessary
|
||||
to limit the channels on either "edge" from being circularly adapted to each other.
|
||||
This will fundamentally increase connectivity and possibly increase performance metrics
|
||||
(accuracy, robustness), without signficantly impacting resource metrics
|
||||
(parameter size, throughput,latency, etc)
|
||||
|
||||
|
||||
Args:
|
||||
channel: Number of channels of the input feature map
|
||||
k_size: Adaptive selection of kernel size
|
||||
"""
|
||||
def __init__(self, channel, k_size=3):
|
||||
super(ceca_layer, self).__init__()
|
||||
assert k_size % 2 == 1
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
#pytorch circular padding mode is bugged as of pytorch 1.4
|
||||
# see https://github.com/pytorch/pytorch/pull/17240
|
||||
#implement manual circular padding
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=k_size, padding = 0, bias=False)
|
||||
self.padding = (k_size - 1) // 2
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
# x: input features with shape [b, c, h, w]
|
||||
b, c, h, w = x.size()
|
||||
# feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
|
||||
#manually implement circular padding
|
||||
y = torch.cat((y[:,:self.padding,:,:], y, y[:,-self.padding:,:,:]),dim=1)
|
||||
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
||||
|
||||
# Multi-scale information fusion
|
||||
y = self.sigmoid(y)
|
||||
|
||||
return x * y.expand_as(x)
|
|
@ -14,6 +14,7 @@ import torch.nn.functional as F
|
|||
from .registry import register_model
|
||||
from .helpers import load_pretrained
|
||||
from .adaptive_avgmax_pool import SelectAdaptivePool2d
|
||||
from .eca_module import eca_layer
|
||||
from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
|
||||
|
||||
|
||||
|
@ -100,6 +101,10 @@ default_cfgs = {
|
|||
'seresnext26tn_32x4d': _cfg(
|
||||
url='https://github.com/rwightman/pytorch-image-models/releases/download/v0.1-weights/seresnext26tn_32x4d-569cb627.pth',
|
||||
interpolation='bicubic'),
|
||||
'ecaresnext26tn_32x4d': _cfg(
|
||||
url='',
|
||||
interpolation='bicubic'),
|
||||
|
||||
}
|
||||
|
||||
|
||||
|
@ -132,7 +137,7 @@ class BasicBlock(nn.Module):
|
|||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=64, use_se=False,
|
||||
cardinality=1, base_width=64, use_se=False, use_eca = False,
|
||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(BasicBlock, self).__init__()
|
||||
|
||||
|
@ -150,7 +155,10 @@ class BasicBlock(nn.Module):
|
|||
first_planes, outplanes, kernel_size=3, padding=previous_dilation,
|
||||
dilation=previous_dilation, bias=False)
|
||||
self.bn2 = norm_layer(outplanes)
|
||||
|
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||
self.eca = eca_layer(outplanes) if use_eca else None
|
||||
|
||||
self.act2 = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
@ -167,6 +175,8 @@ class BasicBlock(nn.Module):
|
|||
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
if self.eca is not None:
|
||||
out = self.eca(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
@ -182,7 +192,7 @@ class Bottleneck(nn.Module):
|
|||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
||||
cardinality=1, base_width=64, use_se=False,
|
||||
cardinality=1, base_width=64, use_se=False, use_eca=False,
|
||||
reduce_first=1, dilation=1, previous_dilation=1, act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d):
|
||||
super(Bottleneck, self).__init__()
|
||||
|
||||
|
@ -200,7 +210,10 @@ class Bottleneck(nn.Module):
|
|||
self.act2 = act_layer(inplace=True)
|
||||
self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
|
||||
self.bn3 = norm_layer(outplanes)
|
||||
|
||||
self.se = SEModule(outplanes, planes // 4) if use_se else None
|
||||
self.eca = eca_layer(outplanes) if use_eca else None
|
||||
|
||||
self.act3 = act_layer(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
@ -222,6 +235,8 @@ class Bottleneck(nn.Module):
|
|||
|
||||
if self.se is not None:
|
||||
out = self.se(out)
|
||||
if self.eca is not None:
|
||||
out = self.eca(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
@ -275,6 +290,8 @@ class ResNet(nn.Module):
|
|||
Number of input (color) channels.
|
||||
use_se : bool, default False
|
||||
Enable Squeeze-Excitation module in blocks
|
||||
use_eca : bool, default False
|
||||
Enable ECA module in blocks
|
||||
cardinality : int, default 1
|
||||
Number of convolution groups for 3x3 conv in Bottleneck.
|
||||
base_width : int, default 64
|
||||
|
@ -303,7 +320,7 @@ class ResNet(nn.Module):
|
|||
global_pool : str, default 'avg'
|
||||
Global pooling type. One of 'avg', 'max', 'avgmax', 'catavgmax'
|
||||
"""
|
||||
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False,
|
||||
def __init__(self, block, layers, num_classes=1000, in_chans=3, use_se=False, use_eca=False,
|
||||
cardinality=1, base_width=64, stem_width=64, stem_type='',
|
||||
block_reduce_first=1, down_kernel_size=1, avg_down=False, output_stride=32,
|
||||
act_layer=nn.ReLU, norm_layer=nn.BatchNorm2d, drop_rate=0.0, global_pool='avg',
|
||||
|
@ -350,7 +367,7 @@ class ResNet(nn.Module):
|
|||
assert output_stride == 32
|
||||
llargs = list(zip(channels, layers, strides, dilations))
|
||||
lkwargs = dict(
|
||||
use_se=use_se, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
||||
use_se=use_se, use_eca=use_eca, reduce_first=block_reduce_first, act_layer=act_layer, norm_layer=norm_layer,
|
||||
avg_down=avg_down, down_kernel_size=down_kernel_size, **block_args)
|
||||
self.layer1 = self._make_layer(block, *llargs[0], **lkwargs)
|
||||
self.layer2 = self._make_layer(block, *llargs[1], **lkwargs)
|
||||
|
@ -375,7 +392,7 @@ class ResNet(nn.Module):
|
|||
nn.init.constant_(m.bias, 0.)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1, dilation=1, reduce_first=1,
|
||||
use_se=False, avg_down=False, down_kernel_size=1, **kwargs):
|
||||
use_se=False, use_eca=False,avg_down=False, down_kernel_size=1, **kwargs):
|
||||
norm_layer = kwargs.get('norm_layer')
|
||||
downsample = None
|
||||
down_kernel_size = 1 if stride == 1 and dilation == 1 else down_kernel_size
|
||||
|
@ -396,7 +413,7 @@ class ResNet(nn.Module):
|
|||
first_dilation = 1 if dilation in (1, 2) else 2
|
||||
bkwargs = dict(
|
||||
cardinality=self.cardinality, base_width=self.base_width, reduce_first=reduce_first,
|
||||
use_se=use_se, **kwargs)
|
||||
use_se=use_se, use_eca=use_eca, **kwargs)
|
||||
layers = [block(
|
||||
self.inplanes, planes, stride, downsample, dilation=first_dilation, previous_dilation=dilation, **bkwargs)]
|
||||
self.inplanes = planes * block.expansion
|
||||
|
@ -944,3 +961,20 @@ def seresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs
|
|||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
||||
@register_model
|
||||
def ecaresnext26tn_32x4d(pretrained=False, num_classes=1000, in_chans=3, **kwargs):
|
||||
"""Constructs a eca-ResNeXt-26-TN model.
|
||||
This is technically a 28 layer ResNet, like a 'D' bag-of-tricks model but with tiered 24, 32, 64 channels
|
||||
in the deep stem. The channel number of the middle stem conv is narrower than the 'T' variant.
|
||||
this model replaces SE module with the ECA module
|
||||
"""
|
||||
default_cfg = default_cfgs['ecaresnext26tn_32x4d']
|
||||
model = ResNet(
|
||||
Bottleneck, [2, 2, 2, 2], cardinality=32, base_width=4,
|
||||
stem_width=32, stem_type='deep_tiered_narrow', avg_down=True, use_eca=True,
|
||||
num_classes=num_classes, in_chans=in_chans, **kwargs)
|
||||
model.default_cfg = default_cfg
|
||||
if pretrained:
|
||||
load_pretrained(model, default_cfg, num_classes, in_chans)
|
||||
return model
|
||||
|
|
Loading…
Reference in New Issue