From 8cc1fdef52ef2fd632a84b739cdbfc8bc2442255 Mon Sep 17 00:00:00 2001 From: takuoko Date: Fri, 4 Nov 2022 16:36:18 +0900 Subject: [PATCH] [Enhancement] RepVGG for YOLOX-PAI for dev-1.x. (#1126) --- mmcls/models/backbones/repvgg.py | 103 ++++++++++++++++-- .../test_models/test_backbones/test_repvgg.py | 78 +++++++++++-- 2 files changed, 158 insertions(+), 23 deletions(-) diff --git a/mmcls/models/backbones/repvgg.py b/mmcls/models/backbones/repvgg.py index dd78d5f2..51a760bc 100644 --- a/mmcls/models/backbones/repvgg.py +++ b/mmcls/models/backbones/repvgg.py @@ -2,9 +2,11 @@ import torch import torch.nn.functional as F import torch.utils.checkpoint as cp -from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer +from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, + build_norm_layer) from mmengine.model import BaseModule, Sequential from mmengine.utils.dl_utils.parrots_wrapper import _BatchNorm +from torch import nn from mmcls.registry import MODELS from ..utils.se_layer import SELayer @@ -254,6 +256,51 @@ class RepVGGBlock(BaseModule): return tmp_conv3x3 +class MTSPPF(BaseModule): + """MTSPPF block for YOLOX-PAI RepVGG backbone. + + Args: + in_channels (int): The input channels of the block. + out_channels (int): The output channels of the block. + norm_cfg (dict): dictionary to construct and config norm layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + kernel_size (int): Kernel size of pooling. Default: 5. + """ + + def __init__(self, + in_channels, + out_channels, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + kernel_size=5): + super().__init__() + hidden_features = in_channels // 2 # hidden channels + self.conv1 = ConvModule( + in_channels, + hidden_features, + 1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.conv2 = ConvModule( + hidden_features * 4, + out_channels, + 1, + stride=1, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + self.maxpool = nn.MaxPool2d( + kernel_size=kernel_size, stride=1, padding=kernel_size // 2) + + def forward(self, x): + x = self.conv1(x) + y1 = self.maxpool(x) + y2 = self.maxpool(y1) + return self.conv2(torch.cat([x, y1, y2, self.maxpool(y2)], 1)) + + @MODELS.register_module() class RepVGG(BaseBackbone): """RepVGG backbone. @@ -262,17 +309,22 @@ class RepVGG(BaseBackbone): `_ Args: - arch (str | dict): The parameter of RepVGG. - If it's a dict, it should contain the following keys: - + arch (str | dict): RepVGG architecture. If use string, + choose from 'A0', 'A1`', 'A2', 'B0', 'B1', 'B1g2', 'B1g4', 'B2' + , 'B2g2', 'B2g4', 'B3', 'B3g2', 'B3g4' or 'D2se'. If use dict, + it should have below keys: - num_blocks (Sequence[int]): Number of blocks in each stage. - width_factor (Sequence[float]): Width deflator in each stage. - group_layer_map (dict | None): RepVGG Block that declares the need to apply group convolution. - - se_cfg (dict | None): Se Layer config + - se_cfg (dict | None): Se Layer config. + - stem_channels (int, optional): The stem channels, the final + stem channels will be + ``min(stem_channels, base_channels*width_factor[0])``. + If not set here, 64 is used by default in the code. in_channels (int): Number of input image channels. Default: 3. - base_channels (int): Base channels of RepVGG backbone, work - with width_factor together. Default: 64. + base_channels (int): Base channels of RepVGG backbone, work with + width_factor together. Defaults to 64. out_indices (Sequence[int]): Output from which stages. Default: (3, ). strides (Sequence[int]): Strides of the first block of each stage. Default: (2, 2, 2, 2). @@ -292,6 +344,7 @@ class RepVGG(BaseBackbone): norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. + add_ppf (bool): Whether to use the MTSPPF block. Default: False. init_cfg (dict or list[dict], optional): Initialization config dict. """ @@ -323,7 +376,8 @@ class RepVGG(BaseBackbone): num_blocks=[4, 6, 16, 1], width_factor=[1, 1, 1, 2.5], group_layer_map=None, - se_cfg=None), + se_cfg=None, + stem_channels=64), 'B1': dict( num_blocks=[4, 6, 16, 1], @@ -383,7 +437,14 @@ class RepVGG(BaseBackbone): num_blocks=[8, 14, 24, 1], width_factor=[2.5, 2.5, 2.5, 5], group_layer_map=None, - se_cfg=dict(ratio=16, divisor=1)) + se_cfg=dict(ratio=16, divisor=1)), + 'yolox-pai-small': + dict( + num_blocks=[3, 5, 7, 3], + width_factor=[1, 1, 1, 1], + group_layer_map=None, + se_cfg=None, + stem_channels=32), } def __init__(self, @@ -400,6 +461,7 @@ class RepVGG(BaseBackbone): with_cp=False, deploy=False, norm_eval=False, + add_ppf=False, init_cfg=[ dict(type='Kaiming', layer=['Conv2d']), dict( @@ -427,9 +489,9 @@ class RepVGG(BaseBackbone): if arch['se_cfg'] is not None: assert isinstance(arch['se_cfg'], dict) + self.base_channels = base_channels self.arch = arch self.in_channels = in_channels - self.base_channels = base_channels self.out_indices = out_indices self.strides = strides self.dilations = dilations @@ -441,7 +503,12 @@ class RepVGG(BaseBackbone): self.with_cp = with_cp self.norm_eval = norm_eval - channels = min(64, int(base_channels * self.arch['width_factor'][0])) + # defaults to 64 to prevert BC-breaking if stem_channels + # not in arch dict; + # the stem channels should not be larger than that of stage1. + channels = min( + arch.get('stem_channels', 64), + int(self.base_channels * self.arch['width_factor'][0])) self.stem = RepVGGBlock( self.in_channels, channels, @@ -459,7 +526,7 @@ class RepVGG(BaseBackbone): num_blocks = self.arch['num_blocks'][i] stride = self.strides[i] dilation = self.dilations[i] - out_channels = int(base_channels * 2**i * + out_channels = int(self.base_channels * 2**i * self.arch['width_factor'][i]) stage, next_create_block_idx = self._make_stage( @@ -471,6 +538,16 @@ class RepVGG(BaseBackbone): channels = out_channels + if add_ppf: + self.ppf = MTSPPF( + out_channels, + out_channels, + norm_cfg=norm_cfg, + act_cfg=act_cfg, + kernel_size=5) + else: + self.ppf = nn.Identity() + def _make_stage(self, in_channels, out_channels, num_blocks, stride, dilation, next_create_block_idx, init_cfg): strides = [stride] + [1] * (num_blocks - 1) @@ -507,6 +584,8 @@ class RepVGG(BaseBackbone): for i, stage_name in enumerate(self.stages): stage = getattr(self, stage_name) x = stage(x) + if i + 1 == len(self.stages): + x = self.ppf(x) if i in self.out_indices: outs.append(x) diff --git a/tests/test_models/test_backbones/test_repvgg.py b/tests/test_models/test_backbones/test_repvgg.py index fe618364..7ac066ac 100644 --- a/tests/test_models/test_backbones/test_repvgg.py +++ b/tests/test_models/test_backbones/test_repvgg.py @@ -202,18 +202,36 @@ def test_repvgg_backbone(): # Test RepVGG forward with layer 3 forward model = RepVGG('A0', out_indices=(3, )) model.init_weights() - model.train() + model.eval() for m in model.modules(): if is_norm(m): assert isinstance(m, _BatchNorm) - imgs = torch.randn(1, 3, 224, 224) + imgs = torch.randn(1, 3, 32, 32) feat = model(imgs) assert isinstance(feat, tuple) assert len(feat) == 1 assert isinstance(feat[0], torch.Tensor) - assert feat[0].shape == torch.Size((1, 1280, 7, 7)) + assert feat[0].shape == torch.Size((1, 1280, 1, 1)) + + # Test with custom arch + cfg = dict( + num_blocks=[3, 5, 7, 3], + width_factor=[1, 1, 1, 1], + group_layer_map=None, + se_cfg=None, + stem_channels=16) + model = RepVGG(arch=cfg, out_indices=(3, )) + model.eval() + assert model.stem.out_channels == min(16, 64 * 1) + + imgs = torch.randn(1, 3, 32, 32) + feat = model(imgs) + assert isinstance(feat, tuple) + assert len(feat) == 1 + assert isinstance(feat[0], torch.Tensor) + assert feat[0].shape == torch.Size((1, 512, 1, 1)) # Test RepVGG forward model_test_settings = [ @@ -233,7 +251,7 @@ def test_repvgg_backbone(): dict(model_name='D2se', out_sizes=(160, 320, 640, 2560)) ] - choose_models = ['A0', 'B1', 'B1g2', 'D2se'] + choose_models = ['A0', 'B1', 'B1g2'] # Test RepVGG model forward for model_test_setting in model_test_settings: if model_test_setting['model_name'] not in choose_models: @@ -241,23 +259,23 @@ def test_repvgg_backbone(): model = RepVGG( model_test_setting['model_name'], out_indices=(0, 1, 2, 3)) model.init_weights() + model.eval() # Test Norm for m in model.modules(): if is_norm(m): assert isinstance(m, _BatchNorm) - model.train() - imgs = torch.randn(1, 3, 224, 224) + imgs = torch.randn(1, 3, 32, 32) feat = model(imgs) assert feat[0].shape == torch.Size( - (1, model_test_setting['out_sizes'][0], 56, 56)) + (1, model_test_setting['out_sizes'][0], 8, 8)) assert feat[1].shape == torch.Size( - (1, model_test_setting['out_sizes'][1], 28, 28)) + (1, model_test_setting['out_sizes'][1], 4, 4)) assert feat[2].shape == torch.Size( - (1, model_test_setting['out_sizes'][2], 14, 14)) + (1, model_test_setting['out_sizes'][2], 2, 2)) assert feat[3].shape == torch.Size( - (1, model_test_setting['out_sizes'][3], 7, 7)) + (1, model_test_setting['out_sizes'][3], 1, 1)) # Test eval of "train" mode and "deploy" mode gap = nn.AdaptiveAvgPool2d(output_size=(1)) @@ -275,11 +293,49 @@ def test_repvgg_backbone(): torch.allclose(feat[i], feat_deploy[i]) torch.allclose(pred, pred_deploy) + # Test RepVGG forward with add_ppf + model = RepVGG('A0', out_indices=(3, ), add_ppf=True) + model.init_weights() + model.train() + + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + + imgs = torch.randn(1, 3, 64, 64) + feat = model(imgs) + assert isinstance(feat, tuple) + assert len(feat) == 1 + assert isinstance(feat[0], torch.Tensor) + assert feat[0].shape == torch.Size((1, 1280, 2, 2)) + + # Test RepVGG forward with 'stem_channels' not in arch + arch = dict( + num_blocks=[2, 4, 14, 1], + width_factor=[0.75, 0.75, 0.75, 2.5], + group_layer_map=None, + se_cfg=None) + model = RepVGG(arch, add_ppf=True) + model.stem.in_channels = min(64, 64 * 0.75) + model.init_weights() + model.train() + + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + + imgs = torch.randn(1, 3, 64, 64) + feat = model(imgs) + assert isinstance(feat, tuple) + assert len(feat) == 1 + assert isinstance(feat[0], torch.Tensor) + assert feat[0].shape == torch.Size((1, 1280, 2, 2)) + def test_repvgg_load(): # Test output before and load from deploy checkpoint model = RepVGG('A1', out_indices=(0, 1, 2, 3)) - inputs = torch.randn((1, 3, 224, 224)) + inputs = torch.randn((1, 3, 32, 32)) ckpt_path = os.path.join(tempfile.gettempdir(), 'ckpt.pth') model.switch_to_deploy() model.eval()