mirror of https://github.com/JDAI-CV/fast-reid.git
34 lines
799 B
Python
34 lines
799 B
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
__all__ = ["NoBiasBatchNorm1d", "IBN"]
|
|
|
|
|
|
def NoBiasBatchNorm1d(in_features):
|
|
bn_layer = nn.BatchNorm1d(in_features)
|
|
bn_layer.bias.requires_grad_(False)
|
|
return bn_layer
|
|
|
|
|
|
class IBN(nn.Module):
|
|
def __init__(self, planes):
|
|
super(IBN, self).__init__()
|
|
half1 = int(planes / 2)
|
|
self.half = half1
|
|
half2 = planes - half1
|
|
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
|
self.BN = nn.BatchNorm2d(half2)
|
|
|
|
def forward(self, x):
|
|
split = torch.split(x, self.half, 1)
|
|
out1 = self.IN(split[0].contiguous())
|
|
out2 = self.BN(split[1].contiguous())
|
|
out = torch.cat((out1, out2), 1)
|
|
return out
|