mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
add a get_norm fucntion to easily change normalization between batchnorm, ghost bn and group bn
88 lines
3.1 KiB
Python
88 lines
3.1 KiB
Python
# encoding: utf-8
|
|
"""
|
|
@author: liaoxingyu
|
|
@contact: sherlockliao01@gmail.com
|
|
"""
|
|
|
|
import torch
|
|
from torch import nn
|
|
|
|
__all__ = [
|
|
"BatchNorm",
|
|
"IBN",
|
|
"GhostBatchNorm",
|
|
"get_norm",
|
|
]
|
|
|
|
|
|
class BatchNorm(nn.BatchNorm2d):
|
|
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
|
bias_init=0.0):
|
|
super().__init__(num_features, eps=eps, momentum=momentum)
|
|
if weight_init is not None: self.weight.data.fill_(weight_init)
|
|
if bias_init is not None: self.bias.data.fill_(bias_init)
|
|
self.weight.requires_grad = not weight_freeze
|
|
self.bias.requires_grad = not bias_freeze
|
|
|
|
|
|
class IBN(nn.Module):
|
|
def __init__(self, planes, bn_norm, num_splits):
|
|
super(IBN, self).__init__()
|
|
half1 = int(planes / 2)
|
|
self.half = half1
|
|
half2 = planes - half1
|
|
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
|
self.BN = get_norm(bn_norm, half2, num_splits)
|
|
|
|
def forward(self, x):
|
|
split = torch.split(x, self.half, 1)
|
|
out1 = self.IN(split[0].contiguous())
|
|
out2 = self.BN(split[1].contiguous())
|
|
out = torch.cat((out1, out2), 1)
|
|
return out
|
|
|
|
|
|
class GhostBatchNorm(BatchNorm):
|
|
def __init__(self, num_features, num_splits=1, **kwargs):
|
|
super().__init__(num_features, **kwargs)
|
|
self.num_splits = num_splits
|
|
self.register_buffer('running_mean', torch.zeros(num_features))
|
|
self.register_buffer('running_var', torch.ones(num_features))
|
|
|
|
def forward(self, input):
|
|
N, C, H, W = input.shape
|
|
if self.training or not self.track_running_stats:
|
|
self.running_mean = self.running_mean.repeat(self.num_splits)
|
|
self.running_var = self.running_var.repeat(self.num_splits)
|
|
outputs = nn.functional.batch_norm(
|
|
input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
|
|
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
|
|
True, self.momentum, self.eps).view(N, C, H, W)
|
|
self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
|
|
self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
|
|
return outputs
|
|
else:
|
|
return nn.functional.batch_norm(
|
|
input, self.running_mean, self.running_var,
|
|
self.weight, self.bias, False, self.momentum, self.eps)
|
|
|
|
|
|
def get_norm(norm, out_channels, num_splits=1, **kwargs):
|
|
"""
|
|
Args:
|
|
norm (str or callable):
|
|
Returns:
|
|
nn.Module or None: the normalization layer
|
|
"""
|
|
if isinstance(norm, str):
|
|
if len(norm) == 0:
|
|
return None
|
|
norm = {
|
|
"BN": BatchNorm(out_channels, **kwargs),
|
|
"GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs),
|
|
# "FrozenBN": FrozenBatchNorm2d,
|
|
# "GN": lambda channels: nn.GroupNorm(32, channels),
|
|
# "nnSyncBN": nn.SyncBatchNorm, # keep for debugging
|
|
}[norm]
|
|
return norm
|