# encoding: utf-8 """ @author: liaoxingyu @contact: sherlockliao01@gmail.com """ import torch from torch import nn from torch.nn.modules.batchnorm import BatchNorm2d from torch.nn import ReLU, LeakyReLU from torch.nn.parameter import Parameter class TLU(nn.Module): def __init__(self, num_features): """max(y, tau) = max(y - tau, 0) + tau = ReLU(y - tau) + tau""" super(TLU, self).__init__() self.num_features = num_features self.tau = Parameter(torch.Tensor(num_features)) self.reset_parameters() def reset_parameters(self): nn.init.zeros_(self.tau) def extra_repr(self): return 'num_features={num_features}'.format(**self.__dict__) def forward(self, x): return torch.max(x, self.tau.view(1, self.num_features, 1, 1)) class FRN(nn.Module): def __init__(self, num_features, eps=1e-6, is_eps_leanable=False): """ weight = gamma, bias = beta beta, gamma: Variables of shape [1, 1, 1, C]. if TensorFlow Variables of shape [1, C, 1, 1]. if PyTorch eps: A scalar constant or learnable variable. """ super(FRN, self).__init__() self.num_features = num_features self.init_eps = eps self.is_eps_leanable = is_eps_leanable self.weight = Parameter(torch.Tensor(num_features)) self.bias = Parameter(torch.Tensor(num_features)) if is_eps_leanable: self.eps = Parameter(torch.Tensor(1)) else: self.register_buffer('eps', torch.Tensor([eps])) self.reset_parameters() def reset_parameters(self): nn.init.ones_(self.weight) nn.init.zeros_(self.bias) if self.is_eps_leanable: nn.init.constant_(self.eps, self.init_eps) def extra_repr(self): return 'num_features={num_features}, eps={init_eps}'.format(**self.__dict__) def forward(self, x): """ 0, 1, 2, 3 -> (B, H, W, C) in TensorFlow 0, 1, 2, 3 -> (B, C, H, W) in PyTorch TensorFlow code nu2 = tf.reduce_mean(tf.square(x), axis=[1, 2], keepdims=True) x = x * tf.rsqrt(nu2 + tf.abs(eps)) # This Code include TLU function max(y, tau) return tf.maximum(gamma * x + beta, tau) """ # Compute the mean norm of activations per channel. nu2 = x.pow(2).mean(dim=[2, 3], keepdim=True) # Perform FRN. x = x * torch.rsqrt(nu2 + self.eps.abs()) # Scale and Bias x = self.weight.view(1, self.num_features, 1, 1) * x + self.bias.view(1, self.num_features, 1, 1) # x = self.weight * x + self.bias return x def bnrelu_to_frn(module): """ Convert 'BatchNorm2d + ReLU' to 'FRN + TLU' """ mod = module before_name = None before_child = None is_before_bn = False for name, child in module.named_children(): if is_before_bn and isinstance(child, (ReLU, LeakyReLU)): # Convert BN to FRN if isinstance(before_child, BatchNorm2d): mod.add_module( before_name, FRN(num_features=before_child.num_features)) else: raise NotImplementedError() # Convert ReLU to TLU mod.add_module(name, TLU(num_features=before_child.num_features)) else: mod.add_module(name, bnrelu_to_frn(child)) before_name = name before_child = child is_before_bn = isinstance(child, BatchNorm2d) return mod def convert(module, flag_name): mod = module before_ch = None for name, child in module.named_children(): if hasattr(child, flag_name) and getattr(child, flag_name): if isinstance(child, BatchNorm2d): before_ch = child.num_features mod.add_module(name, FRN(num_features=child.num_features)) # TODO bn is no good... if isinstance(child, (ReLU, LeakyReLU)): mod.add_module(name, TLU(num_features=before_ch)) else: mod.add_module(name, convert(child, flag_name)) return mod def remove_flags(module, flag_name): mod = module for name, child in module.named_children(): if hasattr(child, 'is_convert_frn'): delattr(child, flag_name) mod.add_module(name, remove_flags(child, flag_name)) else: mod.add_module(name, remove_flags(child, flag_name)) return mod def bnrelu_to_frn2(model, input_size=(3, 128, 128), batch_size=2, flag_name='is_convert_frn'): forard_hooks = list() backward_hooks = list() is_before_bn = [False] def register_forward_hook(module): def hook(self, input, output): if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model): is_before_bn.append(False) return # input and output is required in hook def is_converted = is_before_bn[-1] and isinstance(self, (ReLU, LeakyReLU)) if is_converted: setattr(self, flag_name, True) is_before_bn.append(isinstance(self, BatchNorm2d)) forard_hooks.append(module.register_forward_hook(hook)) is_before_relu = [False] def register_backward_hook(module): def hook(self, input, output): if isinstance(module, (nn.Sequential, nn.ModuleList)) or (module == model): is_before_relu.append(False) return is_converted = is_before_relu[-1] and isinstance(self, BatchNorm2d) if is_converted: setattr(self, flag_name, True) is_before_relu.append(isinstance(self, (ReLU, LeakyReLU))) backward_hooks.append(module.register_backward_hook(hook)) # multiple inputs to the network if isinstance(input_size, tuple): input_size = [input_size] # batch_size of 2 for batchnorm x = [torch.rand(batch_size, *in_size) for in_size in input_size] # register hook model.apply(register_forward_hook) model.apply(register_backward_hook) # make a forward pass output = model(*x) output.sum().backward() # Raw output is not enabled to use backward() # remove these hooks for h in forard_hooks: h.remove() for h in backward_hooks: h.remove() model = convert(model, flag_name=flag_name) model = remove_flags(model, flag_name=flag_name) return model