mirror of https://github.com/RE-OWOD/RE-OWOD
241 lines
9.6 KiB
Python
241 lines
9.6 KiB
Python
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
|
|
import logging
|
|
import torch
|
|
import torch.distributed as dist
|
|
from torch import nn
|
|
from torch.autograd.function import Function
|
|
from torch.nn import functional as F
|
|
|
|
from detectron2.utils import comm, env
|
|
|
|
from .wrappers import BatchNorm2d
|
|
|
|
|
|
class FrozenBatchNorm2d(nn.Module):
|
|
"""
|
|
BatchNorm2d where the batch statistics and the affine parameters are fixed.
|
|
|
|
It contains non-trainable buffers called
|
|
"weight" and "bias", "running_mean", "running_var",
|
|
initialized to perform identity transformation.
|
|
|
|
The pre-trained backbone models from Caffe2 only contain "weight" and "bias",
|
|
which are computed from the original four parameters of BN.
|
|
The affine transform `x * weight + bias` will perform the equivalent
|
|
computation of `(x - running_mean) / sqrt(running_var) * weight + bias`.
|
|
When loading a backbone model from Caffe2, "running_mean" and "running_var"
|
|
will be left unchanged as identity transformation.
|
|
|
|
Other pre-trained backbone models may contain all 4 parameters.
|
|
|
|
The forward is implemented by `F.batch_norm(..., training=False)`.
|
|
"""
|
|
|
|
_version = 3
|
|
|
|
def __init__(self, num_features, eps=1e-5):
|
|
super().__init__()
|
|
self.num_features = num_features
|
|
self.eps = eps
|
|
self.register_buffer("weight", torch.ones(num_features))
|
|
self.register_buffer("bias", torch.zeros(num_features))
|
|
self.register_buffer("running_mean", torch.zeros(num_features))
|
|
self.register_buffer("running_var", torch.ones(num_features) - eps)
|
|
|
|
def forward(self, x):
|
|
if x.requires_grad:
|
|
# When gradients are needed, F.batch_norm will use extra memory
|
|
# because its backward op computes gradients for weight/bias as well.
|
|
scale = self.weight * (self.running_var + self.eps).rsqrt()
|
|
bias = self.bias - self.running_mean * scale
|
|
scale = scale.reshape(1, -1, 1, 1)
|
|
bias = bias.reshape(1, -1, 1, 1)
|
|
return x * scale + bias
|
|
else:
|
|
# When gradients are not needed, F.batch_norm is a single fused op
|
|
# and provide more optimization opportunities.
|
|
return F.batch_norm(
|
|
x,
|
|
self.running_mean,
|
|
self.running_var,
|
|
self.weight,
|
|
self.bias,
|
|
training=False,
|
|
eps=self.eps,
|
|
)
|
|
|
|
def _load_from_state_dict(
|
|
self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
):
|
|
version = local_metadata.get("version", None)
|
|
|
|
if version is None or version < 2:
|
|
# No running_mean/var in early versions
|
|
# This will silent the warnings
|
|
if prefix + "running_mean" not in state_dict:
|
|
state_dict[prefix + "running_mean"] = torch.zeros_like(self.running_mean)
|
|
if prefix + "running_var" not in state_dict:
|
|
state_dict[prefix + "running_var"] = torch.ones_like(self.running_var)
|
|
|
|
if version is not None and version < 3:
|
|
logger = logging.getLogger(__name__)
|
|
logger.info("FrozenBatchNorm {} is upgraded to version 3.".format(prefix.rstrip(".")))
|
|
# In version < 3, running_var are used without +eps.
|
|
state_dict[prefix + "running_var"] -= self.eps
|
|
|
|
super()._load_from_state_dict(
|
|
state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
|
|
)
|
|
|
|
def __repr__(self):
|
|
return "FrozenBatchNorm2d(num_features={}, eps={})".format(self.num_features, self.eps)
|
|
|
|
@classmethod
|
|
def convert_frozen_batchnorm(cls, module):
|
|
"""
|
|
Convert BatchNorm/SyncBatchNorm in module into FrozenBatchNorm.
|
|
|
|
Args:
|
|
module (torch.nn.Module):
|
|
|
|
Returns:
|
|
If module is BatchNorm/SyncBatchNorm, returns a new module.
|
|
Otherwise, in-place convert module and return it.
|
|
|
|
Similar to convert_sync_batchnorm in
|
|
https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/batchnorm.py
|
|
"""
|
|
bn_module = nn.modules.batchnorm
|
|
bn_module = (bn_module.BatchNorm2d, bn_module.SyncBatchNorm)
|
|
res = module
|
|
if isinstance(module, bn_module):
|
|
res = cls(module.num_features)
|
|
if module.affine:
|
|
res.weight.data = module.weight.data.clone().detach()
|
|
res.bias.data = module.bias.data.clone().detach()
|
|
res.running_mean.data = module.running_mean.data
|
|
res.running_var.data = module.running_var.data
|
|
res.eps = module.eps
|
|
else:
|
|
for name, child in module.named_children():
|
|
new_child = cls.convert_frozen_batchnorm(child)
|
|
if new_child is not child:
|
|
res.add_module(name, new_child)
|
|
return res
|
|
|
|
|
|
def get_norm(norm, out_channels):
|
|
"""
|
|
Args:
|
|
norm (str or callable): either one of BN, SyncBN, FrozenBN, GN;
|
|
or a callable that takes a channel number and returns
|
|
the normalization layer as a nn.Module.
|
|
|
|
Returns:
|
|
nn.Module or None: the normalization layer
|
|
"""
|
|
if isinstance(norm, str):
|
|
if len(norm) == 0:
|
|
return None
|
|
norm = {
|
|
"BN": BatchNorm2d,
|
|
# Fixed in https://github.com/pytorch/pytorch/pull/36382
|
|
"SyncBN": NaiveSyncBatchNorm if env.TORCH_VERSION <= (1, 5) else nn.SyncBatchNorm,
|
|
"FrozenBN": FrozenBatchNorm2d,
|
|
"GN": lambda channels: nn.GroupNorm(32, channels),
|
|
# for debugging:
|
|
"nnSyncBN": nn.SyncBatchNorm,
|
|
"naiveSyncBN": NaiveSyncBatchNorm,
|
|
}[norm]
|
|
return norm(out_channels)
|
|
|
|
|
|
class AllReduce(Function):
|
|
@staticmethod
|
|
def forward(ctx, input):
|
|
input_list = [torch.zeros_like(input) for k in range(dist.get_world_size())]
|
|
# Use allgather instead of allreduce since I don't trust in-place operations ..
|
|
dist.all_gather(input_list, input, async_op=False)
|
|
inputs = torch.stack(input_list, dim=0)
|
|
return torch.sum(inputs, dim=0)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_output):
|
|
dist.all_reduce(grad_output, async_op=False)
|
|
return grad_output
|
|
|
|
|
|
class NaiveSyncBatchNorm(BatchNorm2d):
|
|
"""
|
|
In PyTorch<=1.5, ``nn.SyncBatchNorm`` has incorrect gradient
|
|
when the batch size on each worker is different.
|
|
(e.g., when scale augmentation is used, or when it is applied to mask head).
|
|
|
|
This is a slower but correct alternative to `nn.SyncBatchNorm`.
|
|
|
|
Note:
|
|
There isn't a single definition of Sync BatchNorm.
|
|
|
|
When ``stats_mode==""``, this module computes overall statistics by using
|
|
statistics of each worker with equal weight. The result is true statistics
|
|
of all samples (as if they are all on one worker) only when all workers
|
|
have the same (N, H, W). This mode does not support inputs with zero batch size.
|
|
|
|
When ``stats_mode=="N"``, this module computes overall statistics by weighting
|
|
the statistics of each worker by their ``N``. The result is true statistics
|
|
of all samples (as if they are all on one worker) only when all workers
|
|
have the same (H, W). It is slower than ``stats_mode==""``.
|
|
|
|
Even though the result of this module may not be the true statistics of all samples,
|
|
it may still be reasonable because it might be preferrable to assign equal weights
|
|
to all workers, regardless of their (H, W) dimension, instead of putting larger weight
|
|
on larger images. From preliminary experiments, little difference is found between such
|
|
a simplified implementation and an accurate computation of overall mean & variance.
|
|
"""
|
|
|
|
def __init__(self, *args, stats_mode="", **kwargs):
|
|
super().__init__(*args, **kwargs)
|
|
assert stats_mode in ["", "N"]
|
|
self._stats_mode = stats_mode
|
|
|
|
def forward(self, input):
|
|
if comm.get_world_size() == 1 or not self.training:
|
|
return super().forward(input)
|
|
|
|
B, C = input.shape[0], input.shape[1]
|
|
|
|
mean = torch.mean(input, dim=[0, 2, 3])
|
|
meansqr = torch.mean(input * input, dim=[0, 2, 3])
|
|
|
|
if self._stats_mode == "":
|
|
assert B > 0, 'SyncBatchNorm(stats_mode="") does not support zero batch size.'
|
|
vec = torch.cat([mean, meansqr], dim=0)
|
|
vec = AllReduce.apply(vec) * (1.0 / dist.get_world_size())
|
|
mean, meansqr = torch.split(vec, C)
|
|
momentum = self.momentum
|
|
else:
|
|
if B == 0:
|
|
vec = torch.zeros([2 * C + 1], device=mean.device, dtype=mean.dtype)
|
|
vec = vec + input.sum() # make sure there is gradient w.r.t input
|
|
else:
|
|
vec = torch.cat(
|
|
[mean, meansqr, torch.ones([1], device=mean.device, dtype=mean.dtype)], dim=0
|
|
)
|
|
vec = AllReduce.apply(vec * B)
|
|
|
|
total_batch = vec[-1].detach()
|
|
momentum = total_batch.clamp(max=1) * self.momentum # no update if total_batch is 0
|
|
total_batch = torch.max(total_batch, torch.ones_like(total_batch)) # avoid div-by-zero
|
|
mean, meansqr, _ = torch.split(vec / total_batch, C)
|
|
|
|
var = meansqr - mean * mean
|
|
invstd = torch.rsqrt(var + self.eps)
|
|
scale = self.weight * invstd
|
|
bias = self.bias - mean * scale
|
|
scale = scale.reshape(1, -1, 1, 1)
|
|
bias = bias.reshape(1, -1, 1, 1)
|
|
|
|
self.running_mean += momentum * (mean.detach() - self.running_mean)
|
|
self.running_var += momentum * (var.detach() - self.running_var)
|
|
return input * scale + bias
|