539 lines
19 KiB
Python
539 lines
19 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn.functional as F
|
|
import torch.utils.checkpoint as cp
|
|
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
|
from mmcv.runner import BaseModule, Sequential
|
|
from mmcv.utils.parrots_wrapper import _BatchNorm
|
|
|
|
from ..builder import BACKBONES
|
|
from ..utils.se_layer import SELayer
|
|
from .base_backbone import BaseBackbone
|
|
|
|
|
|
class RepVGGBlock(BaseModule):
|
|
"""RepVGG block for RepVGG backbone.
|
|
|
|
Args:
|
|
in_channels (int): The input channels of the block.
|
|
out_channels (int): The output channels of the block.
|
|
stride (int): Stride of the 3x3 and 1x1 convolution layer. Default: 1.
|
|
padding (int): Padding of the 3x3 convolution layer.
|
|
dilation (int): Dilation of the 3x3 convolution layer.
|
|
groups (int): Groups of the 3x3 and 1x1 convolution layer. Default: 1.
|
|
padding_mode (str): Padding mode of the 3x3 convolution layer.
|
|
Default: 'zeros'.
|
|
se_cfg (None or dict): The configuration of the se module.
|
|
Default: None.
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
conv_cfg (dict, optional): Config dict for convolution layer.
|
|
Default: None, which means using conv2d.
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
Default: dict(type='BN', requires_grad=True).
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='ReLU').
|
|
deploy (bool): Whether to switch the model structure to
|
|
deployment mode. Default: False.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
stride=1,
|
|
padding=1,
|
|
dilation=1,
|
|
groups=1,
|
|
padding_mode='zeros',
|
|
se_cfg=None,
|
|
with_cp=False,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'),
|
|
deploy=False,
|
|
init_cfg=None):
|
|
super(RepVGGBlock, self).__init__(init_cfg)
|
|
|
|
assert se_cfg is None or isinstance(se_cfg, dict)
|
|
|
|
self.in_channels = in_channels
|
|
self.out_channels = out_channels
|
|
self.stride = stride
|
|
self.padding = padding
|
|
self.dilation = dilation
|
|
self.groups = groups
|
|
self.se_cfg = se_cfg
|
|
self.with_cp = with_cp
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
self.deploy = deploy
|
|
|
|
if deploy:
|
|
self.branch_reparam = build_conv_layer(
|
|
conv_cfg,
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
stride=stride,
|
|
padding=padding,
|
|
dilation=dilation,
|
|
groups=groups,
|
|
bias=True,
|
|
padding_mode=padding_mode)
|
|
else:
|
|
# judge if input shape and output shape are the same.
|
|
# If true, add a normalized identity shortcut.
|
|
if out_channels == in_channels and stride == 1 and \
|
|
padding == dilation:
|
|
self.branch_norm = build_norm_layer(norm_cfg, in_channels)[1]
|
|
else:
|
|
self.branch_norm = None
|
|
|
|
self.branch_3x3 = self.create_conv_bn(
|
|
kernel_size=3,
|
|
dilation=dilation,
|
|
padding=padding,
|
|
)
|
|
self.branch_1x1 = self.create_conv_bn(kernel_size=1)
|
|
|
|
if se_cfg is not None:
|
|
self.se_layer = SELayer(channels=out_channels, **se_cfg)
|
|
else:
|
|
self.se_layer = None
|
|
|
|
self.act = build_activation_layer(act_cfg)
|
|
|
|
def create_conv_bn(self, kernel_size, dilation=1, padding=0):
|
|
conv_bn = Sequential()
|
|
conv_bn.add_module(
|
|
'conv',
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
in_channels=self.in_channels,
|
|
out_channels=self.out_channels,
|
|
kernel_size=kernel_size,
|
|
stride=self.stride,
|
|
dilation=dilation,
|
|
padding=padding,
|
|
groups=self.groups,
|
|
bias=False))
|
|
conv_bn.add_module(
|
|
'norm',
|
|
build_norm_layer(self.norm_cfg, num_features=self.out_channels)[1])
|
|
|
|
return conv_bn
|
|
|
|
def forward(self, x):
|
|
|
|
def _inner_forward(inputs):
|
|
if self.deploy:
|
|
return self.branch_reparam(inputs)
|
|
|
|
if self.branch_norm is None:
|
|
branch_norm_out = 0
|
|
else:
|
|
branch_norm_out = self.branch_norm(inputs)
|
|
|
|
inner_out = self.branch_3x3(inputs) + self.branch_1x1(
|
|
inputs) + branch_norm_out
|
|
|
|
if self.se_cfg is not None:
|
|
inner_out = self.se_layer(inner_out)
|
|
|
|
return inner_out
|
|
|
|
if self.with_cp and x.requires_grad:
|
|
out = cp.checkpoint(_inner_forward, x)
|
|
else:
|
|
out = _inner_forward(x)
|
|
|
|
out = self.act(out)
|
|
|
|
return out
|
|
|
|
def switch_to_deploy(self):
|
|
"""Switch the model structure from training mode to deployment mode."""
|
|
if self.deploy:
|
|
return
|
|
assert self.norm_cfg['type'] == 'BN', \
|
|
"Switch is not allowed when norm_cfg['type'] != 'BN'."
|
|
|
|
reparam_weight, reparam_bias = self.reparameterize()
|
|
self.branch_reparam = build_conv_layer(
|
|
self.conv_cfg,
|
|
self.in_channels,
|
|
self.out_channels,
|
|
kernel_size=3,
|
|
stride=self.stride,
|
|
padding=self.padding,
|
|
dilation=self.dilation,
|
|
groups=self.groups,
|
|
bias=True)
|
|
self.branch_reparam.weight.data = reparam_weight
|
|
self.branch_reparam.bias.data = reparam_bias
|
|
|
|
for param in self.parameters():
|
|
param.detach_()
|
|
delattr(self, 'branch_3x3')
|
|
delattr(self, 'branch_1x1')
|
|
delattr(self, 'branch_norm')
|
|
|
|
self.deploy = True
|
|
|
|
def reparameterize(self):
|
|
"""Fuse all the parameters of all branches.
|
|
|
|
Returns:
|
|
tuple[torch.Tensor, torch.Tensor]: Parameters after fusion of all
|
|
branches. the first element is the weights and the second is
|
|
the bias.
|
|
"""
|
|
weight_3x3, bias_3x3 = self._fuse_conv_bn(self.branch_3x3)
|
|
weight_1x1, bias_1x1 = self._fuse_conv_bn(self.branch_1x1)
|
|
# pad a conv1x1 weight to a conv3x3 weight
|
|
weight_1x1 = F.pad(weight_1x1, [1, 1, 1, 1], value=0)
|
|
|
|
weight_norm, bias_norm = 0, 0
|
|
if self.branch_norm:
|
|
tmp_conv_bn = self._norm_to_conv3x3(self.branch_norm)
|
|
weight_norm, bias_norm = self._fuse_conv_bn(tmp_conv_bn)
|
|
|
|
return (weight_3x3 + weight_1x1 + weight_norm,
|
|
bias_3x3 + bias_1x1 + bias_norm)
|
|
|
|
def _fuse_conv_bn(self, branch):
|
|
"""Fuse the parameters in a branch with a conv and bn.
|
|
|
|
Args:
|
|
branch (mmcv.runner.Sequential): A branch with conv and bn.
|
|
|
|
Returns:
|
|
tuple[torch.Tensor, torch.Tensor]: The parameters obtained after
|
|
fusing the parameters of conv and bn in one branch.
|
|
The first element is the weight and the second is the bias.
|
|
"""
|
|
if branch is None:
|
|
return 0, 0
|
|
conv_weight = branch.conv.weight
|
|
running_mean = branch.norm.running_mean
|
|
running_var = branch.norm.running_var
|
|
gamma = branch.norm.weight
|
|
beta = branch.norm.bias
|
|
eps = branch.norm.eps
|
|
|
|
std = (running_var + eps).sqrt()
|
|
fused_weight = (gamma / std).reshape(-1, 1, 1, 1) * conv_weight
|
|
fused_bias = -running_mean * gamma / std + beta
|
|
|
|
return fused_weight, fused_bias
|
|
|
|
def _norm_to_conv3x3(self, branch_nrom):
|
|
"""Convert a norm layer to a conv3x3-bn sequence.
|
|
|
|
Args:
|
|
branch (nn.BatchNorm2d): A branch only with bn in the block.
|
|
|
|
Returns:
|
|
tmp_conv3x3 (mmcv.runner.Sequential): a sequential with conv3x3 and
|
|
bn.
|
|
"""
|
|
input_dim = self.in_channels // self.groups
|
|
conv_weight = torch.zeros((self.in_channels, input_dim, 3, 3),
|
|
dtype=branch_nrom.weight.dtype)
|
|
|
|
for i in range(self.in_channels):
|
|
conv_weight[i, i % input_dim, 1, 1] = 1
|
|
conv_weight = conv_weight.to(branch_nrom.weight.device)
|
|
|
|
tmp_conv3x3 = self.create_conv_bn(kernel_size=3)
|
|
tmp_conv3x3.conv.weight.data = conv_weight
|
|
tmp_conv3x3.norm = branch_nrom
|
|
return tmp_conv3x3
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class RepVGG(BaseBackbone):
|
|
"""RepVGG backbone.
|
|
|
|
A PyTorch impl of : `RepVGG: Making VGG-style ConvNets Great Again
|
|
<https://arxiv.org/abs/2101.03697>`_
|
|
|
|
Args:
|
|
arch (str | dict): The parameter of RepVGG.
|
|
If it's a dict, it should contain the following keys:
|
|
|
|
- num_blocks (Sequence[int]): Number of blocks in each stage.
|
|
- width_factor (Sequence[float]): Width deflator in each stage.
|
|
- group_layer_map (dict | None): RepVGG Block that declares
|
|
the need to apply group convolution.
|
|
- se_cfg (dict | None): Se Layer config
|
|
in_channels (int): Number of input image channels. Default: 3.
|
|
base_channels (int): Base channels of RepVGG backbone, work
|
|
with width_factor together. Default: 64.
|
|
out_indices (Sequence[int]): Output from which stages. Default: (3, ).
|
|
strides (Sequence[int]): Strides of the first block of each stage.
|
|
Default: (2, 2, 2, 2).
|
|
dilations (Sequence[int]): Dilation of each stage.
|
|
Default: (1, 1, 1, 1).
|
|
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
|
not freezing any parameters. Default: -1.
|
|
conv_cfg (dict | None): The config dict for conv layers. Default: None.
|
|
norm_cfg (dict): The config dict for norm layers.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='ReLU').
|
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
|
memory while slowing down the training speed. Default: False.
|
|
deploy (bool): Whether to switch the model structure to deployment
|
|
mode. Default: False.
|
|
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
|
freeze running stats (mean and var). Note: Effect on Batch Norm
|
|
and its variants only. Default: False.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
"""
|
|
|
|
groupwise_layers = [2, 4, 6, 8, 10, 12, 14, 16, 18, 20, 22, 24, 26]
|
|
g2_layer_map = {layer: 2 for layer in groupwise_layers}
|
|
g4_layer_map = {layer: 4 for layer in groupwise_layers}
|
|
|
|
arch_settings = {
|
|
'A0':
|
|
dict(
|
|
num_blocks=[2, 4, 14, 1],
|
|
width_factor=[0.75, 0.75, 0.75, 2.5],
|
|
group_layer_map=None,
|
|
se_cfg=None),
|
|
'A1':
|
|
dict(
|
|
num_blocks=[2, 4, 14, 1],
|
|
width_factor=[1, 1, 1, 2.5],
|
|
group_layer_map=None,
|
|
se_cfg=None),
|
|
'A2':
|
|
dict(
|
|
num_blocks=[2, 4, 14, 1],
|
|
width_factor=[1.5, 1.5, 1.5, 2.75],
|
|
group_layer_map=None,
|
|
se_cfg=None),
|
|
'B0':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[1, 1, 1, 2.5],
|
|
group_layer_map=None,
|
|
se_cfg=None),
|
|
'B1':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[2, 2, 2, 4],
|
|
group_layer_map=None,
|
|
se_cfg=None),
|
|
'B1g2':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[2, 2, 2, 4],
|
|
group_layer_map=g2_layer_map,
|
|
se_cfg=None),
|
|
'B1g4':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[2, 2, 2, 4],
|
|
group_layer_map=g4_layer_map,
|
|
se_cfg=None),
|
|
'B2':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[2.5, 2.5, 2.5, 5],
|
|
group_layer_map=None,
|
|
se_cfg=None),
|
|
'B2g2':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[2.5, 2.5, 2.5, 5],
|
|
group_layer_map=g2_layer_map,
|
|
se_cfg=None),
|
|
'B2g4':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[2.5, 2.5, 2.5, 5],
|
|
group_layer_map=g4_layer_map,
|
|
se_cfg=None),
|
|
'B3':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[3, 3, 3, 5],
|
|
group_layer_map=None,
|
|
se_cfg=None),
|
|
'B3g2':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[3, 3, 3, 5],
|
|
group_layer_map=g2_layer_map,
|
|
se_cfg=None),
|
|
'B3g4':
|
|
dict(
|
|
num_blocks=[4, 6, 16, 1],
|
|
width_factor=[3, 3, 3, 5],
|
|
group_layer_map=g4_layer_map,
|
|
se_cfg=None),
|
|
'D2se':
|
|
dict(
|
|
num_blocks=[8, 14, 24, 1],
|
|
width_factor=[2.5, 2.5, 2.5, 5],
|
|
group_layer_map=None,
|
|
se_cfg=dict(ratio=16, divisor=1))
|
|
}
|
|
|
|
def __init__(self,
|
|
arch,
|
|
in_channels=3,
|
|
base_channels=64,
|
|
out_indices=(3, ),
|
|
strides=(2, 2, 2, 2),
|
|
dilations=(1, 1, 1, 1),
|
|
frozen_stages=-1,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'),
|
|
with_cp=False,
|
|
deploy=False,
|
|
norm_eval=False,
|
|
init_cfg=[
|
|
dict(type='Kaiming', layer=['Conv2d']),
|
|
dict(
|
|
type='Constant',
|
|
val=1,
|
|
layer=['_BatchNorm', 'GroupNorm'])
|
|
]):
|
|
super(RepVGG, self).__init__(init_cfg)
|
|
|
|
if isinstance(arch, str):
|
|
assert arch in self.arch_settings, \
|
|
f'"arch": "{arch}" is not one of the arch_settings'
|
|
arch = self.arch_settings[arch]
|
|
elif not isinstance(arch, dict):
|
|
raise TypeError('Expect "arch" to be either a string '
|
|
f'or a dict, got {type(arch)}')
|
|
|
|
assert len(arch['num_blocks']) == len(
|
|
arch['width_factor']) == len(strides) == len(dilations)
|
|
assert max(out_indices) < len(arch['num_blocks'])
|
|
if arch['group_layer_map'] is not None:
|
|
assert max(arch['group_layer_map'].keys()) <= sum(
|
|
arch['num_blocks'])
|
|
|
|
if arch['se_cfg'] is not None:
|
|
assert isinstance(arch['se_cfg'], dict)
|
|
|
|
self.arch = arch
|
|
self.in_channels = in_channels
|
|
self.base_channels = base_channels
|
|
self.out_indices = out_indices
|
|
self.strides = strides
|
|
self.dilations = dilations
|
|
self.deploy = deploy
|
|
self.frozen_stages = frozen_stages
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
self.with_cp = with_cp
|
|
self.norm_eval = norm_eval
|
|
|
|
channels = min(64, int(base_channels * self.arch['width_factor'][0]))
|
|
self.stem = RepVGGBlock(
|
|
self.in_channels,
|
|
channels,
|
|
stride=2,
|
|
se_cfg=arch['se_cfg'],
|
|
with_cp=with_cp,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
deploy=deploy)
|
|
|
|
next_create_block_idx = 1
|
|
self.stages = []
|
|
for i in range(len(arch['num_blocks'])):
|
|
num_blocks = self.arch['num_blocks'][i]
|
|
stride = self.strides[i]
|
|
dilation = self.dilations[i]
|
|
out_channels = int(base_channels * 2**i *
|
|
self.arch['width_factor'][i])
|
|
|
|
stage, next_create_block_idx = self._make_stage(
|
|
channels, out_channels, num_blocks, stride, dilation,
|
|
next_create_block_idx, init_cfg)
|
|
stage_name = f'stage_{i + 1}'
|
|
self.add_module(stage_name, stage)
|
|
self.stages.append(stage_name)
|
|
|
|
channels = out_channels
|
|
|
|
def _make_stage(self, in_channels, out_channels, num_blocks, stride,
|
|
dilation, next_create_block_idx, init_cfg):
|
|
strides = [stride] + [1] * (num_blocks - 1)
|
|
dilations = [dilation] * num_blocks
|
|
|
|
blocks = []
|
|
for i in range(num_blocks):
|
|
groups = self.arch['group_layer_map'].get(
|
|
next_create_block_idx,
|
|
1) if self.arch['group_layer_map'] is not None else 1
|
|
blocks.append(
|
|
RepVGGBlock(
|
|
in_channels,
|
|
out_channels,
|
|
stride=strides[i],
|
|
padding=dilations[i],
|
|
dilation=dilations[i],
|
|
groups=groups,
|
|
se_cfg=self.arch['se_cfg'],
|
|
with_cp=self.with_cp,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg,
|
|
deploy=self.deploy,
|
|
init_cfg=init_cfg))
|
|
in_channels = out_channels
|
|
next_create_block_idx += 1
|
|
|
|
return Sequential(*blocks), next_create_block_idx
|
|
|
|
def forward(self, x):
|
|
x = self.stem(x)
|
|
outs = []
|
|
for i, stage_name in enumerate(self.stages):
|
|
stage = getattr(self, stage_name)
|
|
x = stage(x)
|
|
if i in self.out_indices:
|
|
outs.append(x)
|
|
|
|
return tuple(outs)
|
|
|
|
def _freeze_stages(self):
|
|
if self.frozen_stages >= 0:
|
|
self.stem.eval()
|
|
for param in self.stem.parameters():
|
|
param.requires_grad = False
|
|
for i in range(self.frozen_stages):
|
|
stage = getattr(self, f'stage_{i+1}')
|
|
stage.eval()
|
|
for param in stage.parameters():
|
|
param.requires_grad = False
|
|
|
|
def train(self, mode=True):
|
|
super(RepVGG, self).train(mode)
|
|
self._freeze_stages()
|
|
if mode and self.norm_eval:
|
|
for m in self.modules():
|
|
if isinstance(m, _BatchNorm):
|
|
m.eval()
|
|
|
|
def switch_to_deploy(self):
|
|
for m in self.modules():
|
|
if isinstance(m, RepVGGBlock):
|
|
m.switch_to_deploy()
|
|
self.deploy = True
|