use ConvModule/build_*_layer from mmcv & use *_cfg & add _freeze_stages
parent
38a1c2533a
commit
72bbfbc489
|
@ -3,57 +3,14 @@ import logging
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn.bricks import ConvModule
|
||||
from mmcv.cnn.weight_init import constant_init, kaiming_init
|
||||
from mmcv.cnn import (ConvModule, build_conv_layer, build_activation_layer,
|
||||
constant_init, kaiming_init)
|
||||
from mmcv.runner import load_checkpoint
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
def conv3x3(inplanes, planes, stride=1, padding=1, bias=False, groups=1):
|
||||
"""3x3 convolution
|
||||
|
||||
Applies a 2D convolution with the kernel_size 3x3 over an input signal
|
||||
composed of several input planes.
|
||||
|
||||
Args:
|
||||
inplanes (int): Number of channels of the feature maps.
|
||||
planes (int): Number of channels produced by the convolution.
|
||||
stride (int or tuple, optional): Stride of the convolution.
|
||||
Default is 1.
|
||||
padding (int): Controls the amount of implicit zero-paddings on both
|
||||
sides for padding number of points for each dimension.
|
||||
Default is 1.
|
||||
bias (bool, optional): Whether to add a learnable bias to the output.
|
||||
Default is True
|
||||
groups (int, optional): Number of blocked connections from input
|
||||
channels to output channels. Default is 1
|
||||
"""
|
||||
return nn.Conv2d(
|
||||
inplanes,
|
||||
planes,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=padding,
|
||||
bias=bias,
|
||||
groups=groups)
|
||||
|
||||
|
||||
def conv1x1(inplanes, planes, groups=1):
|
||||
"""1x1 convolution
|
||||
|
||||
Applies a 2D convolution with the kernel_size 1x1 over an input signal
|
||||
composed of several input planes.
|
||||
|
||||
Args:
|
||||
inplanes (int): Number of channels of the input feature maps.
|
||||
planes (int): Number of channels produced by the convolution.
|
||||
groups (int, optional): Number of blocked connections from input
|
||||
channels to output channels. Default: 1
|
||||
"""
|
||||
return nn.Conv2d(inplanes, planes, kernel_size=1, groups=groups, stride=1)
|
||||
|
||||
|
||||
def channel_shuffle(x, groups):
|
||||
""" Channel Shuffle operation
|
||||
|
||||
|
@ -119,6 +76,11 @@ class ShuffleUnit(nn.Module):
|
|||
sequential ShuffleUnits. If True, use the grouped 1x1 convolution.
|
||||
combine (str, optional): The ways to combine the input and output
|
||||
branches.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
|
||||
|
@ -132,6 +94,9 @@ class ShuffleUnit(nn.Module):
|
|||
groups=3,
|
||||
first_block=True,
|
||||
combine='add',
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_cp=False):
|
||||
super(ShuffleUnit, self).__init__()
|
||||
self.inplanes = inplanes
|
||||
|
@ -160,25 +125,33 @@ class ShuffleUnit(nn.Module):
|
|||
in_channels=self.inplanes,
|
||||
out_channels=self.bottleneck_channels,
|
||||
kernel_size=1,
|
||||
groups=self.first_1x1_groups)
|
||||
groups=self.first_1x1_groups,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.depthwise_conv3x3 = conv3x3(
|
||||
self.bottleneck_channels,
|
||||
self.bottleneck_channels,
|
||||
self.depthwise_conv3x3_bn = ConvModule(
|
||||
in_channels=self.bottleneck_channels,
|
||||
out_channels=self.bottleneck_channels,
|
||||
kernel_size=3,
|
||||
stride=self.depthwise_stride,
|
||||
groups=self.bottleneck_channels)
|
||||
|
||||
self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels)
|
||||
padding=1,
|
||||
groups=self.bottleneck_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.g_conv_1x1_expand = ConvModule(
|
||||
in_channels=self.bottleneck_channels,
|
||||
out_channels=self.planes,
|
||||
kernel_size=1,
|
||||
groups=self.groups,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None)
|
||||
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
@staticmethod
|
||||
def _add(x, out):
|
||||
|
@ -200,12 +173,11 @@ class ShuffleUnit(nn.Module):
|
|||
|
||||
out = self.g_conv_1x1_compress(x)
|
||||
out = channel_shuffle(out, self.groups)
|
||||
out = self.depthwise_conv3x3(out)
|
||||
out = self.bn_after_depthwise(out)
|
||||
out = self.depthwise_conv3x3_bn(out)
|
||||
out = self.g_conv_1x1_expand(out)
|
||||
|
||||
out = self._combine_func(residual, out)
|
||||
|
||||
out = self.act(out)
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
|
@ -213,7 +185,6 @@ class ShuffleUnit(nn.Module):
|
|||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
@ -230,9 +201,15 @@ class ShuffleNetv1(BaseBackbone):
|
|||
out_indices (Sequence[int]): Output from which stages.
|
||||
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
|
||||
not freezing any parameters.
|
||||
bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze
|
||||
running stats (mean and var).
|
||||
bn_frozen (bool): Whether to freeze weight and bias of BN layers.
|
||||
conv_cfg (dict): Config dict for convolution layer. Default: None,
|
||||
which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only.
|
||||
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed.
|
||||
"""
|
||||
|
@ -242,8 +219,10 @@ class ShuffleNetv1(BaseBackbone):
|
|||
widen_factor=1.0,
|
||||
out_indices=(0, 1, 2, 3),
|
||||
frozen_stages=-1,
|
||||
bn_eval=True,
|
||||
bn_frozen=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=None,
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_eval=True,
|
||||
with_cp=False):
|
||||
super(ShuffleNetv1, self).__init__()
|
||||
blocks = [3, 7, 3]
|
||||
|
@ -254,13 +233,15 @@ class ShuffleNetv1(BaseBackbone):
|
|||
raise ValueError(f'the item in out_indices must in '
|
||||
f'range(0, 4). But received {indice}')
|
||||
|
||||
if frozen_stages not in [-1, 1, 2, 3]:
|
||||
raise ValueError(f'frozen_stages must in [-1, 1, 2, 3]. '
|
||||
if frozen_stages not in range(-1, 4):
|
||||
raise ValueError(f'frozen_stages must in range(-1, 4). '
|
||||
f'But received {frozen_stages}')
|
||||
self.out_indices = out_indices
|
||||
self.frozen_stages = frozen_stages
|
||||
self.bn_eval = bn_eval
|
||||
self.bn_frozen = bn_frozen
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
if groups == 1:
|
||||
|
@ -280,13 +261,22 @@ class ShuffleNetv1(BaseBackbone):
|
|||
channels = [_make_divisible(ch * widen_factor, 8) for ch in channels]
|
||||
|
||||
self.inplanes = int(24 * widen_factor)
|
||||
self.conv1 = conv3x3(3, self.inplanes, stride=2)
|
||||
|
||||
self.conv1 = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels=3,
|
||||
out_channels=self.inplanes,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False)
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.layer1 = self._make_layer(
|
||||
channels[0], blocks[0], first_block=False, with_cp=with_cp)
|
||||
self.layer2 = self._make_layer(channels[1], blocks[1], with_cp=with_cp)
|
||||
self.layer3 = self._make_layer(channels[2], blocks[2], with_cp=with_cp)
|
||||
self.layer1 = self._make_layer(channels[0], blocks[0],
|
||||
first_block=False)
|
||||
self.layer2 = self._make_layer(channels[1], blocks[1])
|
||||
self.layer3 = self._make_layer(channels[2], blocks[2])
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if isinstance(pretrained, str):
|
||||
|
@ -301,7 +291,7 @@ class ShuffleNetv1(BaseBackbone):
|
|||
else:
|
||||
raise TypeError('pretrained must be a str or None')
|
||||
|
||||
def _make_layer(self, outplanes, blocks, first_block=True, with_cp=False):
|
||||
def _make_layer(self, outplanes, blocks, first_block=True):
|
||||
""" Stack n bottleneck modules where n is inferred from the depth of
|
||||
the network.
|
||||
|
||||
|
@ -311,31 +301,23 @@ class ShuffleNetv1(BaseBackbone):
|
|||
first_block (bool, optional): Whether is the first ShuffleUnit of a
|
||||
sequential ShuffleUnits. If True, use the grouped 1x1
|
||||
convolution.
|
||||
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
|
||||
will save some memory while slowing down the training speed.
|
||||
|
||||
Returns: a Module consisting of n sequential ShuffleUnits.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(blocks):
|
||||
if i == 0:
|
||||
layers.append(
|
||||
ShuffleUnit(
|
||||
self.inplanes,
|
||||
outplanes,
|
||||
groups=self.groups,
|
||||
first_block=first_block,
|
||||
combine='concat',
|
||||
with_cp=with_cp))
|
||||
else:
|
||||
layers.append(
|
||||
ShuffleUnit(
|
||||
self.inplanes,
|
||||
outplanes,
|
||||
groups=self.groups,
|
||||
first_block=True,
|
||||
combine='add',
|
||||
with_cp=with_cp))
|
||||
first_block = first_block if i == 0 else True
|
||||
combine_mode = 'concat' if i == 0 else 'add'
|
||||
layers.append(
|
||||
ShuffleUnit(
|
||||
self.inplanes,
|
||||
outplanes,
|
||||
groups=self.groups,
|
||||
first_block=first_block,
|
||||
combine=combine_mode,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
with_cp=self.with_cp))
|
||||
|
||||
self.inplanes = outplanes
|
||||
|
||||
|
@ -363,20 +345,20 @@ class ShuffleNetv1(BaseBackbone):
|
|||
else:
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(ShuffleNetv1, self).train(mode)
|
||||
if self.bn_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
||||
if self.bn_frozen:
|
||||
for params in m.parameters():
|
||||
params.requires_grad = False
|
||||
if mode and self.frozen_stages >= 0:
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
for i in range(1, self.frozen_stages + 1):
|
||||
layer = getattr(self, f'layer{i}')
|
||||
layer.eval()
|
||||
for param in layer.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def train(self, mode=True):
|
||||
super(ShuffleNetv1, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, _BatchNorm):
|
||||
m.eval()
|
|
@ -67,11 +67,11 @@ def test_shufflenetv1_shuffleuint():
|
|||
def test_shufflenetv1_backbone():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# frozen_stages must in [-1, 1, 2, 3]
|
||||
# frozen_stages must in range(-1, 4)
|
||||
ShuffleNetv1(frozen_stages=10)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# the item in out_indices must in [0, 1, 2, 3]
|
||||
# the item in out_indices must in range(0, 4)
|
||||
ShuffleNetv1(out_indices=[5])
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
|
@ -100,20 +100,6 @@ def test_shufflenetv1_backbone():
|
|||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test ShuffleNetv1 with bn frozen
|
||||
model = ShuffleNetv1(bn_frozen=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for i in range(1, 4):
|
||||
layer = getattr(model, f'layer{i}')
|
||||
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for params in mod.parameters():
|
||||
params.requires_grad = False
|
||||
|
||||
# Test ShuffleNetv1 forward with groups=3
|
||||
model = ShuffleNetv1(groups=3)
|
||||
model.init_weights()
|
||||
|
@ -131,6 +117,25 @@ def test_shufflenetv1_backbone():
|
|||
assert feat[2].shape == torch.Size([1, 960, 7, 7])
|
||||
assert feat[3].shape == torch.Size([1, 960, 7, 7])
|
||||
|
||||
# Test ShuffleNetv1 forward with GroupNorm forward
|
||||
model = ShuffleNetv1(groups=3,
|
||||
norm_cfg=dict(type='GN', num_groups=2,
|
||||
requires_grad=True))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, GroupNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 4
|
||||
assert feat[0].shape == torch.Size([1, 240, 28, 28])
|
||||
assert feat[1].shape == torch.Size([1, 480, 14, 14])
|
||||
assert feat[2].shape == torch.Size([1, 960, 7, 7])
|
||||
assert feat[3].shape == torch.Size([1, 960, 7, 7])
|
||||
|
||||
# Test ShuffleNetv1 forward with layers 1, 2 forward
|
||||
model = ShuffleNetv1(groups=3, out_indices=(1, 2))
|
||||
model.init_weights()
|
||||
|
|
Loading…
Reference in New Issue