From 97679fb8e61bd7fd6e01ba6f4d4870acec8f1c78 Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Thu, 3 May 2018 13:46:41 +0100 Subject: [PATCH] update seresnet --- models/SEResNet.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/models/SEResNet.py b/models/SEResNet.py index 9f12e62..c9b1d84 100644 --- a/models/SEResNet.py +++ b/models/SEResNet.py @@ -91,11 +91,9 @@ class SEModule(nn.Module): def __init__(self, channels, reduction): super(SEModule, self).__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) - self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, - padding=0) + self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0) self.relu = nn.ReLU(inplace=True) - self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, - padding=0) + self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0) self.sigmoid = nn.Sigmoid() def forward(self, x):