mirror of
https://github.com/JDAI-CV/fast-reid.git
synced 2025-06-03 14:50:47 +08:00
feat: add SyncBN and GroupNorm suppor
This commit is contained in:
parent
5ae3d4fecf
commit
0356ef8c5c
@ -2,4 +2,7 @@
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
"""
|
||||
|
||||
|
||||
__version__ = "0.1.0"
|
@ -7,7 +7,7 @@ from torch import nn
|
||||
|
||||
from .batch_drop import BatchDrop
|
||||
from .attention import *
|
||||
from .norm import *
|
||||
from .batch_norm import *
|
||||
from .context_block import ContextBlock
|
||||
from .non_local import Non_local
|
||||
from .se_layer import SELayer
|
||||
|
206
fastreid/layers/batch_norm.py
Normal file
206
fastreid/layers/batch_norm.py
Normal file
@ -0,0 +1,206 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
import logging
|
||||
import torch.nn.functional as F
|
||||
from torch import nn
|
||||
|
||||
__all__ = [
|
||||
"BatchNorm",
|
||||
"IBN",
|
||||
"GhostBatchNorm",
|
||||
"FrozenBatchNorm",
|
||||
"get_norm",
|
||||
]
|
||||
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
if weight_init is not None: self.weight.data.fill_(weight_init)
|
||||
if bias_init is not None: self.bias.data.fill_(bias_init)
|
||||
self.weight.requires_grad_(not weight_freeze)
|
||||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
|
||||
class SyncBatchNorm(nn.SyncBatchNorm):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
if weight_init is not None: self.weight.data.fill_(weight_init)
|
||||
if bias_init is not None: self.bias.data.fill_(bias_init)
|
||||
self.weight.requires_grad_(not weight_freeze)
|
||||
self.bias.requires_grad_(not bias_freeze)
|
||||
|
||||
|
||||
class IBN(nn.Module):
|
||||
def __init__(self, planes, bn_norm, num_splits):
|
||||
super(IBN, self).__init__()
|
||||
half1 = int(planes / 2)
|
||||
self.half = half1
|
||||
half2 = planes - half1
|
||||
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
||||
self.BN = get_norm(bn_norm, half2, num_splits)
|
||||
|
||||
def forward(self, x):
|
||||
split = torch.split(x, self.half, 1)
|
||||
out1 = self.IN(split[0].contiguous())
|
||||
out2 = self.BN(split[1].contiguous())
|
||||
out = torch.cat((out1, out2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class GhostBatchNorm(BatchNorm):
|
||||
def __init__(self, num_features, num_splits=1, **kwargs):
|
||||
super().__init__(num_features, **kwargs)
|
||||
self.num_splits = num_splits
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
|
||||
def forward(self, input):
|
||||
N, C, H, W = input.shape
|
||||
if self.training or not self.track_running_stats:
|
||||
self.running_mean = self.running_mean.repeat(self.num_splits)
|
||||
self.running_var = self.running_var.repeat(self.num_splits)
|
||||
outputs = nn.functional.batch_norm(
|
||||
input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
|
||||
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
|
||||
True, self.momentum, self.eps).view(N, C, H, W)
|
||||
self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
|
||||
self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
|
||||
return outputs
|
||||
else:
|
||||
return nn.functional.batch_norm(
|
||||
input, self.running_mean, self.running_var,
|
||||
self.weight, self.bias, False, self.momentum, self.eps)
|
||||
|
||||
|
||||
class FrozenBatchNorm(BatchNorm):
|
||||
"""
|
||||
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
||||
It contains non-trainable buffers called
|
||||
"weight" and "bias", "running_mean", "running_var",
|
||||
initialized to perform identity transformation.
|
||||
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
||||
which are computed from the original four parameters of BN.
|
||||
The affine transform `x * weight + bias` will perform the equivalent
|
||||
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
||||
When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
||||
will be left unchanged as identity transformation.
|
||||
Other pre-trained backbone models may contain all 4 parameters.
|
||||
The forward is implemented by `F.batch_norm(..., training=False)`.
|
||||
"""
|
||||
|
||||
_version = 3
|
||||
|
||||
def __init__(self, num_features, eps=1e-5):
|
||||
super().__init__()
|
||||
self.num_features = num_features
|
||||
self.eps = eps
|
||||
self.register_buffer("weight", torch.ones(num_features))
|
||||
self.register_buffer("bias", torch.zeros(num_features))
|
||||
self.register_buffer("running_mean", torch.zeros(num_features))
|
||||
self.register_buffer("running_var", torch.ones(num_features) - eps)
|
||||
|
||||
def forward(self, x):
|
||||
if x.requires_grad:
|
||||
# When gradients are needed, F.batch_norm will use extra memory
|
||||
# because its backward op computes gradients for weight/bias as well.
|
||||
scale = self.weight * (self.running_var + self.eps).rsqrt()
|
||||
bias = self.bias - self.running_mean * scale
|
||||
scale = scale.reshape(1, -1, 1, 1)
|
||||
bias = bias.reshape(1, -1, 1, 1)
|
||||
return x * scale + bias
|
||||
else:
|
||||
# When gradients are not needed, F.batch_norm is a single fused op
|
||||
# and provide more optimization opportunities.
|
||||
return F.batch_norm(
|
||||
x,
|
||||
self.running_mean,
|
||||
self.running_var,
|
||||
self.weight,
|
||||
self.bias,
|
||||
training=False,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
def _load_from_state_dict(
|
||||
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
):
|
||||
version = local_metadata.get("version", None)
|
||||
|
||||
if version is None or version < 2:
|
||||
# No running_mean/var in early versions
|
||||
# This will silent the warnings
|
||||
if prefix + "running_mean" not in state_dict:
|
||||
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
||||
if prefix + "running_var" not in state_dict:
|
||||
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
||||
|
||||
if version is not None and version < 3:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
|
||||
# In version < 3, running_var are used without +eps.
|
||||
state_dict[prefix + "running_var"] -= self.eps
|
||||
|
||||
super()._load_from_state_dict(
|
||||
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
||||
|
||||
@classmethod
|
||||
def convert_frozen_batchnorm(cls, module):
|
||||
"""
|
||||
Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
||||
Args:
|
||||
module (torch.nn.Module):
|
||||
Returns:
|
||||
If module is BatchNorm/SyncBatchNorm, returns a new module.
|
||||
Otherwise, in-place convert module and return it.
|
||||
Similar to convert_sync_batchnorm in
|
||||
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
||||
"""
|
||||
bn_module = nn.modules.batchnorm
|
||||
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
||||
res = module
|
||||
if isinstance(module, bn_module):
|
||||
res = cls(module.num_features)
|
||||
if module.affine:
|
||||
res.weight.data = module.weight.data.clone().detach()
|
||||
res.bias.data = module.bias.data.clone().detach()
|
||||
res.running_mean.data = module.running_mean.data
|
||||
res.running_var.data = module.running_var.data
|
||||
res.eps = module.eps
|
||||
else:
|
||||
for name, child in module.named_children():
|
||||
new_child = cls.convert_frozen_batchnorm(child)
|
||||
if new_child is not child:
|
||||
res.add_module(name, new_child)
|
||||
return res
|
||||
|
||||
|
||||
def get_norm(norm, out_channels, num_splits=1, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
norm (str or callable):
|
||||
Returns:
|
||||
nn.Module or None: the normalization layer
|
||||
"""
|
||||
if isinstance(norm, str):
|
||||
if len(norm) == 0:
|
||||
return None
|
||||
norm = {
|
||||
"BN": BatchNorm(out_channels, **kwargs),
|
||||
"GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs),
|
||||
"FrozenBN": FrozenBatchNorm(out_channels),
|
||||
"GN": nn.GroupNorm(32, out_channels),
|
||||
"syncBN": SyncBatchNorm(out_channels, **kwargs), # it is unavailable now
|
||||
}[norm]
|
||||
return norm
|
@ -3,7 +3,7 @@
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from .norm import get_norm
|
||||
from .batch_norm import get_norm
|
||||
|
||||
|
||||
class Non_local(nn.Module):
|
||||
|
@ -1,87 +0,0 @@
|
||||
# encoding: utf-8
|
||||
"""
|
||||
@author: liaoxingyu
|
||||
@contact: sherlockliao01@gmail.com
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
__all__ = [
|
||||
"BatchNorm",
|
||||
"IBN",
|
||||
"GhostBatchNorm",
|
||||
"get_norm",
|
||||
]
|
||||
|
||||
|
||||
class BatchNorm(nn.BatchNorm2d):
|
||||
def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0,
|
||||
bias_init=0.0):
|
||||
super().__init__(num_features, eps=eps, momentum=momentum)
|
||||
if weight_init is not None: self.weight.data.fill_(weight_init)
|
||||
if bias_init is not None: self.bias.data.fill_(bias_init)
|
||||
self.weight.requires_grad = not weight_freeze
|
||||
self.bias.requires_grad = not bias_freeze
|
||||
|
||||
|
||||
class IBN(nn.Module):
|
||||
def __init__(self, planes, bn_norm, num_splits):
|
||||
super(IBN, self).__init__()
|
||||
half1 = int(planes / 2)
|
||||
self.half = half1
|
||||
half2 = planes - half1
|
||||
self.IN = nn.InstanceNorm2d(half1, affine=True)
|
||||
self.BN = get_norm(bn_norm, half2, num_splits)
|
||||
|
||||
def forward(self, x):
|
||||
split = torch.split(x, self.half, 1)
|
||||
out1 = self.IN(split[0].contiguous())
|
||||
out2 = self.BN(split[1].contiguous())
|
||||
out = torch.cat((out1, out2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class GhostBatchNorm(BatchNorm):
|
||||
def __init__(self, num_features, num_splits=1, **kwargs):
|
||||
super().__init__(num_features, **kwargs)
|
||||
self.num_splits = num_splits
|
||||
self.register_buffer('running_mean', torch.zeros(num_features))
|
||||
self.register_buffer('running_var', torch.ones(num_features))
|
||||
|
||||
def forward(self, input):
|
||||
N, C, H, W = input.shape
|
||||
if self.training or not self.track_running_stats:
|
||||
self.running_mean = self.running_mean.repeat(self.num_splits)
|
||||
self.running_var = self.running_var.repeat(self.num_splits)
|
||||
outputs = nn.functional.batch_norm(
|
||||
input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var,
|
||||
self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits),
|
||||
True, self.momentum, self.eps).view(N, C, H, W)
|
||||
self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0)
|
||||
self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0)
|
||||
return outputs
|
||||
else:
|
||||
return nn.functional.batch_norm(
|
||||
input, self.running_mean, self.running_var,
|
||||
self.weight, self.bias, False, self.momentum, self.eps)
|
||||
|
||||
|
||||
def get_norm(norm, out_channels, num_splits=1, **kwargs):
|
||||
"""
|
||||
Args:
|
||||
norm (str or callable):
|
||||
Returns:
|
||||
nn.Module or None: the normalization layer
|
||||
"""
|
||||
if isinstance(norm, str):
|
||||
if len(norm) == 0:
|
||||
return None
|
||||
norm = {
|
||||
"BN": BatchNorm(out_channels, **kwargs),
|
||||
"GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs),
|
||||
# "FrozenBN": FrozenBatchNorm2d,
|
||||
# "GN": lambda channels: nn.GroupNorm(32, channels),
|
||||
# "nnSyncBN": nn.SyncBatchNorm, # keep for debugging
|
||||
}[norm]
|
||||
return norm
|
@ -20,11 +20,11 @@ class ReductionHead(nn.Module):
|
||||
|
||||
self.bottleneck = nn.Sequential(
|
||||
nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False),
|
||||
BatchNorm(reduction_dim, bias_freeze=True),
|
||||
get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True),
|
||||
nn.LeakyReLU(0.1),
|
||||
nn.Dropout2d(0.5),
|
||||
)
|
||||
self.bnneck = BatchNorm(reduction_dim, bias_freeze=True)
|
||||
self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True)
|
||||
|
||||
self.bottleneck.apply(weights_init_kaiming)
|
||||
self.bnneck.apply(weights_init_kaiming)
|
||||
|
Loading…
x
Reference in New Issue
Block a user