Bring in code that should fix 1d circular padding properly, tweaks to ECA impl, using CECA in MobileNetV3 experiment

This commit is contained in:
Ross Wightman 2020-02-25 16:12:41 -08:00
parent 67e759f710
commit ade1ba5fe3
2 changed files with 29 additions and 12 deletions

View File

@ -34,6 +34,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE. SOFTWARE.
""" """
import math import math
import torch
from torch import nn from torch import nn
import torch.nn.functional as F import torch.nn.functional as F
@ -62,14 +63,30 @@ class EfficientChannelAttn(nn.Module):
self.gate_fn = gate_fn self.gate_fn = gate_fn
def forward(self, x): def forward(self, x):
y = self.avg_pool(x) # Feature descriptor on the global spatial information y = self.avg_pool(x)
y = y.view(x.shape[0], 1, -1) # Reshape for convolution y = y.view(x.shape[0], 1, -1) # Reshape 4d -> 3d for 1d convolution
y = self.conv(y) y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1) y = y.view(x.shape[0], -1, 1, 1) # Back to 4d
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y) y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)
return x * y.expand_as(x) return x * y.expand_as(x)
def padding1d_circular(input, pad):
r"""input: torch.tensor([[[0., 1., 2.],
[3., 4., 5.]]])
pad: (1, 2)
output: tensor([[[2., 0., 1., 2., 0., 1.],
[5., 3., 4., 5., 3., 4.]]])
from: https://github.com/pytorch/pytorch/issues/24504
"""
input = torch.cat([input, input[:, :, 0:pad[-1]]], dim=2)
if pad[-1] == 0 and pad[-2] != 0:
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):], input], dim=2)
else:
return torch.cat([input[:, :, -(pad[-1] + pad[-2]):-pad[-1]], input], dim=2)
class CircularEfficientChannelAttn(nn.Module): class CircularEfficientChannelAttn(nn.Module):
"""Constructs a circular ECA module. """Constructs a circular ECA module.
@ -97,20 +114,20 @@ class CircularEfficientChannelAttn(nn.Module):
if channels is not None: if channels is not None:
t = int(abs(math.log(channels, 2) + beta) / gamma) t = int(abs(math.log(channels, 2) + beta) / gamma)
kernel_size = max(t if t % 2 else t + 1, 3) kernel_size = max(t if t % 2 else t + 1, 3)
self.padding = (kernel_size - 1) // 2
# pytorch conv circular padding mode is buggy as of pytorch 1.4, will implement manually
# see https://github.com/pytorch/pytorch/pull/17240
# https://github.com/pytorch/pytorch/issues/24504
p = (kernel_size - 1) // 2
self.padding = (p, p)
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
#pytorch circular padding mode is buggy as of pytorch 1.4
#see https://github.com/pytorch/pytorch/pull/17240
#implement manual circular padding
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False) self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
self.gate_fn = gate_fn self.gate_fn = gate_fn
def forward(self, x): def forward(self, x):
y = self.avg_pool(x) # Feature descriptor on the global spatial information y = self.avg_pool(x)
# Manually implement circular padding, F.pad does not seemed to be bugged y = padding1d_circular(y.view(x.shape[0], 1, -1), self.padding) # manual circular padding
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
y = self.conv(y) y = self.conv(y)
y = y.view(x.shape[0], -1, 1, 1) y = y.view(x.shape[0], -1, 1, 1)
y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y) y = y.sigmoid() if self.gate_fn is None else self.gate_fn(y)

View File

@ -413,7 +413,7 @@ def _gen_mobilenet_v3_eca(variant, channel_multiplier=1.0, pretrained=False, **k
channel_multiplier=channel_multiplier, channel_multiplier=channel_multiplier,
norm_kwargs=resolve_bn_args(kwargs), norm_kwargs=resolve_bn_args(kwargs),
act_layer=act_layer, act_layer=act_layer,
attn_layer='eca', attn_layer='ceca',
attn_kwargs=dict(gate_fn=hard_sigmoid), attn_kwargs=dict(gate_fn=hard_sigmoid),
**kwargs, **kwargs,
) )