# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import logging import torch import torch.nn.functional as F from torch import nn try: from apex import parallel except ImportError: raise ImportError("Please install apex from https://www.github.com/nvidia/apex to run model with syncBN") __all__ = ["IBN", "get_norm"] class BatchNorm(nn.BatchNorm2d): def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0, **kwargs): super().__init__(num_features, eps=eps, momentum=momentum) if weight_init is not None: nn.init.constant_(self.weight, weight_init) if bias_init is not None: nn.init.constant_(self.bias, bias_init) self.weight.requires_grad_(not weight_freeze) self.bias.requires_grad_(not bias_freeze) class SyncBatchNorm(parallel.SyncBatchNorm): def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, bias_init=0.0): super().__init__(num_features, eps=eps, momentum=momentum) if weight_init is not None: nn.init.constant_(self.weight, weight_init) if bias_init is not None: nn.init.constant_(self.bias, bias_init) self.weight.requires_grad_(not weight_freeze) self.bias.requires_grad_(not bias_freeze) class IBN(nn.Module): def __init__(self, planes, bn_norm, **kwargs): super(IBN, self).__init__() half1 = int(planes / 2) self.half = half1 half2 = planes - half1 self.IN = nn.InstanceNorm2d(half1, affine=True) self.BN = get_norm(bn_norm, half2, **kwargs) def forward(self, x): split = torch.split(x, self.half, 1) out1 = self.IN(split[0].contiguous()) out2 = self.BN(split[1].contiguous()) out = torch.cat((out1, out2), 1) return out class GhostBatchNorm(BatchNorm): def __init__(self, num_features, num_splits=1, **kwargs): super().__init__(num_features, **kwargs) self.num_splits = num_splits self.register_buffer('running_mean', torch.zeros(num_features)) self.register_buffer('running_var', torch.ones(num_features)) def forward(self, input): N, C, H, W = input.shape if self.training or not self.track_running_stats: self.running_mean = self.running_mean.repeat(self.num_splits) self.running_var = self.running_var.repeat(self.num_splits) outputs = F.batch_norm( input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var, self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), True, self.momentum, self.eps).view(N, C, H, W) self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0) self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0) return outputs else: return F.batch_norm( input, self.running_mean, self.running_var, self.weight, self.bias, False, self.momentum, self.eps) class FrozenBatchNorm(nn.Module): """ BatchNorm2d where the batch statistics and the affine parameters are fixed. It contains non-trainable buffers called "weight" and "bias", "running_mean", "running_var", initialized to perform identity transformation. The pre-trained backbone models from Caffe2 only contain "weight" and "bias", which are computed from the original four parameters of BN. The affine transform `x * weight + bias` will perform the equivalent computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. When loading a backbone model from Caffe2, "running_mean" and "running_var" will be left unchanged as identity transformation. Other pre-trained backbone models may contain all 4 parameters. The forward is implemented by `F.batch_norm(..., training=False)`. """ _version = 3 def __init__(self, num_features, eps=1e-5, **kwargs): super().__init__() self.num_features = num_features self.eps = eps self.register_buffer("weight", torch.ones(num_features)) self.register_buffer("bias", torch.zeros(num_features)) self.register_buffer("running_mean", torch.zeros(num_features)) self.register_buffer("running_var", torch.ones(num_features) - eps) def forward(self, x): if x.requires_grad: # When gradients are needed, F.batch_norm will use extra memory # because its backward op computes gradients for weight/bias as well. scale = self.weight * (self.running_var + self.eps).rsqrt() bias = self.bias - self.running_mean * scale scale = scale.reshape(1, -1, 1, 1) bias = bias.reshape(1, -1, 1, 1) return x * scale + bias else: # When gradients are not needed, F.batch_norm is a single fused op # and provide more optimization opportunities. return F.batch_norm( x, self.running_mean, self.running_var, self.weight, self.bias, training=False, eps=self.eps, ) def _load_from_state_dict( self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ): version = local_metadata.get("version", None) if version is None or version < 2: # No running_mean/var in early versions # This will silent the warnings if prefix + "running_mean" not in state_dict: state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) if prefix + "running_var" not in state_dict: state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) if version is not None and version < 3: logger = logging.getLogger(__name__) logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))) # In version < 3, running_var are used without +eps. state_dict[prefix + "running_var"] -= self.eps super()._load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) def __repr__(self): return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) @classmethod def convert_frozen_batchnorm(cls, module): """ Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. Args: module (torch.nn.Module): Returns: If module is BatchNorm/SyncBatchNorm, returns a new module. Otherwise, in-place convert module and return it. Similar to convert_sync_batchnorm in https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py """ bn_module = nn.modules.batchnorm bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) res = module if isinstance(module, bn_module): res = cls(module.num_features) if module.affine: res.weight.data = module.weight.data.clone().detach() res.bias.data = module.bias.data.clone().detach() res.running_mean.data = module.running_mean.data res.running_var.data = module.running_var.data res.eps = module.eps else: for name, child in module.named_children(): new_child = cls.convert_frozen_batchnorm(child) if new_child is not child: res.add_module(name, new_child) return res def get_norm(norm, out_channels, **kwargs): """ Args: norm (str or callable): either one of BN, GhostBN, FrozenBN, GN or SyncBN; or a callable that thakes a channel number and returns the normalization layer as a nn.Module out_channels: number of channels for normalization layer Returns: nn.Module or None: the normalization layer """ if isinstance(norm, str): if len(norm) == 0: return None norm = { "BN": BatchNorm, "syncBN": SyncBatchNorm, "GhostBN": GhostBatchNorm, "FrozenBN": FrozenBatchNorm, "GN": lambda channels, **args: nn.GroupNorm(32, channels), }[norm] return norm(out_channels, **kwargs)