diff --git a/fastreid/__init__.py b/fastreid/__init__.py index 980c5fe..eaf1de3 100644 --- a/fastreid/__init__.py +++ b/fastreid/__init__.py @@ -2,4 +2,7 @@ """ @author: liaoxingyu @contact: sherlockliao01@gmail.com -""" \ No newline at end of file +""" + + +__version__ = "0.1.0" \ No newline at end of file diff --git a/fastreid/layers/__init__.py b/fastreid/layers/__init__.py index f985e7c..ab6bec9 100644 --- a/fastreid/layers/__init__.py +++ b/fastreid/layers/__init__.py @@ -7,7 +7,7 @@ from torch import nn from .batch_drop import BatchDrop from .attention import * -from .norm import * +from .batch_norm import * from .context_block import ContextBlock from .non_local import Non_local from .se_layer import SELayer diff --git a/fastreid/layers/batch_norm.py b/fastreid/layers/batch_norm.py new file mode 100644 index 0000000..44bad44 --- /dev/null +++ b/fastreid/layers/batch_norm.py @@ -0,0 +1,206 @@ +# encoding: utf-8 +""" +@author: liaoxingyu +@contact: sherlockliao01@gmail.com +""" + +import torch +import logging +import torch.nn.functional as F +from torch import nn + +__all__ = [ + "BatchNorm", + "IBN", + "GhostBatchNorm", + "FrozenBatchNorm", + "get_norm", +] + + +class BatchNorm(nn.BatchNorm2d): + def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, + bias_init=0.0): + super().__init__(num_features, eps=eps, momentum=momentum) + if weight_init is not None: self.weight.data.fill_(weight_init) + if bias_init is not None: self.bias.data.fill_(bias_init) + self.weight.requires_grad_(not weight_freeze) + self.bias.requires_grad_(not bias_freeze) + + +class SyncBatchNorm(nn.SyncBatchNorm): + def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, + bias_init=0.0): + super().__init__(num_features, eps=eps, momentum=momentum) + if weight_init is not None: self.weight.data.fill_(weight_init) + if bias_init is not None: self.bias.data.fill_(bias_init) + self.weight.requires_grad_(not weight_freeze) + self.bias.requires_grad_(not bias_freeze) + + +class IBN(nn.Module): + def __init__(self, planes, bn_norm, num_splits): + super(IBN, self).__init__() + half1 = int(planes / 2) + self.half = half1 + half2 = planes - half1 + self.IN = nn.InstanceNorm2d(half1, affine=True) + self.BN = get_norm(bn_norm, half2, num_splits) + + def forward(self, x): + split = torch.split(x, self.half, 1) + out1 = self.IN(split[0].contiguous()) + out2 = self.BN(split[1].contiguous()) + out = torch.cat((out1, out2), 1) + return out + + +class GhostBatchNorm(BatchNorm): + def __init__(self, num_features, num_splits=1, **kwargs): + super().__init__(num_features, **kwargs) + self.num_splits = num_splits + self.register_buffer('running_mean', torch.zeros(num_features)) + self.register_buffer('running_var', torch.ones(num_features)) + + def forward(self, input): + N, C, H, W = input.shape + if self.training or not self.track_running_stats: + self.running_mean = self.running_mean.repeat(self.num_splits) + self.running_var = self.running_var.repeat(self.num_splits) + outputs = nn.functional.batch_norm( + input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var, + self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), + True, self.momentum, self.eps).view(N, C, H, W) + self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0) + self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0) + return outputs + else: + return nn.functional.batch_norm( + input, self.running_mean, self.running_var, + self.weight, self.bias, False, self.momentum, self.eps) + + +class FrozenBatchNorm(BatchNorm): + """ + BatchNorm2d where the batch statistics and the affine parameters are fixed. + It contains non-trainable buffers called + "weight" and "bias", "running_mean", "running_var", + initialized to perform identity transformation. + The pre-trained backbone models from Caffe2 only contain "weight" and "bias", + which are computed from the original four parameters of BN. + The affine transform `x * weight + bias` will perform the equivalent + computation of `(x - running_mean) / sqrt(running_var) * weight + bias`. + When loading a backbone model from Caffe2, "running_mean" and "running_var" + will be left unchanged as identity transformation. + Other pre-trained backbone models may contain all 4 parameters. + The forward is implemented by `F.batch_norm(..., training=False)`. + """ + + _version = 3 + + def __init__(self, num_features, eps=1e-5): + super().__init__() + self.num_features = num_features + self.eps = eps + self.register_buffer("weight", torch.ones(num_features)) + self.register_buffer("bias", torch.zeros(num_features)) + self.register_buffer("running_mean", torch.zeros(num_features)) + self.register_buffer("running_var", torch.ones(num_features) - eps) + + def forward(self, x): + if x.requires_grad: + # When gradients are needed, F.batch_norm will use extra memory + # because its backward op computes gradients for weight/bias as well. + scale = self.weight * (self.running_var + self.eps).rsqrt() + bias = self.bias - self.running_mean * scale + scale = scale.reshape(1, -1, 1, 1) + bias = bias.reshape(1, -1, 1, 1) + return x * scale + bias + else: + # When gradients are not needed, F.batch_norm is a single fused op + # and provide more optimization opportunities. + return F.batch_norm( + x, + self.running_mean, + self.running_var, + self.weight, + self.bias, + training=False, + eps=self.eps, + ) + + def _load_from_state_dict( + self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ): + version = local_metadata.get("version", None) + + if version is None or version < 2: + # No running_mean/var in early versions + # This will silent the warnings + if prefix + "running_mean" not in state_dict: + state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean) + if prefix + "running_var" not in state_dict: + state_dict[prefix + "running_var"] = torch.ones_like(self.running_var) + + if version is not None and version < 3: + logger = logging.getLogger(__name__) + logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip("."))) + # In version < 3, running_var are used without +eps. + state_dict[prefix + "running_var"] -= self.eps + + super()._load_from_state_dict( + state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs + ) + + def __repr__(self): + return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps) + + @classmethod + def convert_frozen_batchnorm(cls, module): + """ + Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm. + Args: + module (torch.nn.Module): + Returns: + If module is BatchNorm/SyncBatchNorm, returns a new module. + Otherwise, in-place convert module and return it. + Similar to convert_sync_batchnorm in + https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py + """ + bn_module = nn.modules.batchnorm + bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm) + res = module + if isinstance(module, bn_module): + res = cls(module.num_features) + if module.affine: + res.weight.data = module.weight.data.clone().detach() + res.bias.data = module.bias.data.clone().detach() + res.running_mean.data = module.running_mean.data + res.running_var.data = module.running_var.data + res.eps = module.eps + else: + for name, child in module.named_children(): + new_child = cls.convert_frozen_batchnorm(child) + if new_child is not child: + res.add_module(name, new_child) + return res + + +def get_norm(norm, out_channels, num_splits=1, **kwargs): + """ + Args: + norm (str or callable): + Returns: + nn.Module or None: the normalization layer + """ + if isinstance(norm, str): + if len(norm) == 0: + return None + norm = { + "BN": BatchNorm(out_channels, **kwargs), + "GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs), + "FrozenBN": FrozenBatchNorm(out_channels), + "GN": nn.GroupNorm(32, out_channels), + "syncBN": SyncBatchNorm(out_channels, **kwargs), # it is unavailable now + }[norm] + return norm diff --git a/fastreid/layers/non_local.py b/fastreid/layers/non_local.py index a914790..876ec43 100644 --- a/fastreid/layers/non_local.py +++ b/fastreid/layers/non_local.py @@ -3,7 +3,7 @@ import torch from torch import nn -from .norm import get_norm +from .batch_norm import get_norm class Non_local(nn.Module): diff --git a/fastreid/layers/norm.py b/fastreid/layers/norm.py deleted file mode 100644 index 8eab306..0000000 --- a/fastreid/layers/norm.py +++ /dev/null @@ -1,87 +0,0 @@ -# encoding: utf-8 -""" -@author: liaoxingyu -@contact: sherlockliao01@gmail.com -""" - -import torch -from torch import nn - -__all__ = [ - "BatchNorm", - "IBN", - "GhostBatchNorm", - "get_norm", -] - - -class BatchNorm(nn.BatchNorm2d): - def __init__(self, num_features, eps=1e-05, momentum=0.1, weight_freeze=False, bias_freeze=False, weight_init=1.0, - bias_init=0.0): - super().__init__(num_features, eps=eps, momentum=momentum) - if weight_init is not None: self.weight.data.fill_(weight_init) - if bias_init is not None: self.bias.data.fill_(bias_init) - self.weight.requires_grad = not weight_freeze - self.bias.requires_grad = not bias_freeze - - -class IBN(nn.Module): - def __init__(self, planes, bn_norm, num_splits): - super(IBN, self).__init__() - half1 = int(planes / 2) - self.half = half1 - half2 = planes - half1 - self.IN = nn.InstanceNorm2d(half1, affine=True) - self.BN = get_norm(bn_norm, half2, num_splits) - - def forward(self, x): - split = torch.split(x, self.half, 1) - out1 = self.IN(split[0].contiguous()) - out2 = self.BN(split[1].contiguous()) - out = torch.cat((out1, out2), 1) - return out - - -class GhostBatchNorm(BatchNorm): - def __init__(self, num_features, num_splits=1, **kwargs): - super().__init__(num_features, **kwargs) - self.num_splits = num_splits - self.register_buffer('running_mean', torch.zeros(num_features)) - self.register_buffer('running_var', torch.ones(num_features)) - - def forward(self, input): - N, C, H, W = input.shape - if self.training or not self.track_running_stats: - self.running_mean = self.running_mean.repeat(self.num_splits) - self.running_var = self.running_var.repeat(self.num_splits) - outputs = nn.functional.batch_norm( - input.view(-1, C * self.num_splits, H, W), self.running_mean, self.running_var, - self.weight.repeat(self.num_splits), self.bias.repeat(self.num_splits), - True, self.momentum, self.eps).view(N, C, H, W) - self.running_mean = torch.mean(self.running_mean.view(self.num_splits, self.num_features), dim=0) - self.running_var = torch.mean(self.running_var.view(self.num_splits, self.num_features), dim=0) - return outputs - else: - return nn.functional.batch_norm( - input, self.running_mean, self.running_var, - self.weight, self.bias, False, self.momentum, self.eps) - - -def get_norm(norm, out_channels, num_splits=1, **kwargs): - """ - Args: - norm (str or callable): - Returns: - nn.Module or None: the normalization layer - """ - if isinstance(norm, str): - if len(norm) == 0: - return None - norm = { - "BN": BatchNorm(out_channels, **kwargs), - "GhostBN": GhostBatchNorm(out_channels, num_splits, **kwargs), - # "FrozenBN": FrozenBatchNorm2d, - # "GN": lambda channels: nn.GroupNorm(32, channels), - # "nnSyncBN": nn.SyncBatchNorm, # keep for debugging - }[norm] - return norm diff --git a/fastreid/modeling/heads/reduction_head.py b/fastreid/modeling/heads/reduction_head.py index ee6be4f..f587c70 100644 --- a/fastreid/modeling/heads/reduction_head.py +++ b/fastreid/modeling/heads/reduction_head.py @@ -20,11 +20,11 @@ class ReductionHead(nn.Module): self.bottleneck = nn.Sequential( nn.Conv2d(in_feat, reduction_dim, 1, 1, bias=False), - BatchNorm(reduction_dim, bias_freeze=True), + get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True), nn.LeakyReLU(0.1), nn.Dropout2d(0.5), ) - self.bnneck = BatchNorm(reduction_dim, bias_freeze=True) + self.bnneck = get_norm(cfg.MODEL.HEADS.NORM, reduction_dim, cfg.MODEL.HEADS.NORM_SPLIT, bias_freeze=True) self.bottleneck.apply(weights_init_kaiming) self.bnneck.apply(weights_init_kaiming)