fast-reid/fastreid/layers/frn.py

200 lines
6.3 KiB
Python

# encoding: utf-8
"""
@author: liaoxingyu
@contact: sherlockliao01@gmail.com
"""
import torch
from torch import nn
from torch.nn.modules.batchnorm import BatchNorm2d
from torch.nn import ReLU, LeakyReLU
from torch.nn.parameter import Parameter
class TLU(nn.Module):
def __init__(self, num_features):
"""max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau"""
super(TLU, self).__init__()
self.num_features = num_features
self.tau = Parameter(torch.Tensor(num_features))
self.reset_parameters()
def reset_parameters(self):
nn.init.zeros_(self.tau)
def extra_repr(self):
return 'num_features={num_features}'.format(**self.__dict__)
def forward(self, x):
return torch.max(x, self.tau.view(1, self.num_features, 1, 1))
class FRN(nn.Module):
def __init__(self, num_features, eps=1e-6, is_eps_leanable=False):
"""
weight = gamma, bias = beta
beta, gamma:
Variables of shape [1, 1, 1, C]. if TensorFlow
Variables of shape [1, C, 1, 1]. if PyTorch
eps: A scalar constant or learnable variable.
"""
super(FRN, self).__init__()
self.num_features = num_features
self.init_eps = eps
self.is_eps_leanable = is_eps_leanable
self.weight = Parameter(torch.Tensor(num_features))
self.bias = Parameter(torch.Tensor(num_features))
if is_eps_leanable:
self.eps = Parameter(torch.Tensor(1))
else:
self.register_buffer('eps', torch.Tensor([eps]))
self.reset_parameters()
def reset_parameters(self):
nn.init.ones_(self.weight)
nn.init.zeros_(self.bias)
if self.is_eps_leanable:
nn.init.constant_(self.eps, self.init_eps)
def extra_repr(self):
return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__)
def forward(self, x):
"""
0, 1, 2, 3 -> (B, H, W, C) in TensorFlow
0, 1, 2, 3 -> (B, C, H, W) in PyTorch
TensorFlow code
nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True)
x = x * tf.rsqrt(nu2 + tf.abs(eps))
# This Code include TLU function max(y, tau)
return tf.maximum(gamma * x + beta, tau)
"""
# Compute the mean norm of activations per channel.
nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True)
# Perform FRN.
x = x * torch.rsqrt(nu2 + self.eps.abs())
# Scale and Bias
x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1)
# x = self.weight * x + self.bias
return x
def bnrelu_to_frn(module):
"""
Convert 'BatchNorm2d + ReLU' to 'FRN + TLU'
"""
mod = module
before_name = None
before_child = None
is_before_bn = False
for name, child in module.named_children():
if is_before_bn and isinstance(child, (ReLU, LeakyReLU)):
# Convert BN to FRN
if isinstance(before_child, BatchNorm2d):
mod.add_module(
before_name, FRN(num_features=before_child.num_features))
else:
raise NotImplementedError()
# Convert ReLU to TLU
mod.add_module(name, TLU(num_features=before_child.num_features))
else:
mod.add_module(name, bnrelu_to_frn(child))
before_name = name
before_child = child
is_before_bn = isinstance(child, BatchNorm2d)
return mod
def convert(module, flag_name):
mod = module
before_ch = None
for name, child in module.named_children():
if hasattr(child, flag_name) and getattr(child, flag_name):
if isinstance(child, BatchNorm2d):
before_ch = child.num_features
mod.add_module(name, FRN(num_features=child.num_features))
# TODO bn is no good...
if isinstance(child, (ReLU, LeakyReLU)):
mod.add_module(name, TLU(num_features=before_ch))
else:
mod.add_module(name, convert(child, flag_name))
return mod
def remove_flags(module, flag_name):
mod = module
for name, child in module.named_children():
if hasattr(child, 'is_convert_frn'):
delattr(child, flag_name)
mod.add_module(name, remove_flags(child, flag_name))
else:
mod.add_module(name, remove_flags(child, flag_name))
return mod
def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'):
forard_hooks = list()
backward_hooks = list()
is_before_bn = [False]
def register_forward_hook(module):
def hook(self, input, output):
if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
is_before_bn.append(False)
return
# input and output is required in hook def
is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU))
if is_converted:
setattr(self, flag_name, True)
is_before_bn.append(isinstance(self, BatchNorm2d))
forard_hooks.append(module.register_forward_hook(hook))
is_before_relu = [False]
def register_backward_hook(module):
def hook(self, input, output):
if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model):
is_before_relu.append(False)
return
is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d)
if is_converted:
setattr(self, flag_name, True)
is_before_relu.append(isinstance(self, (ReLU, LeakyReLU)))
backward_hooks.append(module.register_backward_hook(hook))
# multiple inputs to the network
if isinstance(input_size, tuple):
input_size = [input_size]
# batch_size of 2 for batchnorm
x = [torch.rand(batch_size, *in_size) for in_size in input_size]
# register hook
model.apply(register_forward_hook)
model.apply(register_backward_hook)
# make a forward pass
output = model(*x)
output.sum().backward() # Raw output is not enabled to use backward()
# remove these hooks
for h in forard_hooks:
h.remove()
for h in backward_hooks:
h.remove()
model = convert(model, flag_name=flag_name)
model = remove_flags(model, flag_name=flag_name)
return model