mirror of https://github.com/JDAI-CV/fast-reid.git
26 lines
732 B
Python
26 lines
732 B
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
from torch import nn
|
|
|
|
|
|
class SEModule(nn.Module):
|
|
def __init__(self, channels, reduciton):
|
|
super().__init__()
|
|
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
|
self.fc1 = nn.Conv2d(channels, channels//reduciton, kernel_size=1, padding=0, bias=False)
|
|
self.relu = nn.ReLU(True)
|
|
self.fc2 = nn.Conv2d(channels//reduciton, channels, kernel_size=1, padding=0, bias=False)
|
|
self.sigmoid = nn.Sigmoid()
|
|
|
|
def forward(self, x):
|
|
module_input = x
|
|
x = self.avg_pool(x)
|
|
x = self.fc1(x)
|
|
x = self.relu(x)
|
|
x = self.fc2(x)
|
|
x = self.sigmoid(x)
|
|
return module_input * x |