update seresnet

pull/17/head
KaiyangZhou 2018-05-03 13:46:41 +01:00
parent 70e6a0b5ff
commit 97679fb8e6
1 changed files with 2 additions and 4 deletions

View File

@ -91,11 +91,9 @@ class SEModule(nn.Module):
def __init__(self, channels, reduction): def __init__(self, channels, reduction):
super(SEModule, self).__init__() super(SEModule, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1) self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)
padding=0)
self.relu = nn.ReLU(inplace=True) self.relu = nn.ReLU(inplace=True)
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)
padding=0)
self.sigmoid = nn.Sigmoid() self.sigmoid = nn.Sigmoid()
def forward(self, x): def forward(self, x):