[Enhancement] RepVGG for YOLOX-PAI for dev-1.x. (#1126)
parent
d05cbbcf9b
commit
8cc1fdef52
|
@ -2,9 +2,11 @@
|
|||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
|
||||
build_norm_layer)
|
||||
from mmengine.model import BaseModule, Sequential
|
||||
from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm
|
||||
from torch import nn
|
||||
|
||||
from mmcls.registry import MODELS
|
||||
from ..utils.se_layer import SELayer
|
||||
|
@ -254,6 +256,51 @@ class RepVGGBlock(BaseModule):
|
|||
return tmp_conv3x3
|
||||
|
||||
|
||||
class MTSPPF(BaseModule):
|
||||
"""MTSPPF block for YOLOX-PAI RepVGG backbone.
|
||||
|
||||
Args:
|
||||
in_channels (int): The input channels of the block.
|
||||
out_channels (int): The output channels of the block.
|
||||
norm_cfg (dict): dictionary to construct and config norm layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
kernel_size (int): Kernel size of pooling. Default: 5.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
kernel_size=5):
|
||||
super().__init__()
|
||||
hidden_features = in_channels // 2 # hidden channels
|
||||
self.conv1 = ConvModule(
|
||||
in_channels,
|
||||
hidden_features,
|
||||
1,
|
||||
stride=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.conv2 = ConvModule(
|
||||
hidden_features * 4,
|
||||
out_channels,
|
||||
1,
|
||||
stride=1,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.maxpool = nn.MaxPool2d(
|
||||
kernel_size=kernel_size, stride=1, padding=kernel_size // 2)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
y1 = self.maxpool(x)
|
||||
y2 = self.maxpool(y1)
|
||||
return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1))
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class RepVGG(BaseBackbone):
|
||||
"""RepVGG backbone.
|
||||
|
@ -262,17 +309,22 @@ class RepVGG(BaseBackbone):
|
|||
<https://arxiv.org/abs/2101.03697>`_
|
||||
|
||||
Args:
|
||||
arch (str | dict): The parameter of RepVGG.
|
||||
If it's a dict, it should contain the following keys:
|
||||
|
||||
arch (str | dict): RepVGG architecture. If use string,
|
||||
choose from 'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2'
|
||||
, 'B2g2', 'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict,
|
||||
it should have below keys:
|
||||
- num_blocks (Sequence[int]): Number of blocks in each stage.
|
||||
- width_factor (Sequence[float]): Width deflator in each stage.
|
||||
- group_layer_map (dict | None): RepVGG Block that declares
|
||||
the need to apply group convolution.
|
||||
- se_cfg (dict | None): Se Layer config
|
||||
- se_cfg (dict | None): Se Layer config.
|
||||
- stem_channels (int, optional): The stem channels, the final
|
||||
stem channels will be
|
||||
``min(stem_channels, base_channels*width_factor[0])``.
|
||||
If not set here, 64 is used by default in the code.
|
||||
in_channels (int): Number of input image channels. Default: 3.
|
||||
base_channels (int): Base channels of RepVGG backbone, work
|
||||
with width_factor together. Default: 64.
|
||||
base_channels (int): Base channels of RepVGG backbone, work with
|
||||
width_factor together. Defaults to 64.
|
||||
out_indices (Sequence[int]): Output from which stages. Default: (3, ).
|
||||
strides (Sequence[int]): Strides of the first block of each stage.
|
||||
Default: (2, 2, 2, 2).
|
||||
|
@ -292,6 +344,7 @@ class RepVGG(BaseBackbone):
|
|||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
add_ppf (bool): Whether to use the MTSPPF block. Default: False.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
"""
|
||||
|
||||
|
@ -323,7 +376,8 @@ class RepVGG(BaseBackbone):
|
|||
num_blocks=[4, 6, 16, 1],
|
||||
width_factor=[1, 1, 1, 2.5],
|
||||
group_layer_map=None,
|
||||
se_cfg=None),
|
||||
se_cfg=None,
|
||||
stem_channels=64),
|
||||
'B1':
|
||||
dict(
|
||||
num_blocks=[4, 6, 16, 1],
|
||||
|
@ -383,7 +437,14 @@ class RepVGG(BaseBackbone):
|
|||
num_blocks=[8, 14, 24, 1],
|
||||
width_factor=[2.5, 2.5, 2.5, 5],
|
||||
group_layer_map=None,
|
||||
se_cfg=dict(ratio=16, divisor=1))
|
||||
se_cfg=dict(ratio=16, divisor=1)),
|
||||
'yolox-pai-small':
|
||||
dict(
|
||||
num_blocks=[3, 5, 7, 3],
|
||||
width_factor=[1, 1, 1, 1],
|
||||
group_layer_map=None,
|
||||
se_cfg=None,
|
||||
stem_channels=32),
|
||||
}
|
||||
|
||||
def __init__(self,
|
||||
|
@ -400,6 +461,7 @@ class RepVGG(BaseBackbone):
|
|||
with_cp=False,
|
||||
deploy=False,
|
||||
norm_eval=False,
|
||||
add_ppf=False,
|
||||
init_cfg=[
|
||||
dict(type='Kaiming', layer=['Conv2d']),
|
||||
dict(
|
||||
|
@ -427,9 +489,9 @@ class RepVGG(BaseBackbone):
|
|||
if arch['se_cfg'] is not None:
|
||||
assert isinstance(arch['se_cfg'], dict)
|
||||
|
||||
self.base_channels = base_channels
|
||||
self.arch = arch
|
||||
self.in_channels = in_channels
|
||||
self.base_channels = base_channels
|
||||
self.out_indices = out_indices
|
||||
self.strides = strides
|
||||
self.dilations = dilations
|
||||
|
@ -441,7 +503,12 @@ class RepVGG(BaseBackbone):
|
|||
self.with_cp = with_cp
|
||||
self.norm_eval = norm_eval
|
||||
|
||||
channels = min(64, int(base_channels * self.arch['width_factor'][0]))
|
||||
# defaults to 64 to prevert BC-breaking if stem_channels
|
||||
# not in arch dict;
|
||||
# the stem channels should not be larger than that of stage1.
|
||||
channels = min(
|
||||
arch.get('stem_channels', 64),
|
||||
int(self.base_channels * self.arch['width_factor'][0]))
|
||||
self.stem = RepVGGBlock(
|
||||
self.in_channels,
|
||||
channels,
|
||||
|
@ -459,7 +526,7 @@ class RepVGG(BaseBackbone):
|
|||
num_blocks = self.arch['num_blocks'][i]
|
||||
stride = self.strides[i]
|
||||
dilation = self.dilations[i]
|
||||
out_channels = int(base_channels * 2**i *
|
||||
out_channels = int(self.base_channels * 2**i *
|
||||
self.arch['width_factor'][i])
|
||||
|
||||
stage, next_create_block_idx = self._make_stage(
|
||||
|
@ -471,6 +538,16 @@ class RepVGG(BaseBackbone):
|
|||
|
||||
channels = out_channels
|
||||
|
||||
if add_ppf:
|
||||
self.ppf = MTSPPF(
|
||||
out_channels,
|
||||
out_channels,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
kernel_size=5)
|
||||
else:
|
||||
self.ppf = nn.Identity()
|
||||
|
||||
def _make_stage(self, in_channels, out_channels, num_blocks, stride,
|
||||
dilation, next_create_block_idx, init_cfg):
|
||||
strides = [stride] + [1] * (num_blocks - 1)
|
||||
|
@ -507,6 +584,8 @@ class RepVGG(BaseBackbone):
|
|||
for i, stage_name in enumerate(self.stages):
|
||||
stage = getattr(self, stage_name)
|
||||
x = stage(x)
|
||||
if i + 1 == len(self.stages):
|
||||
x = self.ppf(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
|
|
|
@ -202,18 +202,36 @@ def test_repvgg_backbone():
|
|||
# Test RepVGG forward with layer 3 forward
|
||||
model = RepVGG('A0', out_indices=(3, ))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
model.eval()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
imgs = torch.randn(1, 3, 32, 32)
|
||||
feat = model(imgs)
|
||||
assert isinstance(feat, tuple)
|
||||
assert len(feat) == 1
|
||||
assert isinstance(feat[0], torch.Tensor)
|
||||
assert feat[0].shape == torch.Size((1, 1280, 7, 7))
|
||||
assert feat[0].shape == torch.Size((1, 1280, 1, 1))
|
||||
|
||||
# Test with custom arch
|
||||
cfg = dict(
|
||||
num_blocks=[3, 5, 7, 3],
|
||||
width_factor=[1, 1, 1, 1],
|
||||
group_layer_map=None,
|
||||
se_cfg=None,
|
||||
stem_channels=16)
|
||||
model = RepVGG(arch=cfg, out_indices=(3, ))
|
||||
model.eval()
|
||||
assert model.stem.out_channels == min(16, 64 * 1)
|
||||
|
||||
imgs = torch.randn(1, 3, 32, 32)
|
||||
feat = model(imgs)
|
||||
assert isinstance(feat, tuple)
|
||||
assert len(feat) == 1
|
||||
assert isinstance(feat[0], torch.Tensor)
|
||||
assert feat[0].shape == torch.Size((1, 512, 1, 1))
|
||||
|
||||
# Test RepVGG forward
|
||||
model_test_settings = [
|
||||
|
@ -233,7 +251,7 @@ def test_repvgg_backbone():
|
|||
dict(model_name='D2se', out_sizes=(160, 320, 640, 2560))
|
||||
]
|
||||
|
||||
choose_models = ['A0', 'B1', 'B1g2', 'D2se']
|
||||
choose_models = ['A0', 'B1', 'B1g2']
|
||||
# Test RepVGG model forward
|
||||
for model_test_setting in model_test_settings:
|
||||
if model_test_setting['model_name'] not in choose_models:
|
||||
|
@ -241,23 +259,23 @@ def test_repvgg_backbone():
|
|||
model = RepVGG(
|
||||
model_test_setting['model_name'], out_indices=(0, 1, 2, 3))
|
||||
model.init_weights()
|
||||
model.eval()
|
||||
|
||||
# Test Norm
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
model.train()
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
imgs = torch.randn(1, 3, 32, 32)
|
||||
feat = model(imgs)
|
||||
assert feat[0].shape == torch.Size(
|
||||
(1, model_test_setting['out_sizes'][0], 56, 56))
|
||||
(1, model_test_setting['out_sizes'][0], 8, 8))
|
||||
assert feat[1].shape == torch.Size(
|
||||
(1, model_test_setting['out_sizes'][1], 28, 28))
|
||||
(1, model_test_setting['out_sizes'][1], 4, 4))
|
||||
assert feat[2].shape == torch.Size(
|
||||
(1, model_test_setting['out_sizes'][2], 14, 14))
|
||||
(1, model_test_setting['out_sizes'][2], 2, 2))
|
||||
assert feat[3].shape == torch.Size(
|
||||
(1, model_test_setting['out_sizes'][3], 7, 7))
|
||||
(1, model_test_setting['out_sizes'][3], 1, 1))
|
||||
|
||||
# Test eval of "train" mode and "deploy" mode
|
||||
gap = nn.AdaptiveAvgPool2d(output_size=(1))
|
||||
|
@ -275,11 +293,49 @@ def test_repvgg_backbone():
|
|||
torch.allclose(feat[i], feat_deploy[i])
|
||||
torch.allclose(pred, pred_deploy)
|
||||
|
||||
# Test RepVGG forward with add_ppf
|
||||
model = RepVGG('A0', out_indices=(3, ), add_ppf=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert isinstance(feat, tuple)
|
||||
assert len(feat) == 1
|
||||
assert isinstance(feat[0], torch.Tensor)
|
||||
assert feat[0].shape == torch.Size((1, 1280, 2, 2))
|
||||
|
||||
# Test RepVGG forward with 'stem_channels' not in arch
|
||||
arch = dict(
|
||||
num_blocks=[2, 4, 14, 1],
|
||||
width_factor=[0.75, 0.75, 0.75, 2.5],
|
||||
group_layer_map=None,
|
||||
se_cfg=None)
|
||||
model = RepVGG(arch, add_ppf=True)
|
||||
model.stem.in_channels = min(64, 64 * 0.75)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 64, 64)
|
||||
feat = model(imgs)
|
||||
assert isinstance(feat, tuple)
|
||||
assert len(feat) == 1
|
||||
assert isinstance(feat[0], torch.Tensor)
|
||||
assert feat[0].shape == torch.Size((1, 1280, 2, 2))
|
||||
|
||||
|
||||
def test_repvgg_load():
|
||||
# Test output before and load from deploy checkpoint
|
||||
model = RepVGG('A1', out_indices=(0, 1, 2, 3))
|
||||
inputs = torch.randn((1, 3, 224, 224))
|
||||
inputs = torch.randn((1, 3, 32, 32))
|
||||
ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth')
|
||||
model.switch_to_deploy()
|
||||
model.eval()
|
||||
|
|
Loading…
Reference in New Issue