EasyCV/easycv/models/backbones/repvgg_yolox_backbone.py

365 lines
12 KiB
Python

# borrow some code from https://github.com/DingXiaoH/RepVGG/repvgg.py MIT2.0
import warnings
import numpy as np
import torch
import torch.nn as nn
from easycv.models.utils.ops import make_divisible
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
'''Basic cell for rep-style block, including conv and bn'''
result = nn.Sequential()
result.add_module(
'conv',
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=False))
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
return result
class RepVGGBlock(nn.Module):
"""
Basic Block of RepVGG
It's an efficient block that will be reparameterized in evaluation. (deploy = True)
Usage: RepVGGBlock(in_channels, out_channels, ksize=3, stride=stride)
"""
def __init__(self,
in_channels,
out_channels,
ksize=3,
stride=1,
padding=1,
dilation=1,
groups=1,
padding_mode='zeros',
deploy=False,
act=None):
super(RepVGGBlock, self).__init__()
self.deploy = deploy
self.groups = groups
self.in_channels = in_channels
assert ksize == 3
assert padding == 1
padding_11 = padding - ksize // 2
self.nonlinearity = nn.ReLU()
self.se = nn.Identity()
if deploy:
self.rbr_reparam = nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=ksize,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
bias=True,
padding_mode=padding_mode)
else:
self.rbr_identity = nn.BatchNorm2d(
num_features=in_channels
) if out_channels == in_channels and stride == 1 else None
self.rbr_dense = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=ksize,
stride=stride,
padding=padding,
groups=groups)
self.rbr_1x1 = conv_bn(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=1,
stride=stride,
padding=padding_11,
groups=groups)
def forward(self, inputs):
if hasattr(self, 'rbr_reparam'):
return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
if self.rbr_identity is None:
id_out = 0
else:
id_out = self.rbr_identity(inputs)
return self.nonlinearity(
self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
# Optional. This improves the accuracy and facilitates quantization.
# 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
# 2. Use like this.
# loss = criterion(....)
# for every RepVGGBlock blk:
# loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
# optimizer.zero_grad()
# loss.backward()
def get_custom_L2(self):
K3 = self.rbr_dense.conv.weight
K1 = self.rbr_1x1.conv.weight
t3 = (self.rbr_dense.bn.weight /
((self.rbr_dense.bn.running_var +
self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var +
self.rbr_1x1.bn.eps).sqrt())).reshape(
-1, 1, 1, 1).detach()
l2_loss_circle = (K3**2).sum() - (K3[:, :, 1:2, 1:2]**2).sum(
) # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
eq_kernel = K3[:, :, 1:2, 1:
2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
l2_loss_eq_kernel = (eq_kernel**2 / (t3**2 + t1**2)).sum(
) # Normalize for an L2 coefficient comparable to regular L2.
return l2_loss_eq_kernel + l2_loss_circle
# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
# You can get the equivalent kernel and bias at any time and do whatever you want,
# for example, apply some penalties or constraints during training, just like you do to the other models.
# May be useful for quantization or pruning.
def get_equivalent_kernel_bias(self):
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
if kernel1x1 is None:
return 0
else:
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
def _fuse_bn_tensor(self, branch):
if branch is None:
return 0, 0
if isinstance(branch, nn.Sequential):
kernel = branch.conv.weight
running_mean = branch.bn.running_mean
running_var = branch.bn.running_var
gamma = branch.bn.weight
beta = branch.bn.bias
eps = branch.bn.eps
else:
assert isinstance(branch, nn.BatchNorm2d)
if not hasattr(self, 'id_tensor'):
input_dim = self.in_channels // self.groups
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
dtype=np.float32)
for i in range(self.in_channels):
kernel_value[i, i % input_dim, 1, 1] = 1
self.id_tensor = torch.from_numpy(kernel_value).to(
branch.weight.device)
kernel = self.id_tensor
running_mean = branch.running_mean
running_var = branch.running_var
gamma = branch.weight
beta = branch.bias
eps = branch.eps
std = (running_var + eps).sqrt()
t = (gamma / std).reshape(-1, 1, 1, 1)
return kernel * t, beta - running_mean * gamma / std
def switch_to_deploy(self):
if hasattr(self, 'rbr_reparam'):
return
kernel, bias = self.get_equivalent_kernel_bias()
self.rbr_reparam = nn.Conv2d(
in_channels=self.rbr_dense.conv.in_channels,
out_channels=self.rbr_dense.conv.out_channels,
kernel_size=self.rbr_dense.conv.kernel_size,
stride=self.rbr_dense.conv.stride,
padding=self.rbr_dense.conv.padding,
dilation=self.rbr_dense.conv.dilation,
groups=self.rbr_dense.conv.groups,
bias=True)
self.rbr_reparam.weight.data = kernel
self.rbr_reparam.bias.data = bias
for para in self.parameters():
para.detach_()
self.__delattr__('rbr_dense')
self.__delattr__('rbr_1x1')
if hasattr(self, 'rbr_identity'):
self.__delattr__('rbr_identity')
if hasattr(self, 'id_tensor'):
self.__delattr__('id_tensor')
self.deploy = True
class ConvBNAct(nn.Module):
'''Normal Conv with SiLU activation'''
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
bias=False,
act='relu'):
super().__init__()
padding = kernel_size // 2
self.conv = nn.Conv2d(
in_channels,
out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
groups=groups,
bias=bias,
)
self.bn = nn.BatchNorm2d(out_channels)
if act == 'relu':
self.act = nn.ReLU()
if act == 'silu':
self.act = nn.SiLU()
def forward(self, x):
return self.act(self.bn(self.conv(x)))
def forward_fuse(self, x):
return self.act(self.conv(x))
class ConvBNReLU(ConvBNAct):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
bias=False):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
groups=groups,
bias=bias,
act='relu')
class ConvBNSiLU(ConvBNAct):
def __init__(self,
in_channels,
out_channels,
kernel_size,
stride,
groups=1,
bias=False):
super().__init__(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
groups=groups,
bias=bias,
act='silu')
class MT_SPPF(nn.Module):
'''Simplified SPPF with ReLU activation'''
def __init__(self, in_channels, out_channels, kernel_size=5):
super().__init__()
c_ = in_channels // 2 # hidden channels
self.cv1 = ConvBNReLU(in_channels, c_, 1, 1)
self.cv2 = ConvBNReLU(c_ * 4, out_channels, 1, 1)
self.maxpool = nn.MaxPool2d(
kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
def forward(self, x):
x = self.cv1(x)
with warnings.catch_warnings():
warnings.simplefilter('ignore')
y1 = self.maxpool(x)
y2 = self.maxpool(y1)
return self.cv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1))
class RepVGGYOLOX(nn.Module):
'''
RepVGG with MT_SPPF to build a efficient Yolox backbone
'''
def __init__(
self,
in_channels=3,
depth=1.0,
width=1.0,
):
super().__init__()
num_repeat_backbone = [1, 6, 12, 18, 6]
channels_list_backbone = [64, 128, 256, 512, 1024]
num_repeat_neck = [12, 12, 12, 12]
channels_list_neck = [256, 128, 128, 256, 256, 512]
num_repeats = [(max(round(i * depth), 1) if i > 1 else i)
for i in (num_repeat_backbone + num_repeat_neck)]
channels_list = [
make_divisible(i * width, 8)
for i in (channels_list_backbone + channels_list_neck)
]
assert channels_list is not None
assert num_repeats is not None
self.stage0 = RepVGGBlock(
in_channels=in_channels,
out_channels=channels_list[0],
ksize=3,
stride=2)
self.stage1 = self._make_stage(channels_list[0], channels_list[1],
num_repeats[1])
self.stage2 = self._make_stage(channels_list[1], channels_list[2],
num_repeats[2])
self.stage3 = self._make_stage(channels_list[2], channels_list[3],
num_repeats[3])
self.stage4 = self._make_stage(
channels_list[3], channels_list[4], num_repeats[4], add_ppf=True)
def _make_stage(self,
in_channels,
out_channels,
repeat,
stride=2,
add_ppf=False):
blocks = []
blocks.append(
RepVGGBlock(in_channels, out_channels, ksize=3, stride=stride))
for i in range(repeat):
blocks.append(RepVGGBlock(out_channels, out_channels))
if add_ppf:
blocks.append(MT_SPPF(out_channels, out_channels, kernel_size=5))
return nn.Sequential(*blocks)
def forward(self, x):
outputs = []
x = self.stage0(x)
x = self.stage1(x)
x = self.stage2(x)
outputs.append(x)
x = self.stage3(x)
outputs.append(x)
x = self.stage4(x)
outputs.append(x)
return tuple(outputs)