From 72bbfbc489d367c74a89c553a1ea46f4eded104c Mon Sep 17 00:00:00 2001 From: lixiaojie Date: Thu, 11 Jun 2020 15:55:17 +0800 Subject: [PATCH] use ConvModule/build_*_layer from mmcv & use *_cfg & add _freeze_stages --- mmcls/models/backbones/shufflenet_v1.py | 200 +++++++++++------------- tests/test_backbone.py | 37 +++-- 2 files changed, 112 insertions(+), 125 deletions(-) diff --git a/mmcls/models/backbones/shufflenet_v1.py b/mmcls/models/backbones/shufflenet_v1.py index 44b01326..a0d66129 100644 --- a/mmcls/models/backbones/shufflenet_v1.py +++ b/mmcls/models/backbones/shufflenet_v1.py @@ -3,57 +3,14 @@ import logging import torch import torch.nn as nn import torch.utils.checkpoint as cp -from mmcv.cnn.bricks import ConvModule -from mmcv.cnn.weight_init import constant_init, kaiming_init +from mmcv.cnn import (ConvModule, build_conv_layer, build_activation_layer, + constant_init, kaiming_init) from mmcv.runner import load_checkpoint +from torch.nn.modules.batchnorm import _BatchNorm from .base_backbone import BaseBackbone -def conv3x3(inplanes, planes, stride=1, padding=1, bias=False, groups=1): - """3x3 convolution - - Applies a 2D convolution with the kernel_size 3x3 over an input signal - composed of several input planes. - - Args: - inplanes (int): Number of channels of the feature maps. - planes (int): Number of channels produced by the convolution. - stride (int or tuple, optional): Stride of the convolution. - Default is 1. - padding (int): Controls the amount of implicit zero-paddings on both - sides for padding number of points for each dimension. - Default is 1. - bias (bool, optional): Whether to add a learnable bias to the output. - Default is True - groups (int, optional): Number of blocked connections from input - channels to output channels. Default is 1 - """ - return nn.Conv2d( - inplanes, - planes, - kernel_size=3, - stride=stride, - padding=padding, - bias=bias, - groups=groups) - - -def conv1x1(inplanes, planes, groups=1): - """1x1 convolution - - Applies a 2D convolution with the kernel_size 1x1 over an input signal - composed of several input planes. - - Args: - inplanes (int): Number of channels of the input feature maps. - planes (int): Number of channels produced by the convolution. - groups (int, optional): Number of blocked connections from input - channels to output channels. Default: 1 - """ - return nn.Conv2d(inplanes, planes, kernel_size=1, groups=groups, stride=1) - - def channel_shuffle(x, groups): """ Channel Shuffle operation @@ -119,6 +76,11 @@ class ShuffleUnit(nn.Module): sequential ShuffleUnits. If True, use the grouped 1x1 convolution. combine (str, optional): The ways to combine the input and output branches. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). with_cp (bool, optional): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. @@ -132,6 +94,9 @@ class ShuffleUnit(nn.Module): groups=3, first_block=True, combine='add', + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), with_cp=False): super(ShuffleUnit, self).__init__() self.inplanes = inplanes @@ -160,25 +125,33 @@ class ShuffleUnit(nn.Module): in_channels=self.inplanes, out_channels=self.bottleneck_channels, kernel_size=1, - groups=self.first_1x1_groups) + groups=self.first_1x1_groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) - self.depthwise_conv3x3 = conv3x3( - self.bottleneck_channels, - self.bottleneck_channels, + self.depthwise_conv3x3_bn = ConvModule( + in_channels=self.bottleneck_channels, + out_channels=self.bottleneck_channels, + kernel_size=3, stride=self.depthwise_stride, - groups=self.bottleneck_channels) - - self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels) + padding=1, + groups=self.bottleneck_channels, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) self.g_conv_1x1_expand = ConvModule( in_channels=self.bottleneck_channels, out_channels=self.planes, kernel_size=1, groups=self.groups, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, act_cfg=None) self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) - self.relu = nn.ReLU(inplace=True) + self.act = build_activation_layer(act_cfg) @staticmethod def _add(x, out): @@ -200,12 +173,11 @@ class ShuffleUnit(nn.Module): out = self.g_conv_1x1_compress(x) out = channel_shuffle(out, self.groups) - out = self.depthwise_conv3x3(out) - out = self.bn_after_depthwise(out) + out = self.depthwise_conv3x3_bn(out) out = self.g_conv_1x1_expand(out) out = self._combine_func(residual, out) - + out = self.act(out) return out if self.with_cp and x.requires_grad: @@ -213,7 +185,6 @@ class ShuffleUnit(nn.Module): else: out = _inner_forward(x) - out = self.relu(out) return out @@ -230,9 +201,15 @@ class ShuffleNetv1(BaseBackbone): out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. - bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze - running stats (mean and var). - bn_frozen (bool): Whether to freeze weight and bias of BN layers. + conv_cfg (dict): Config dict for convolution layer. Default: None, + which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. Default: None. + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. """ @@ -242,8 +219,10 @@ class ShuffleNetv1(BaseBackbone): widen_factor=1.0, out_indices=(0, 1, 2, 3), frozen_stages=-1, - bn_eval=True, - bn_frozen=False, + conv_cfg=None, + norm_cfg=None, + act_cfg=dict(type='ReLU'), + norm_eval=True, with_cp=False): super(ShuffleNetv1, self).__init__() blocks = [3, 7, 3] @@ -254,13 +233,15 @@ class ShuffleNetv1(BaseBackbone): raise ValueError(f'the item in out_indices must in ' f'range(0, 4). But received {indice}') - if frozen_stages not in [-1, 1, 2, 3]: - raise ValueError(f'frozen_stages must in [-1, 1, 2, 3]. ' + if frozen_stages not in range(-1, 4): + raise ValueError(f'frozen_stages must in range(-1, 4). ' f'But received {frozen_stages}') self.out_indices = out_indices self.frozen_stages = frozen_stages - self.bn_eval = bn_eval - self.bn_frozen = bn_frozen + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval self.with_cp = with_cp if groups == 1: @@ -280,13 +261,22 @@ class ShuffleNetv1(BaseBackbone): channels = [_make_divisible(ch * widen_factor, 8) for ch in channels] self.inplanes = int(24 * widen_factor) - self.conv1 = conv3x3(3, self.inplanes, stride=2) + + self.conv1 = build_conv_layer( + self.conv_cfg, + in_channels=3, + out_channels=self.inplanes, + kernel_size=3, + stride=2, + padding=1, + bias=False) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) - self.layer1 = self._make_layer( - channels[0], blocks[0], first_block=False, with_cp=with_cp) - self.layer2 = self._make_layer(channels[1], blocks[1], with_cp=with_cp) - self.layer3 = self._make_layer(channels[2], blocks[2], with_cp=with_cp) + self.layer1 = self._make_layer(channels[0], blocks[0], + first_block=False) + self.layer2 = self._make_layer(channels[1], blocks[1]) + self.layer3 = self._make_layer(channels[2], blocks[2]) def init_weights(self, pretrained=None): if isinstance(pretrained, str): @@ -301,7 +291,7 @@ class ShuffleNetv1(BaseBackbone): else: raise TypeError('pretrained must be a str or None') - def _make_layer(self, outplanes, blocks, first_block=True, with_cp=False): + def _make_layer(self, outplanes, blocks, first_block=True): """ Stack n bottleneck modules where n is inferred from the depth of the network. @@ -311,31 +301,23 @@ class ShuffleNetv1(BaseBackbone): first_block (bool, optional): Whether is the first ShuffleUnit of a sequential ShuffleUnits. If True, use the grouped 1x1 convolution. - with_cp (bool, optional): Use checkpoint or not. Using checkpoint - will save some memory while slowing down the training speed. - Returns: a Module consisting of n sequential ShuffleUnits. """ layers = [] for i in range(blocks): - if i == 0: - layers.append( - ShuffleUnit( - self.inplanes, - outplanes, - groups=self.groups, - first_block=first_block, - combine='concat', - with_cp=with_cp)) - else: - layers.append( - ShuffleUnit( - self.inplanes, - outplanes, - groups=self.groups, - first_block=True, - combine='add', - with_cp=with_cp)) + first_block = first_block if i == 0 else True + combine_mode = 'concat' if i == 0 else 'add' + layers.append( + ShuffleUnit( + self.inplanes, + outplanes, + groups=self.groups, + first_block=first_block, + combine=combine_mode, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) self.inplanes = outplanes @@ -363,20 +345,20 @@ class ShuffleNetv1(BaseBackbone): else: return tuple(outs) - def train(self, mode=True): - super(ShuffleNetv1, self).train(mode) - if self.bn_eval: - for m in self.modules(): - if isinstance(m, nn.BatchNorm2d): - m.eval() - if self.bn_frozen: - for params in m.parameters(): - params.requires_grad = False - if mode and self.frozen_stages >= 0: + def _freeze_stages(self): + if self.frozen_stages >= 0: for param in self.conv1.parameters(): param.requires_grad = False - for i in range(1, self.frozen_stages + 1): - layer = getattr(self, f'layer{i}') - layer.eval() - for param in layer.parameters(): - param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(ShuffleNetv1, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() \ No newline at end of file diff --git a/tests/test_backbone.py b/tests/test_backbone.py index 0bd69ba1..c1863973 100644 --- a/tests/test_backbone.py +++ b/tests/test_backbone.py @@ -67,11 +67,11 @@ def test_shufflenetv1_shuffleuint(): def test_shufflenetv1_backbone(): with pytest.raises(ValueError): - # frozen_stages must in [-1, 1, 2, 3] + # frozen_stages must in range(-1, 4) ShuffleNetv1(frozen_stages=10) with pytest.raises(ValueError): - # the item in out_indices must in [0, 1, 2, 3] + # the item in out_indices must in range(0, 4) ShuffleNetv1(out_indices=[5]) with pytest.raises(ValueError): @@ -100,20 +100,6 @@ def test_shufflenetv1_backbone(): for param in layer.parameters(): assert param.requires_grad is False - # Test ShuffleNetv1 with bn frozen - model = ShuffleNetv1(bn_frozen=True) - model.init_weights() - model.train() - - for i in range(1, 4): - layer = getattr(model, f'layer{i}') - - for mod in layer.modules(): - if isinstance(mod, _BatchNorm): - assert mod.training is False - for params in mod.parameters(): - params.requires_grad = False - # Test ShuffleNetv1 forward with groups=3 model = ShuffleNetv1(groups=3) model.init_weights() @@ -131,6 +117,25 @@ def test_shufflenetv1_backbone(): assert feat[2].shape == torch.Size([1, 960, 7, 7]) assert feat[3].shape == torch.Size([1, 960, 7, 7]) + # Test ShuffleNetv1 forward with GroupNorm forward + model = ShuffleNetv1(groups=3, + norm_cfg=dict(type='GN', num_groups=2, + requires_grad=True)) + model.init_weights() + model.train() + + for m in model.modules(): + if is_norm(m): + assert isinstance(m, GroupNorm) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 4 + assert feat[0].shape == torch.Size([1, 240, 28, 28]) + assert feat[1].shape == torch.Size([1, 480, 14, 14]) + assert feat[2].shape == torch.Size([1, 960, 7, 7]) + assert feat[3].shape == torch.Size([1, 960, 7, 7]) + # Test ShuffleNetv1 forward with layers 1, 2 forward model = ShuffleNetv1(groups=3, out_indices=(1, 2)) model.init_weights()