# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import torch from torch import nn __all__ = [ "BatchNorm", "IBN", "GhostBatchNorm", "get_norm", ] class BatchNorm(nn.BatchNorm2d): def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0): super().__init__(num_features, eps=eps, momentum=momentum) if weight_init is not None: self.weight.data.fill_(weight_init) if bias_init is not None: self.bias.data.fill_(bias_init) self.weight.requires_grad = not weight_freeze self.bias.requires_grad = not bias_freeze class IBN(nn.Module): def __init__(self, planes, bn_norm, num_splits): super(IBN, self).__init__() half1 = int(planes / 2) self.half = half1 half2 = planes - half1 self.IN = nn.InstanceNorm2d(half1, affine=True) self.BN = get_norm(bn_norm, half2, num_splits) def forward(self, x): split = torch.split(x, self.half, 1) out1 = self.IN(split[0].contiguous()) out2 = self.BN(split[1].contiguous()) out = torch.cat((out1, out2), 1) return out class GhostBatchNorm(BatchNorm): def __init__(self, num_features, num_splits=1, **kwargs): super().__init__(num_features, **kwargs) self.num_splits = num_splits self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) def forward(self, input): N, C, H, W = input.shape if self.training or not self.track_running_stats: self.running_mean = self.running_mean.repeat(self.num_splits) self.running_var = self.running_var.repeat(self.num_splits) outputs = nn.functional.batch_norm( input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var, self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), True, self.momentum, self.eps).view(N, C, H, W) self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0) self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0) return outputs else: return nn.functional.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, False, self.momentum, self.eps) def get_norm(norm, out_channels, num_splits=1, **kwargs): """ Args: norm (str or callable): Returns: nn.Module or None: the normalization layer """ if isinstance(norm, str): if len(norm) == 0: return None norm = { "BN": BatchNorm(out_channels, **kwargs), "GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs), # "FrozenBN": FrozenBatchNorm2d, # "GN": lambda channels: nn.GroupNorm(32, channels), # "nnSyncBN": nn.SyncBatchNorm, # keep for debugging }[norm] return norm