mirror of https://github.com/alibaba/EasyCV.git
365 lines
12 KiB
Python
365 lines
12 KiB
Python
# borrow some code from https://github.com/DingXiaoH/RepVGG/repvgg.py MIT2.0
|
|
import warnings
|
|
|
|
import numpy as np
|
|
import torch
|
|
import torch.nn as nn
|
|
|
|
from easycv.models.utils.ops import make_divisible
|
|
|
|
|
|
def conv_bn(in_channels, out_channels, kernel_size, stride, padding, groups=1):
|
|
'''Basic cell for rep-style block, including conv and bn'''
|
|
result = nn.Sequential()
|
|
result.add_module(
|
|
'conv',
|
|
nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
groups=groups,
|
|
bias=False))
|
|
result.add_module('bn', nn.BatchNorm2d(num_features=out_channels))
|
|
return result
|
|
|
|
|
|
class RepVGGBlock(nn.Module):
|
|
"""
|
|
Basic Block of RepVGG
|
|
It's an efficient block that will be reparameterized in evaluation. (deploy = True)
|
|
Usage: RepVGGBlock(in_channels, out_channels, ksize=3, stride=stride)
|
|
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
ksize=3,
|
|
stride=1,
|
|
padding=1,
|
|
dilation=1,
|
|
groups=1,
|
|
padding_mode='zeros',
|
|
deploy=False,
|
|
act=None):
|
|
super(RepVGGBlock, self).__init__()
|
|
self.deploy = deploy
|
|
self.groups = groups
|
|
self.in_channels = in_channels
|
|
|
|
assert ksize == 3
|
|
assert padding == 1
|
|
|
|
padding_11 = padding - ksize // 2
|
|
|
|
self.nonlinearity = nn.ReLU()
|
|
self.se = nn.Identity()
|
|
|
|
if deploy:
|
|
self.rbr_reparam = nn.Conv2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=ksize,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=True,
|
|
padding_mode=padding_mode)
|
|
|
|
else:
|
|
self.rbr_identity = nn.BatchNorm2d(
|
|
num_features=in_channels
|
|
) if out_channels == in_channels and stride == 1 else None
|
|
self.rbr_dense = conv_bn(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=ksize,
|
|
stride=stride,
|
|
padding=padding,
|
|
groups=groups)
|
|
self.rbr_1x1 = conv_bn(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=1,
|
|
stride=stride,
|
|
padding=padding_11,
|
|
groups=groups)
|
|
|
|
def forward(self, inputs):
|
|
if hasattr(self, 'rbr_reparam'):
|
|
return self.nonlinearity(self.se(self.rbr_reparam(inputs)))
|
|
|
|
if self.rbr_identity is None:
|
|
id_out = 0
|
|
else:
|
|
id_out = self.rbr_identity(inputs)
|
|
|
|
return self.nonlinearity(
|
|
self.se(self.rbr_dense(inputs) + self.rbr_1x1(inputs) + id_out))
|
|
|
|
# Optional. This improves the accuracy and facilitates quantization.
|
|
# 1. Cancel the original weight decay on rbr_dense.conv.weight and rbr_1x1.conv.weight.
|
|
# 2. Use like this.
|
|
# loss = criterion(....)
|
|
# for every RepVGGBlock blk:
|
|
# loss += weight_decay_coefficient * 0.5 * blk.get_cust_L2()
|
|
# optimizer.zero_grad()
|
|
# loss.backward()
|
|
|
|
def get_custom_L2(self):
|
|
K3 = self.rbr_dense.conv.weight
|
|
K1 = self.rbr_1x1.conv.weight
|
|
t3 = (self.rbr_dense.bn.weight /
|
|
((self.rbr_dense.bn.running_var +
|
|
self.rbr_dense.bn.eps).sqrt())).reshape(-1, 1, 1, 1).detach()
|
|
t1 = (self.rbr_1x1.bn.weight / ((self.rbr_1x1.bn.running_var +
|
|
self.rbr_1x1.bn.eps).sqrt())).reshape(
|
|
-1, 1, 1, 1).detach()
|
|
|
|
l2_loss_circle = (K3**2).sum() - (K3[:, :, 1:2, 1:2]**2).sum(
|
|
) # The L2 loss of the "circle" of weights in 3x3 kernel. Use regular L2 on them.
|
|
eq_kernel = K3[:, :, 1:2, 1:
|
|
2] * t3 + K1 * t1 # The equivalent resultant central point of 3x3 kernel.
|
|
l2_loss_eq_kernel = (eq_kernel**2 / (t3**2 + t1**2)).sum(
|
|
) # Normalize for an L2 coefficient comparable to regular L2.
|
|
return l2_loss_eq_kernel + l2_loss_circle
|
|
|
|
# This func derives the equivalent kernel and bias in a DIFFERENTIABLE way.
|
|
# You can get the equivalent kernel and bias at any time and do whatever you want,
|
|
# for example, apply some penalties or constraints during training, just like you do to the other models.
|
|
# May be useful for quantization or pruning.
|
|
def get_equivalent_kernel_bias(self):
|
|
kernel3x3, bias3x3 = self._fuse_bn_tensor(self.rbr_dense)
|
|
kernel1x1, bias1x1 = self._fuse_bn_tensor(self.rbr_1x1)
|
|
kernelid, biasid = self._fuse_bn_tensor(self.rbr_identity)
|
|
return kernel3x3 + self._pad_1x1_to_3x3_tensor(
|
|
kernel1x1) + kernelid, bias3x3 + bias1x1 + biasid
|
|
|
|
def _pad_1x1_to_3x3_tensor(self, kernel1x1):
|
|
if kernel1x1 is None:
|
|
return 0
|
|
else:
|
|
return torch.nn.functional.pad(kernel1x1, [1, 1, 1, 1])
|
|
|
|
def _fuse_bn_tensor(self, branch):
|
|
if branch is None:
|
|
return 0, 0
|
|
if isinstance(branch, nn.Sequential):
|
|
kernel = branch.conv.weight
|
|
running_mean = branch.bn.running_mean
|
|
running_var = branch.bn.running_var
|
|
gamma = branch.bn.weight
|
|
beta = branch.bn.bias
|
|
eps = branch.bn.eps
|
|
else:
|
|
assert isinstance(branch, nn.BatchNorm2d)
|
|
if not hasattr(self, 'id_tensor'):
|
|
input_dim = self.in_channels // self.groups
|
|
kernel_value = np.zeros((self.in_channels, input_dim, 3, 3),
|
|
dtype=np.float32)
|
|
for i in range(self.in_channels):
|
|
kernel_value[i, i % input_dim, 1, 1] = 1
|
|
self.id_tensor = torch.from_numpy(kernel_value).to(
|
|
branch.weight.device)
|
|
kernel = self.id_tensor
|
|
running_mean = branch.running_mean
|
|
running_var = branch.running_var
|
|
gamma = branch.weight
|
|
beta = branch.bias
|
|
eps = branch.eps
|
|
std = (running_var + eps).sqrt()
|
|
t = (gamma / std).reshape(-1, 1, 1, 1)
|
|
return kernel * t, beta - running_mean * gamma / std
|
|
|
|
def switch_to_deploy(self):
|
|
if hasattr(self, 'rbr_reparam'):
|
|
return
|
|
kernel, bias = self.get_equivalent_kernel_bias()
|
|
self.rbr_reparam = nn.Conv2d(
|
|
in_channels=self.rbr_dense.conv.in_channels,
|
|
out_channels=self.rbr_dense.conv.out_channels,
|
|
kernel_size=self.rbr_dense.conv.kernel_size,
|
|
stride=self.rbr_dense.conv.stride,
|
|
padding=self.rbr_dense.conv.padding,
|
|
dilation=self.rbr_dense.conv.dilation,
|
|
groups=self.rbr_dense.conv.groups,
|
|
bias=True)
|
|
self.rbr_reparam.weight.data = kernel
|
|
self.rbr_reparam.bias.data = bias
|
|
for para in self.parameters():
|
|
para.detach_()
|
|
self.__delattr__('rbr_dense')
|
|
self.__delattr__('rbr_1x1')
|
|
if hasattr(self, 'rbr_identity'):
|
|
self.__delattr__('rbr_identity')
|
|
if hasattr(self, 'id_tensor'):
|
|
self.__delattr__('id_tensor')
|
|
self.deploy = True
|
|
|
|
|
|
class ConvBNAct(nn.Module):
|
|
'''Normal Conv with SiLU activation'''
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
groups=1,
|
|
bias=False,
|
|
act='relu'):
|
|
super().__init__()
|
|
padding = kernel_size // 2
|
|
self.conv = nn.Conv2d(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
padding=padding,
|
|
groups=groups,
|
|
bias=bias,
|
|
)
|
|
self.bn = nn.BatchNorm2d(out_channels)
|
|
|
|
if act == 'relu':
|
|
self.act = nn.ReLU()
|
|
if act == 'silu':
|
|
self.act = nn.SiLU()
|
|
|
|
def forward(self, x):
|
|
return self.act(self.bn(self.conv(x)))
|
|
|
|
def forward_fuse(self, x):
|
|
return self.act(self.conv(x))
|
|
|
|
|
|
class ConvBNReLU(ConvBNAct):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
groups=1,
|
|
bias=False):
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
groups=groups,
|
|
bias=bias,
|
|
act='relu')
|
|
|
|
|
|
class ConvBNSiLU(ConvBNAct):
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size,
|
|
stride,
|
|
groups=1,
|
|
bias=False):
|
|
super().__init__(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=stride,
|
|
groups=groups,
|
|
bias=bias,
|
|
act='silu')
|
|
|
|
|
|
class MT_SPPF(nn.Module):
|
|
'''Simplified SPPF with ReLU activation'''
|
|
|
|
def __init__(self, in_channels, out_channels, kernel_size=5):
|
|
super().__init__()
|
|
c_ = in_channels // 2 # hidden channels
|
|
self.cv1 = ConvBNReLU(in_channels, c_, 1, 1)
|
|
self.cv2 = ConvBNReLU(c_ * 4, out_channels, 1, 1)
|
|
self.maxpool = nn.MaxPool2d(
|
|
kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
|
|
|
|
def forward(self, x):
|
|
x = self.cv1(x)
|
|
with warnings.catch_warnings():
|
|
warnings.simplefilter('ignore')
|
|
y1 = self.maxpool(x)
|
|
y2 = self.maxpool(y1)
|
|
return self.cv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1))
|
|
|
|
|
|
class RepVGGYOLOX(nn.Module):
|
|
'''
|
|
RepVGG with MT_SPPF to build a efficient Yolox backbone
|
|
'''
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels=3,
|
|
depth=1.0,
|
|
width=1.0,
|
|
):
|
|
super().__init__()
|
|
num_repeat_backbone = [1, 6, 12, 18, 6]
|
|
channels_list_backbone = [64, 128, 256, 512, 1024]
|
|
num_repeat_neck = [12, 12, 12, 12]
|
|
channels_list_neck = [256, 128, 128, 256, 256, 512]
|
|
num_repeats = [(max(round(i * depth), 1) if i > 1 else i)
|
|
for i in (num_repeat_backbone + num_repeat_neck)]
|
|
|
|
channels_list = [
|
|
make_divisible(i * width, 8)
|
|
for i in (channels_list_backbone + channels_list_neck)
|
|
]
|
|
|
|
assert channels_list is not None
|
|
assert num_repeats is not None
|
|
|
|
self.stage0 = RepVGGBlock(
|
|
in_channels=in_channels,
|
|
out_channels=channels_list[0],
|
|
ksize=3,
|
|
stride=2)
|
|
self.stage1 = self._make_stage(channels_list[0], channels_list[1],
|
|
num_repeats[1])
|
|
self.stage2 = self._make_stage(channels_list[1], channels_list[2],
|
|
num_repeats[2])
|
|
self.stage3 = self._make_stage(channels_list[2], channels_list[3],
|
|
num_repeats[3])
|
|
self.stage4 = self._make_stage(
|
|
channels_list[3], channels_list[4], num_repeats[4], add_ppf=True)
|
|
|
|
def _make_stage(self,
|
|
in_channels,
|
|
out_channels,
|
|
repeat,
|
|
stride=2,
|
|
add_ppf=False):
|
|
blocks = []
|
|
blocks.append(
|
|
RepVGGBlock(in_channels, out_channels, ksize=3, stride=stride))
|
|
for i in range(repeat):
|
|
blocks.append(RepVGGBlock(out_channels, out_channels))
|
|
if add_ppf:
|
|
blocks.append(MT_SPPF(out_channels, out_channels, kernel_size=5))
|
|
|
|
return nn.Sequential(*blocks)
|
|
|
|
def forward(self, x):
|
|
outputs = []
|
|
x = self.stage0(x)
|
|
x = self.stage1(x)
|
|
x = self.stage2(x)
|
|
outputs.append(x)
|
|
x = self.stage3(x)
|
|
outputs.append(x)
|
|
x = self.stage4(x)
|
|
outputs.append(x)
|
|
return tuple(outputs)
|