mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
AdaptiveAvgPool2d -> mean((2,3)) for all SE/attn layers to avoid NaN with AMP + channels_last layout. See https://github.com/pytorch/pytorch/issues/43992
This commit is contained in:
parent
c2cd1a332e
commit
110a7c4982
@ -106,20 +106,18 @@ class SqueezeExcite(nn.Module):
|
||||
def __init__(self, in_chs, se_ratio=0.25, reduced_base_chs=None,
|
||||
act_layer=nn.ReLU, gate_fn=sigmoid, divisor=1, **_):
|
||||
super(SqueezeExcite, self).__init__()
|
||||
self.gate_fn = gate_fn
|
||||
reduced_chs = make_divisible((reduced_base_chs or in_chs) * se_ratio, divisor)
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv_reduce = nn.Conv2d(in_chs, reduced_chs, 1, bias=True)
|
||||
self.act1 = act_layer(inplace=True)
|
||||
self.conv_expand = nn.Conv2d(reduced_chs, in_chs, 1, bias=True)
|
||||
self.gate_fn = gate_fn
|
||||
|
||||
def forward(self, x):
|
||||
x_se = self.avg_pool(x)
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.conv_reduce(x_se)
|
||||
x_se = self.act1(x_se)
|
||||
x_se = self.conv_expand(x_se)
|
||||
x = x * self.gate_fn(x_se)
|
||||
return x
|
||||
return x * self.gate_fn(x_se)
|
||||
|
||||
|
||||
class ConvBnAct(nn.Module):
|
||||
|
@ -10,6 +10,7 @@ Hacked together by / Copyright 2020 Ross Wightman
|
||||
|
||||
import torch
|
||||
from torch import nn as nn
|
||||
import torch.nn.functional as F
|
||||
from .conv_bn_act import ConvBnAct
|
||||
|
||||
|
||||
@ -18,15 +19,13 @@ class ChannelAttn(nn.Module):
|
||||
"""
|
||||
def __init__(self, channels, reduction=16, act_layer=nn.ReLU):
|
||||
super(ChannelAttn, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
||||
self.fc1 = nn.Conv2d(channels, channels // reduction, 1, bias=False)
|
||||
self.act = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(channels // reduction, channels, 1, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
x_avg = self.avg_pool(x)
|
||||
x_max = self.max_pool(x)
|
||||
x_avg = x.mean((2, 3), keepdim=True)
|
||||
x_max = F.adaptive_max_pool2d(x, 1)
|
||||
x_avg = self.fc2(self.act(self.fc1(x_avg)))
|
||||
x_max = self.fc2(self.act(self.fc1(x_max)))
|
||||
x_attn = x_avg + x_max
|
||||
@ -40,7 +39,7 @@ class LightChannelAttn(ChannelAttn):
|
||||
super(LightChannelAttn, self).__init__(channels, reduction)
|
||||
|
||||
def forward(self, x):
|
||||
x_pool = 0.5 * self.avg_pool(x) + 0.5 * self.max_pool(x)
|
||||
x_pool = 0.5 * x.mean((2, 3), keepdim=True) + 0.5 * F.adaptive_max_pool2d(x, 1)
|
||||
x_attn = self.fc2(self.act(self.fc1(x_pool)))
|
||||
return x * x_attn.sigmoid()
|
||||
|
||||
|
@ -52,22 +52,15 @@ class EcaModule(nn.Module):
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||
super(EcaModule, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
|
||||
if channels is not None:
|
||||
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
||||
|
||||
def forward(self, x):
|
||||
# Feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
# Reshape for convolution
|
||||
y = y.view(x.shape[0], 1, -1)
|
||||
# Two different branches of ECA module
|
||||
y = x.mean((2, 3)).view(x.shape[0], 1, -1) # view for 1d conv
|
||||
y = self.conv(y)
|
||||
# Multi-scale information fusion
|
||||
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||
return x * y.expand_as(x)
|
||||
|
||||
@ -95,30 +88,20 @@ class CecaModule(nn.Module):
|
||||
def __init__(self, channels=None, kernel_size=3, gamma=2, beta=1):
|
||||
super(CecaModule, self).__init__()
|
||||
assert kernel_size % 2 == 1
|
||||
|
||||
if channels is not None:
|
||||
t = int(abs(math.log(channels, 2) + beta) / gamma)
|
||||
kernel_size = max(t if t % 2 else t + 1, 3)
|
||||
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
#pytorch circular padding mode is buggy as of pytorch 1.4
|
||||
#see https://github.com/pytorch/pytorch/pull/17240
|
||||
|
||||
#implement manual circular padding
|
||||
# PyTorch circular padding mode is buggy as of pytorch 1.4
|
||||
# see https://github.com/pytorch/pytorch/pull/17240
|
||||
# implement manual circular padding
|
||||
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=0, bias=False)
|
||||
self.padding = (kernel_size - 1) // 2
|
||||
|
||||
def forward(self, x):
|
||||
# Feature descriptor on the global spatial information
|
||||
y = self.avg_pool(x)
|
||||
|
||||
y = x.mean((2, 3)).view(x.shape[0], 1, -1)
|
||||
# Manually implement circular padding, F.pad does not seemed to be bugged
|
||||
y = F.pad(y.view(x.shape[0], 1, -1), (self.padding, self.padding), mode='circular')
|
||||
|
||||
# Two different branches of ECA module
|
||||
y = F.pad(y, (self.padding, self.padding), mode='circular')
|
||||
y = self.conv(y)
|
||||
|
||||
# Multi-scale information fusion
|
||||
y = y.view(x.shape[0], -1, 1, 1).sigmoid()
|
||||
|
||||
return x * y.expand_as(x)
|
||||
|
@ -1,40 +1,36 @@
|
||||
from torch import nn as nn
|
||||
from .create_act import get_act_fn
|
||||
from .create_act import create_act_layer
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, min_channels=8, reduction_channels=None,
|
||||
gate_fn='sigmoid'):
|
||||
gate_layer='sigmoid'):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
reduction_channels = reduction_channels or max(channels // reduction, min_channels)
|
||||
self.fc1 = nn.Conv2d(
|
||||
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
||||
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
|
||||
self.act = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(
|
||||
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
||||
self.gate_fn = get_act_fn(gate_fn)
|
||||
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = self.avg_pool(x)
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.fc1(x_se)
|
||||
x_se = self.act(x_se)
|
||||
x_se = self.fc2(x_se)
|
||||
return x * self.gate_fn(x_se)
|
||||
return x * self.gate(x_se)
|
||||
|
||||
|
||||
class EffectiveSEModule(nn.Module):
|
||||
""" 'Effective Squeeze-Excitation
|
||||
From `CenterMask : Real-Time Anchor-Free Instance Segmentation` - https://arxiv.org/abs/1911.06667
|
||||
"""
|
||||
def __init__(self, channels, gate_fn='hard_sigmoid'):
|
||||
def __init__(self, channels, gate_layer='hard_sigmoid'):
|
||||
super(EffectiveSEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc = nn.Conv2d(channels, channels, kernel_size=1, padding=0)
|
||||
self.gate_fn = get_act_fn(gate_fn)
|
||||
self.gate = create_act_layer(gate_layer, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = self.avg_pool(x)
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.fc(x_se)
|
||||
return x * self.gate_fn(x_se, inplace=True)
|
||||
return x * self.gate(x_se)
|
||||
|
@ -27,7 +27,6 @@ class SelectiveKernelAttn(nn.Module):
|
||||
"""
|
||||
super(SelectiveKernelAttn, self).__init__()
|
||||
self.num_paths = num_paths
|
||||
self.pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc_reduce = nn.Conv2d(channels, attn_channels, kernel_size=1, bias=False)
|
||||
self.bn = norm_layer(attn_channels)
|
||||
self.act = act_layer(inplace=True)
|
||||
@ -35,8 +34,7 @@ class SelectiveKernelAttn(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
assert x.shape[1] == self.num_paths
|
||||
x = torch.sum(x, dim=1)
|
||||
x = self.pool(x)
|
||||
x = x.sum(1).mean((2, 3), keepdim=True)
|
||||
x = self.fc_reduce(x)
|
||||
x = self.bn(x)
|
||||
x = self.act(x)
|
||||
|
@ -59,18 +59,15 @@ class SEWithNorm(nn.Module):
|
||||
def __init__(self, channels, reduction=16, act_layer=nn.ReLU, divisor=1, reduction_channels=None,
|
||||
gate_layer='sigmoid'):
|
||||
super(SEWithNorm, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
reduction_channels = reduction_channels or make_divisible(channels // reduction, divisor=divisor)
|
||||
self.fc1 = nn.Conv2d(
|
||||
channels, reduction_channels, kernel_size=1, padding=0, bias=True)
|
||||
self.fc1 = nn.Conv2d(channels, reduction_channels, kernel_size=1, bias=True)
|
||||
self.bn = nn.BatchNorm2d(reduction_channels)
|
||||
self.act = act_layer(inplace=True)
|
||||
self.fc2 = nn.Conv2d(
|
||||
reduction_channels, channels, kernel_size=1, padding=0, bias=True)
|
||||
self.fc2 = nn.Conv2d(reduction_channels, channels, kernel_size=1, bias=True)
|
||||
self.gate = create_act_layer(gate_layer)
|
||||
|
||||
def forward(self, x):
|
||||
x_se = self.avg_pool(x)
|
||||
x_se = x.mean((2, 3), keepdim=True)
|
||||
x_se = self.fc1(x_se)
|
||||
x_se = self.bn(x_se)
|
||||
x_se = self.act(x_se)
|
||||
|
@ -71,17 +71,14 @@ class SEModule(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc1 = nn.Conv2d(
|
||||
channels, channels // reduction, kernel_size=1, padding=0)
|
||||
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Conv2d(
|
||||
channels // reduction, channels, kernel_size=1, padding=0)
|
||||
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = x.mean((2, 3), keepdim=True)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
|
@ -56,10 +56,9 @@ class FastGlobalAvgPool2d(nn.Module):
|
||||
|
||||
def forward(self, x):
|
||||
if self.flatten:
|
||||
in_size = x.size()
|
||||
return x.view((in_size[0], in_size[1], -1)).mean(dim=2)
|
||||
return x.mean((2, 3))
|
||||
else:
|
||||
return x.view(x.size(0), x.size(1), -1).mean(-1).view(x.size(0), x.size(1), 1, 1)
|
||||
return x.mean((2, 3), keepdim=True)
|
||||
|
||||
def feat_mult(self):
|
||||
return 1
|
||||
|
Loading…
x
Reference in New Issue
Block a user