liaoxingyu a2dcd7b4ab feat(layers/norm): add ghost batchnorm
add a get_norm fucntion to easily change normalization between batchnorm, ghost bn and group bn
2020-05-01 09:02:46 +08:00

88 lines
3.1 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
__all__ = [
"BatchNorm",
"IBN",
"GhostBatchNorm",
"get_norm",
]
class BatchNorm(nn.BatchNorm2d):
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
bias_init=0.0):
super().__init__(num_features, eps=eps, momentum=momentum)
if weight_init is not None: self.weight.data.fill_(weight_init)
if bias_init is not None: self.bias.data.fill_(bias_init)
self.weight.requires_grad = not weight_freeze
self.bias.requires_grad = not bias_freeze
class IBN(nn.Module):
def __init__(self, planes, bn_norm, num_splits):
super(IBN, self).__init__()
half1 = int(planes / 2)
self.half = half1
half2 = planes - half1
self.IN = nn.InstanceNorm2d(half1, affine=True)
self.BN = get_norm(bn_norm, half2, num_splits)
def forward(self, x):
split = torch.split(x, self.half, 1)
out1 = self.IN(split[0].contiguous())
out2 = self.BN(split[1].contiguous())
out = torch.cat((out1, out2), 1)
return out
class GhostBatchNorm(BatchNorm):
def __init__(self, num_features, num_splits=1, **kwargs):
super().__init__(num_features, **kwargs)
self.num_splits = num_splits
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
def forward(self, input):
N, C, H, W = input.shape
if self.training or not self.track_running_stats:
self.running_mean = self.running_mean.repeat(self.num_splits)
self.running_var = self.running_var.repeat(self.num_splits)
outputs = nn.functional.batch_norm(
input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
True, self.momentum, self.eps).view(N, C, H, W)
self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
return outputs
else:
return nn.functional.batch_norm(
input, self.running_mean, self.running_var,
self.weight, self.bias, False, self.momentum, self.eps)
def get_norm(norm, out_channels, num_splits=1, **kwargs):
"""
Args:
norm (str or callable):
Returns:
nn.Module or None: the normalization layer
"""
if isinstance(norm, str):
if len(norm) == 0:
return None
norm = {
"BN": BatchNorm(out_channels, **kwargs),
"GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs),
# "FrozenBN": FrozenBatchNorm2d,
# "GN": lambda channels: nn.GroupNorm(32, channels),
# "nnSyncBN": nn.SyncBatchNorm, # keep for debugging
}[norm]
return norm