mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Bring in code that should fix 1d circular padding properly, tweaks to ECA impl, using CECA in MobileNetV3 experiment
This commit is contained in:
parent
67e759f710
commit
ade1ba5fe3
@ -34,6 +34,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
"""
|
||||
import math
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
@ -62,14 +63,30 @@ class EfficientChannelAttn(nn.Module):
|
||||
self.gate_fn = gate_fn
|
||||
|
||||
def forward(self, x):
|
||||
y = self.avg_pool(x) # Feature descriptor on the global spatial information
|
||||
y = y.view(x.shape[0], 1, -1) # Reshape for convolution
|
||||
y = self.avg_pool(x)
|
||||
y = y.view(x.shape[0], 1, -1) # Reshape 4d -> 3d for 1d convolution
|
||||
y = self.conv(y)
|
||||
y = y.view(x.shape[0], -1, 1, 1)
|
||||
y = y.view(x.shape[0], -1, 1, 1) # Back to 4d
|
||||
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
|
||||
return x * y.expand_as(x)
|
||||
|
||||
|
||||
def padding1d_circular(input, pad):
|
||||
r"""input: torch.tensor([[[0., 1., 2.],
|
||||
[3., 4., 5.]]])
|
||||
pad: (1, 2)
|
||||
output: tensor([[[2., 0., 1., 2., 0., 1.],
|
||||
[5., 3., 4., 5., 3., 4.]]])
|
||||
|
||||
from: https://github.com/pytorch/pytorch/issues/24504
|
||||
"""
|
||||
input = torch.cat([input, input[:, :, 0:pad[-1]]], dim=2)
|
||||
if pad[-1] == 0 and pad[-2] != 0:
|
||||
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):], input], dim=2)
|
||||
else:
|
||||
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):-pad[-1]], input], dim=2)
|
||||
|
||||
|
||||
class CircularEfficientChannelAttn(nn.Module):
|
||||
"""Constructs a circular ECA module.
|
||||
|
||||
@ -97,20 +114,20 @@ class CircularEfficientChannelAttn(nn.Module):
|
||||
if channels is not None:
|
||||
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||
self.padding = (kernel_size - 1) // 2
|
||||
|
||||
# pytorch conv circular padding mode is buggy as of pytorch 1.4, will implement manually
|
||||
# see https://github.com/pytorch/pytorch/pull/17240
|
||||
# https://github.com/pytorch/pytorch/issues/24504
|
||||
p = (kernel_size - 1) // 2
|
||||
self.padding = (p, p)
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
#pytorch circular padding mode is buggy as of pytorch 1.4
|
||||
#see https://github.com/pytorch/pytorch/pull/17240
|
||||
|
||||
#implement manual circular padding
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
|
||||
self.gate_fn = gate_fn
|
||||
|
||||
def forward(self, x):
|
||||
y = self.avg_pool(x) # Feature descriptor on the global spatial information
|
||||
# Manually implement circular padding, F.pad does not seemed to be bugged
|
||||
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
|
||||
y = self.avg_pool(x)
|
||||
y = padding1d_circular(y.view(x.shape[0], 1, -1), self.padding) # manual circular padding
|
||||
y = self.conv(y)
|
||||
y = y.view(x.shape[0], -1, 1, 1)
|
||||
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
|
||||
|
@ -413,7 +413,7 @@ def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **k
|
||||
channel_multiplier=channel_multiplier,
|
||||
norm_kwargs=resolve_bn_args(kwargs),
|
||||
act_layer=act_layer,
|
||||
attn_layer='eca',
|
||||
attn_layer='ceca',
|
||||
attn_kwargs=dict(gate_fn=hard_sigmoid),
|
||||
**kwargs,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user