# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ from torch import nn class SEModule(nn.Module): def __init__(self, channels, reduciton): super().__init__() self.avg_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Conv2d(channels, channels//reduciton, kernel_size=1, padding=0, bias=False) self.relu = nn.ReLU(True) self.fc2 = nn.Conv2d(channels//reduciton, channels, kernel_size=1, padding=0, bias=False) self.sigmoid = nn.Sigmoid() def forward(self, x): module_input = x x = self.avg_pool(x) x = self.fc1(x) x = self.relu(x) x = self.fc2(x) x = self.sigmoid(x) return module_input * x