fast-reid/fastreid/layers/norm.py

34 lines
799 B
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
__all__ = ["NoBiasBatchNorm1d", "IBN"]
def NoBiasBatchNorm1d(in_features):
bn_layer = nn.BatchNorm1d(in_features)
bn_layer.bias.requires_grad_(False)
return bn_layer
class IBN(nn.Module):
def __init__(self, planes):
super(IBN, self).__init__()
half1 = int(planes / 2)
self.half = half1
half2 = planes - half1
self.IN = nn.InstanceNorm2d(half1, affine=True)
self.BN = nn.BatchNorm2d(half2)
def forward(self, x):
split = torch.split(x, self.half, 1)
out1 = self.IN(split[0].contiguous())
out2 = self.BN(split[1].contiguous())
out = torch.cat((out1, out2), 1)
return out