fix norm_cfg&rewrite

This commit is contained in:
lixiaojie 2020-06-14 01:23:40 +08:00
parent 940c955523
commit fb3934fd2c
2 changed files with 178 additions and 103 deletions

View File

@ -1,11 +1,8 @@
import logging
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer,
constant_init, kaiming_init)
from mmcv.runner import load_checkpoint
from torch.nn.modules.batchnorm import _BatchNorm
from .base_backbone import BaseBackbone
@ -38,15 +35,15 @@ def channel_shuffle(x, groups):
return x
def _make_divisible(v, divisor, min_value=None):
def make_divisible(value, divisor, min_value=None):
""" Make divisible function.
This function ensures that all layers have a channel number that is
divisible by divisor.
Args:
v (int): The original channel number
divisor (int): The divisor to fully divide the channel number
value (int): The original channel number.
divisor (int): The divisor to fully divide the channel number.
min_value (int, optional): the minimum value of the output channel.
Returns:
@ -55,11 +52,11 @@ def _make_divisible(v, divisor, min_value=None):
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
if new_value < 0.9 * value:
new_value += divisor
return new_value
class ShuffleUnit(nn.Module):
@ -79,7 +76,8 @@ class ShuffleUnit(nn.Module):
branches.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
@ -96,7 +94,7 @@ class ShuffleUnit(nn.Module):
first_block=True,
combine='add',
conv_cfg=None,
norm_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
super(ShuffleUnit, self).__init__()
@ -203,7 +201,8 @@ class ShuffleNetv1(BaseBackbone):
not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
@ -219,12 +218,12 @@ class ShuffleNetv1(BaseBackbone):
out_indices=(0, 1, 2, 3),
frozen_stages=-1,
conv_cfg=None,
norm_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=True,
with_cp=False):
super(ShuffleNetv1, self).__init__()
blocks = [3, 7, 3]
self.stage_blocks = [3, 7, 3]
self.groups = groups
for indice in out_indices:
@ -257,7 +256,7 @@ class ShuffleNetv1(BaseBackbone):
raise ValueError(f'{groups} groups is not supported for 1x1 '
f'Grouped Convolutions')
channels = [_make_divisible(ch * widen_factor, 8) for ch in channels]
channels = [make_divisible(ch * widen_factor, 8) for ch in channels]
self.inplanes = int(24 * widen_factor)
@ -272,78 +271,13 @@ class ShuffleNetv1(BaseBackbone):
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(
channels[0], blocks[0], first_block=False)
self.layer2 = self._make_layer(channels[1], blocks[1])
self.layer3 = self._make_layer(channels[2], blocks[2])
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def _make_layer(self, outplanes, blocks, first_block=True):
""" Stack ShuffleUnit blocks to make a layer.
Args:
outplanes: Number of output channels.
blocks: Number of blocks to be built.
first_block (bool, optional): Whether is the first ShuffleUnit of a
sequential ShuffleUnits. If True, use the grouped 1x1
convolution.
Returns:
Module: A module consisting of several ShuffleUnit blocks.
"""
layers = []
for i in range(blocks):
first_block = first_block if i == 0 else True
combine_mode = 'concat' if i == 0 else 'add'
layers.append(
ShuffleUnit(
self.inplanes,
outplanes,
groups=self.groups,
first_block=first_block,
combine=combine_mode,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.inplanes = outplanes
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
outs = []
x = self.layer1(x)
if 0 in self.out_indices:
outs.append(x)
x = self.layer2(x)
if 1 in self.out_indices:
outs.append(x)
x = self.layer3(x)
if 2 in self.out_indices:
outs.append(x)
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
self.layers = []
for i, num_blocks in enumerate(self.stage_blocks):
first_block = False if i == 0 else True
layer = self.make_layer(channels[i], num_blocks, first_block)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, layer)
self.layers.append(layer_name)
def _freeze_stages(self):
if self.frozen_stages >= 0:
@ -355,6 +289,61 @@ class ShuffleNetv1(BaseBackbone):
for param in layer.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def make_layer(self, planes, num_blocks, first_block=True):
""" Stack ShuffleUnit blocks to make a layer.
Args:
planes (int): planes of block.
num_blocks (int): number of blocks.
first_block (bool, optional): Whether is the first ShuffleUnit of a
sequential ShuffleUnits. If True, use the grouped 1x1
convolution.
"""
layers = []
for i in range(num_blocks):
first_block = first_block if i == 0 else True
combine_mode = 'concat' if i == 0 else 'add'
layers.append(
ShuffleUnit(
self.inplanes,
planes,
groups=self.groups,
first_block=first_block,
combine=combine_mode,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.inplanes = planes
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
outs = []
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
else:
return tuple(outs)
def train(self, mode=True):
super(ShuffleNetv1, self).train(mode)
self._freeze_stages()

View File

@ -4,7 +4,7 @@ from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import ShuffleNetv1
from mmcls.models.backbones.shufflenet_v1 import ShuffleUnit
from mmcls.models.backbones.shufflenet_v1 import ShuffleUnit, make_divisible
def is_block(modules):
@ -30,6 +30,14 @@ def check_norm_state(modules, train_state):
return True
def test_make_divisible():
# test min_value is None
make_divisible(34, 8, None)
# test new_value < 0.9 * value
make_divisible(10, 8, None)
def test_shufflenetv1_shuffleuint():
with pytest.raises(ValueError):
@ -59,7 +67,9 @@ def test_shufflenetv1_shuffleuint():
# Test ShuffleUnit with checkpoint forward
block = ShuffleUnit(
24, 24, groups=3, first_block=True, combine='add', with_cp=True)
assert block.with_cp
x = torch.randn(1, 24, 56, 56)
x.requires_grad = True
x_out = block(x)
assert x_out.shape == torch.Size((1, 24, 56, 56))
@ -78,6 +88,11 @@ def test_shufflenetv1_backbone():
# groups must in [1, 2, 3, 4, 8]
ShuffleNetv1(groups=10)
with pytest.raises(TypeError):
# pretrained must be str or None
model = ShuffleNetv1()
model.init_weights(pretrained=1)
# Test ShuffleNetv1 norm state
model = ShuffleNetv1()
model.init_weights()
@ -100,6 +115,38 @@ def test_shufflenetv1_backbone():
for param in layer.parameters():
assert param.requires_grad is False
# Test ShuffleNetv1 forward with groups=1
model = ShuffleNetv1(groups=1)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 144, 28, 28))
assert feat[1].shape == torch.Size((1, 288, 14, 14))
assert feat[2].shape == torch.Size((1, 576, 7, 7))
# Test ShuffleNetv1 forward with groups=2
model = ShuffleNetv1(groups=2)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 200, 28, 28))
assert feat[1].shape == torch.Size((1, 400, 14, 14))
assert feat[2].shape == torch.Size((1, 800, 7, 7))
# Test ShuffleNetv1 forward with groups=3
model = ShuffleNetv1(groups=3)
model.init_weights()
@ -111,11 +158,42 @@ def test_shufflenetv1_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 240, 28, 28))
assert feat[1].shape == torch.Size((1, 480, 14, 14))
assert feat[2].shape == torch.Size((1, 960, 7, 7))
assert feat[3].shape == torch.Size((1, 960, 7, 7))
# Test ShuffleNetv1 forward with groups=4
model = ShuffleNetv1(groups=4)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 272, 28, 28))
assert feat[1].shape == torch.Size((1, 544, 14, 14))
assert feat[2].shape == torch.Size((1, 1088, 7, 7))
# Test ShuffleNetv1 forward with groups=8
model = ShuffleNetv1(groups=8)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 384, 28, 28))
assert feat[1].shape == torch.Size((1, 768, 14, 14))
assert feat[2].shape == torch.Size((1, 1536, 7, 7))
# Test ShuffleNetv1 forward with GroupNorm forward
model = ShuffleNetv1(
@ -129,11 +207,10 @@ def test_shufflenetv1_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 240, 28, 28))
assert feat[1].shape == torch.Size((1, 480, 14, 14))
assert feat[2].shape == torch.Size((1, 960, 7, 7))
assert feat[3].shape == torch.Size((1, 960, 7, 7))
# Test ShuffleNetv1 forward with layers 1, 2 forward
model = ShuffleNetv1(groups=3, out_indices=(1, 2))
@ -144,15 +221,14 @@ def test_shufflenetv1_backbone():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert len(feat) == 2
assert feat[0].shape == torch.Size((1, 480, 14, 14))
assert feat[1].shape == torch.Size((1, 960, 7, 7))
assert feat[2].shape == torch.Size((1, 960, 7, 7))
# Test ShuffleNetv1 forward with checkpoint forward
model = ShuffleNetv1(groups=3, with_cp=True)
# Test ShuffleNetv1 forward with layers 2 forward
model = ShuffleNetv1(groups=3, out_indices=(2,))
model.init_weights()
model.train()
@ -162,8 +238,18 @@ def test_shufflenetv1_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size((1, 240, 28, 28))
assert feat[1].shape == torch.Size((1, 480, 14, 14))
assert feat[2].shape == torch.Size((1, 960, 7, 7))
assert feat[3].shape == torch.Size((1, 960, 7, 7))
assert isinstance(feat, torch.Tensor)
assert feat.shape == torch.Size((1, 960, 7, 7))
# Test ShuffleNetv1 forward with checkpoint forward
model = ShuffleNetv1(groups=3, with_cp=True)
for m in model.modules():
if is_block(m):
assert m.with_cp
# Test ShuffleNetv1 with norm_eval
model = ShuffleNetv1(norm_eval=True)
model.init_weights()
model.train()
assert check_norm_state(model.modules(), False)