use ConvModule/build_*_layer from mmcv & use *_cfg & add _freeze_stages

pull/2/head
lixiaojie 2020-06-11 15:55:17 +08:00
parent 38a1c2533a
commit 72bbfbc489
2 changed files with 112 additions and 125 deletions

View File

@ -3,57 +3,14 @@ import logging
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn.bricks import ConvModule
from mmcv.cnn.weight_init import constant_init, kaiming_init
from mmcv.cnn import (ConvModule, build_conv_layer, build_activation_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from .base_backbone import BaseBackbone
def conv3x3(inplanes, planes, stride=1, padding=1, bias=False, groups=1):
"""3x3 convolution
Applies a 2D convolution with the kernel_size 3x3 over an input signal
composed of several input planes.
Args:
inplanes (int): Number of channels of the feature maps.
planes (int): Number of channels produced by the convolution.
stride (int or tuple, optional): Stride of the convolution.
Default is 1.
padding (int): Controls the amount of implicit zero-paddings on both
sides for padding number of points for each dimension.
Default is 1.
bias (bool, optional): Whether to add a learnable bias to the output.
Default is True
groups (int, optional): Number of blocked connections from input
channels to output channels. Default is 1
"""
return nn.Conv2d(
inplanes,
planes,
kernel_size=3,
stride=stride,
padding=padding,
bias=bias,
groups=groups)
def conv1x1(inplanes, planes, groups=1):
"""1x1 convolution
Applies a 2D convolution with the kernel_size 1x1 over an input signal
composed of several input planes.
Args:
inplanes (int): Number of channels of the input feature maps.
planes (int): Number of channels produced by the convolution.
groups (int, optional): Number of blocked connections from input
channels to output channels. Default: 1
"""
return nn.Conv2d(inplanes, planes, kernel_size=1, groups=groups, stride=1)
def channel_shuffle(x, groups):
""" Channel Shuffle operation
@ -119,6 +76,11 @@ class ShuffleUnit(nn.Module):
sequential ShuffleUnits. If True, use the grouped 1x1 convolution.
combine (str, optional): The ways to combine the input and output
branches.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
@ -132,6 +94,9 @@ class ShuffleUnit(nn.Module):
groups=3,
first_block=True,
combine='add',
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
with_cp=False):
super(ShuffleUnit, self).__init__()
self.inplanes = inplanes
@ -160,25 +125,33 @@ class ShuffleUnit(nn.Module):
in_channels=self.inplanes,
out_channels=self.bottleneck_channels,
kernel_size=1,
groups=self.first_1x1_groups)
groups=self.first_1x1_groups,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.depthwise_conv3x3 = conv3x3(
self.bottleneck_channels,
self.bottleneck_channels,
self.depthwise_conv3x3_bn = ConvModule(
in_channels=self.bottleneck_channels,
out_channels=self.bottleneck_channels,
kernel_size=3,
stride=self.depthwise_stride,
groups=self.bottleneck_channels)
self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)
padding=1,
groups=self.bottleneck_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.g_conv_1x1_expand = ConvModule(
in_channels=self.bottleneck_channels,
out_channels=self.planes,
kernel_size=1,
groups=self.groups,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None)
self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
self.relu = nn.ReLU(inplace=True)
self.act = build_activation_layer(act_cfg)
@staticmethod
def _add(x, out):
@ -200,12 +173,11 @@ class ShuffleUnit(nn.Module):
out = self.g_conv_1x1_compress(x)
out = channel_shuffle(out, self.groups)
out = self.depthwise_conv3x3(out)
out = self.bn_after_depthwise(out)
out = self.depthwise_conv3x3_bn(out)
out = self.g_conv_1x1_expand(out)
out = self._combine_func(residual, out)
out = self.act(out)
return out
if self.with_cp and x.requires_grad:
@ -213,7 +185,6 @@ class ShuffleUnit(nn.Module):
else:
out = _inner_forward(x)
out = self.relu(out)
return out
@ -230,9 +201,15 @@ class ShuffleNetv1(BaseBackbone):
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
"""
@ -242,8 +219,10 @@ class ShuffleNetv1(BaseBackbone):
widen_factor=1.0,
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
conv_cfg=None,
norm_cfg=None,
act_cfg=dict(type='ReLU'),
norm_eval=True,
with_cp=False):
super(ShuffleNetv1, self).__init__()
blocks = [3, 7, 3]
@ -254,13 +233,15 @@ class ShuffleNetv1(BaseBackbone):
raise ValueError(f'the item in out_indices must in '
f'range(0, 4). But received {indice}')
if frozen_stages not in [-1, 1, 2, 3]:
raise ValueError(f'frozen_stages must in [-1, 1, 2, 3]. '
if frozen_stages not in range(-1, 4):
raise ValueError(f'frozen_stages must in range(-1, 4). '
f'But received {frozen_stages}')
self.out_indices = out_indices
self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
if groups == 1:
@ -280,13 +261,22 @@ class ShuffleNetv1(BaseBackbone):
channels = [_make_divisible(ch * widen_factor, 8) for ch in channels]
self.inplanes = int(24 * widen_factor)
self.conv1 = conv3x3(3, self.inplanes, stride=2)
self.conv1 = build_conv_layer(
self.conv_cfg,
in_channels=3,
out_channels=self.inplanes,
kernel_size=3,
stride=2,
padding=1,
bias=False)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
channels[0], blocks[0], first_block=False, with_cp=with_cp)
self.layer2 = self._make_layer(channels[1], blocks[1], with_cp=with_cp)
self.layer3 = self._make_layer(channels[2], blocks[2], with_cp=with_cp)
self.layer1 = self._make_layer(channels[0], blocks[0],
first_block=False)
self.layer2 = self._make_layer(channels[1], blocks[1])
self.layer3 = self._make_layer(channels[2], blocks[2])
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
@ -301,7 +291,7 @@ class ShuffleNetv1(BaseBackbone):
else:
raise TypeError('pretrained must be a str or None')
def _make_layer(self, outplanes, blocks, first_block=True, with_cp=False):
def _make_layer(self, outplanes, blocks, first_block=True):
""" Stack n bottleneck modules where n is inferred from the depth of
the network.
@ -311,31 +301,23 @@ class ShuffleNetv1(BaseBackbone):
first_block (bool, optional): Whether is the first ShuffleUnit of a
sequential ShuffleUnits. If True, use the grouped 1x1
convolution.
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Returns: a Module consisting of n sequential ShuffleUnits.
"""
layers = []
for i in range(blocks):
if i == 0:
layers.append(
ShuffleUnit(
self.inplanes,
outplanes,
groups=self.groups,
first_block=first_block,
combine='concat',
with_cp=with_cp))
else:
layers.append(
ShuffleUnit(
self.inplanes,
outplanes,
groups=self.groups,
first_block=True,
combine='add',
with_cp=with_cp))
first_block = first_block if i == 0 else True
combine_mode = 'concat' if i == 0 else 'add'
layers.append(
ShuffleUnit(
self.inplanes,
outplanes,
groups=self.groups,
first_block=first_block,
combine=combine_mode,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.inplanes = outplanes
@ -363,20 +345,20 @@ class ShuffleNetv1(BaseBackbone):
else:
return tuple(outs)
def train(self, mode=True):
super(ShuffleNetv1, self).train(mode)
if self.bn_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
if mode and self.frozen_stages >= 0:
def _freeze_stages(self):
if self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
layer = getattr(self, f'layer{i}')
layer.eval()
for param in layer.parameters():
param.requires_grad = False
def train(self, mode=True):
super(ShuffleNetv1, self).train(mode)
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()

View File

@ -67,11 +67,11 @@ def test_shufflenetv1_shuffleuint():
def test_shufflenetv1_backbone():
with pytest.raises(ValueError):
# frozen_stages must in [-1, 1, 2, 3]
# frozen_stages must in range(-1, 4)
ShuffleNetv1(frozen_stages=10)
with pytest.raises(ValueError):
# the item in out_indices must in [0, 1, 2, 3]
# the item in out_indices must in range(0, 4)
ShuffleNetv1(out_indices=[5])
with pytest.raises(ValueError):
@ -100,20 +100,6 @@ def test_shufflenetv1_backbone():
for param in layer.parameters():
assert param.requires_grad is False
# Test ShuffleNetv1 with bn frozen
model = ShuffleNetv1(bn_frozen=True)
model.init_weights()
model.train()
for i in range(1, 4):
layer = getattr(model, f'layer{i}')
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for params in mod.parameters():
params.requires_grad = False
# Test ShuffleNetv1 forward with groups=3
model = ShuffleNetv1(groups=3)
model.init_weights()
@ -131,6 +117,25 @@ def test_shufflenetv1_backbone():
assert feat[2].shape == torch.Size([1, 960, 7, 7])
assert feat[3].shape == torch.Size([1, 960, 7, 7])
# Test ShuffleNetv1 forward with GroupNorm forward
model = ShuffleNetv1(groups=3,
norm_cfg=dict(type='GN', num_groups=2,
requires_grad=True))
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, GroupNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 240, 28, 28])
assert feat[1].shape == torch.Size([1, 480, 14, 14])
assert feat[2].shape == torch.Size([1, 960, 7, 7])
assert feat[3].shape == torch.Size([1, 960, 7, 7])
# Test ShuffleNetv1 forward with layers 1, 2 forward
model = ShuffleNetv1(groups=3, out_indices=(1, 2))
model.init_weights()