fast-reid/fastreid/layers/se_module.py

26 lines
732 B
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
from torch import nn
class SEModule(nn.Module):
def __init__(self, channels, reduciton):
super().__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Conv2d(channels, channels//reduciton, kernel_size=1, padding=0, bias=False)
self.relu = nn.ReLU(True)
self.fc2 = nn.Conv2d(channels//reduciton, channels, kernel_size=1, padding=0, bias=False)
self.sigmoid = nn.Sigmoid()
def forward(self, x):
module_input = x
x = self.avg_pool(x)
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
x = self.sigmoid(x)
return module_input * x