# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import torch from torch import nn __all__ = ["NoBiasBatchNorm1d", "IBN"] def NoBiasBatchNorm1d(in_features): bn_layer = nn.BatchNorm1d(in_features) bn_layer.bias.requires_grad_(False) return bn_layer class IBN(nn.Module): def __init__(self, planes): super(IBN, self).__init__() half1 = int(planes / 2) self.half = half1 half2 = planes - half1 self.IN = nn.InstanceNorm2d(half1, affine=True) self.BN = nn.BatchNorm2d(half2) def forward(self, x): split = torch.split(x, self.half, 1) out1 = self.IN(split[0].contiguous()) out2 = self.BN(split[1].contiguous()) out = torch.cat((out1, out2), 1) return out