pull/2/head
lixiaojie 2020-06-14 00:46:22 +08:00
parent 75858a3d3e
commit e3e980d84e
3 changed files with 305 additions and 240 deletions
mmcls/models/backbones
tests/test_backbones

View File

@ -1,118 +1,151 @@
import logging
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from mmcv.runner import load_checkpoint
from mmcv.cnn import ConvModule, constant_init, kaiming_init
from torch.nn.modules.batchnorm import _BatchNorm
from .base_backbone import BaseBackbone
from .weight_init import constant_init, kaiming_init
def conv_bn(inp, oup, stride):
return nn.Sequential(
nn.Conv2d(inp, oup, 3, stride, 1, bias=False), nn.BatchNorm2d(oup),
nn.ReLU(inplace=True))
def conv_1x1_bn(inp, oup):
return nn.Sequential(
nn.Conv2d(inp, oup, 1, 1, 0, bias=False), nn.BatchNorm2d(oup),
nn.ReLU(inplace=True))
def channel_shuffle(x, groups):
batchsize, num_channels, height, width = x.data.size()
assert (num_channels % groups == 0)
""" Channel Shuffle operation.
This function enables cross-group information flow for multiple group
convolution layers.
Args:
x (Tensor): The input tensor.
groups (int): The number of groups to divide the input tensor
in channel dimension.
Returns:
Tensor: The output tensor after channel shuffle operation.
"""
batchsize, num_channels, height, width = x.size()
assert (num_channels % groups == 0), ('num_channels should be '
'divisible by groups')
channels_per_group = num_channels // groups
# reshape
x = x.view(batchsize, groups, channels_per_group, height, width)
# transpose
# - contiguous() required if transpose() is used before view().
# See https://github.com/pytorch/pytorch/issues/764
x = torch.transpose(x, 1, 2).contiguous()
# flatten
x = x.view(batchsize, -1, height, width)
return x
def _make_divisible(v, divisor, min_value=None):
def make_divisible(value, divisor, min_value=None):
""" Make divisible function.
This function ensures that all layers have a channel number that is
divisible by divisor.
Args:
value (int): The original channel number.
divisor (int): The divisor to fully divide the channel number.
min_value (int, optional): the minimum value of the output channel.
Returns:
int: The modified output channel number
"""
if min_value is None:
min_value = divisor
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
new_value = max(min_value, int(value + divisor / 2) // divisor * divisor)
# Make sure that round down does not go down by more than 10%.
if new_v < 0.9 * v:
new_v += divisor
return new_v
if new_value < 0.9 * value:
new_value += divisor
return new_value
class InvertedResidual(nn.Module):
"""InvertedResidual block for ShuffleNetV2 backbone.
def __init__(self, inp, oup, stride, with_cp=False):
Args:
inplanes (int): The input channels of the block.
planes (int): The output channels of the block.
stride (int): stride of the 3x3 convolution layer. Default: 1
conv_cfg (dict): Config dict for convolution layer. Default: None,
which means using conv2d.
norm_cfg (dict): Config dict for normalization layer. Default: None.
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
with_cp (bool, optional): Use checkpoint or not. Using checkpoint
will save some memory while slowing down the training speed.
Returns:
Tensor: The output tensor.
"""
def __init__(self,
inplanes,
planes,
stride=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
with_cp=False):
super(InvertedResidual, self).__init__()
if not (1 <= stride <= 2):
raise ValueError('illegal stride value')
self.stride = stride
self.with_cp = with_cp
branch_features = oup // 2
assert (self.stride != 1) or (inp == branch_features << 1)
branch_features = planes // 2
assert (self.stride != 1) or (inplanes == branch_features << 1)
if self.stride > 1:
self.branch1 = nn.Sequential(
self.depthwise_conv(
inp, inp, kernel_size=3, stride=self.stride, padding=1),
nn.BatchNorm2d(inp),
nn.Conv2d(
inp,
ConvModule(
inplanes,
inplanes,
kernel_size=3,
stride=self.stride,
padding=1,
groups=inplanes,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
inplanes,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
else:
self.branch1 = nn.Sequential()
self.branch2 = nn.Sequential(
nn.Conv2d(
inp if (self.stride > 1) else branch_features,
ConvModule(
inplanes if (self.stride > 1) else branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
self.depthwise_conv(
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
branch_features,
branch_features,
kernel_size=3,
stride=self.stride,
padding=1),
nn.BatchNorm2d(branch_features),
nn.Conv2d(
padding=1,
groups=branch_features,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
ConvModule(
branch_features,
branch_features,
kernel_size=1,
stride=1,
padding=0,
bias=False),
nn.BatchNorm2d(branch_features),
nn.ReLU(inplace=True),
)
@staticmethod
def depthwise_conv(i, o, kernel_size, stride=1, padding=0, bias=False):
return nn.Conv2d(
i, o, kernel_size, stride, padding, bias=bias, groups=i)
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
def forward(self, x):
@ -139,37 +172,48 @@ class ShuffleNetv2(BaseBackbone):
"""ShuffleNetv2 backbone.
Args:
groups (int): number of groups to be used in grouped
1x1 convolutions in each ShuffleUnit. Default is 3 for best
performance according to original paper.
widen_factor (float): Config of widen_factor.
groups (int): The number of groups to be used in grouped 1x1
convolutions in each ShuffleUnit. Default: 3.
widen_factor (float): Width multiplier - adjusts number of
channels in each layer by this amount. Default: 1.0.
out_indices (Sequence[int]): Output from which stages.
frozen_stages (int): Stages to be frozen (all param fixed). -1 means
not freezing any parameters.
bn_eval (bool): Whether to set nn.BatchNorm2d layers as eval mode,
namely, freeze
running stats (mean and var).
bn_frozen (bool): Whether to freeze weight and bias of
nn.BatchNorm2d layers.
Default: (0, 1, 2, 3).
frozen_stages (int): Stages to be frozen (all param fixed).
Default: -1, which means not freezing any parameters.
conv_cfg (dict): Config dict for convolution layer.
Default: None, which means using conv2d.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU').
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: True.
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
memory while slowing down the training speed.
memory while slowing down the training speed. Default: False.
"""
def __init__(self,
groups=3,
widen_factor=1.0,
out_indices=(0, 1, 2, 3),
out_indices=(0, 1, 2),
frozen_stages=-1,
bn_eval=True,
bn_frozen=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
norm_eval=True,
with_cp=False):
super(ShuffleNetv2, self).__init__()
blocks = [4, 8, 4]
self.stage_blocks = [4, 8, 4]
self.groups = groups
self.out_indices = out_indices
assert max(out_indices) < len(self.stage_blocks)
self.frozen_stages = frozen_stages
self.bn_eval = bn_eval
self.bn_frozen = bn_frozen
assert frozen_stages < len(self.stage_blocks)
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.norm_eval = norm_eval
self.with_cp = with_cp
if widen_factor == 0.5:
@ -181,64 +225,94 @@ class ShuffleNetv2(BaseBackbone):
elif widen_factor == 2.0:
channels = [244, 488, 976, 2048]
else:
raise ValueError("""{} groups is not supported for
1x1 Grouped Convolutions""".format(groups))
raise ValueError(f'widen_factor must in [0.5, 1.0, 1.5, 2.0]. '
f'But received {widen_factor}.')
self.inplanes = 24
self.conv1 = ConvModule(
in_channels=3,
out_channels=self.inplanes,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.inplanes = channels[0]
self.conv1 = conv_bn(3, self.inplanes, 2)
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
self.layer1 = self._make_layer(channels[1], blocks[0], with_cp=with_cp)
self.layer2 = self._make_layer(channels[2], blocks[1], with_cp=with_cp)
self.layer3 = self._make_layer(channels[3], blocks[2], with_cp=with_cp)
self.layers = []
for i, num_blocks in enumerate(self.stage_blocks):
layer = self._make_layer(channels[i], num_blocks)
layer_name = f'layer{i + 1}'
self.add_module(layer_name, layer)
self.layers.append(layer_name)
self.inplanes = channels[i]
self.conv_out = conv_1x1_bn(self.inplanes, channels[-1])
output_channels = channels[-1]
self.conv2 = ConvModule(
in_channels=self.inplanes,
out_channels=output_channels,
kernel_size=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def _make_layer(self, planes, num_blocks):
""" Stack blocks to make a layer.
Args:
planes (int): planes of block.
num_blocks (int): number of blocks.
"""
layers = []
for i in range(num_blocks):
stride = 2 if i == 0 else 1
layers.append(
InvertedResidual(
inplanes=self.inplanes,
planes=planes,
stride=stride,
conv_cfg=self.conv_cfg,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
with_cp=self.with_cp))
self.inplanes = planes
return nn.Sequential(*layers)
def _freeze_stages(self):
if self.frozen_stages >= 0:
for m in [self.conv1]:
for param in m.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
m = getattr(self, f'layer{i}')
m.eval()
for param in m.parameters():
param.requires_grad = False
def init_weights(self, pretrained=None):
if isinstance(pretrained, str):
logger = logging.getLogger()
load_checkpoint(self, pretrained, strict=False, logger=logger)
elif pretrained is None:
if pretrained is None:
for m in self.modules():
if isinstance(m, nn.Conv2d):
kaiming_init(m)
elif isinstance(m, nn.BatchNorm2d):
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
else:
raise TypeError('pretrained must be a str or None')
def _make_layer(self, outplanes, blocks, with_cp):
layers = []
for i in range(blocks):
if i == 0:
layers.append(
InvertedResidual(
self.inplanes, outplanes, stride=2, with_cp=with_cp))
else:
layers.append(
InvertedResidual(
self.inplanes, outplanes, stride=1, with_cp=with_cp))
self.inplanes = outplanes
return nn.Sequential(*layers)
def forward(self, x):
x = self.conv1(x)
x = self.maxpool(x)
outs = []
x = self.layer1(x)
if 0 in self.out_indices:
outs.append(x)
x = self.layer2(x)
if 1 in self.out_indices:
outs.append(x)
x = self.layer3(x)
if 2 in self.out_indices:
outs.append(x)
x = self.conv_out(x)
outs.append(x)
for i, layer_name in enumerate(self.layers):
layer = getattr(self, layer_name)
x = layer(x)
if i in self.out_indices:
outs.append(x)
if len(outs) == 1:
return outs[0]
@ -247,18 +321,8 @@ class ShuffleNetv2(BaseBackbone):
def train(self, mode=True):
super(ShuffleNetv2, self).train(mode)
if self.bn_eval:
self._freeze_stages()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
if self.bn_frozen:
for params in m.parameters():
params.requires_grad = False
if mode and self.frozen_stages >= 0:
for param in self.conv1.parameters():
param.requires_grad = False
for i in range(1, self.frozen_stages + 1):
mod = getattr(self, 'layer{}'.format(i))
mod.eval()
for param in mod.parameters():
param.requires_grad = False

View File

@ -1,66 +0,0 @@
# Copyright (c) Open-MMLab. All rights reserved.
import numpy as np
import torch.nn as nn
def constant_init(module, val, bias=0):
if hasattr(module, 'weight') and module.weight is not None:
nn.init.constant_(module.weight, val)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def xavier_init(module, gain=1, bias=0, distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.xavier_uniform_(module.weight, gain=gain)
else:
nn.init.xavier_normal_(module.weight, gain=gain)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def normal_init(module, mean=0, std=1, bias=0):
nn.init.normal_(module.weight, mean, std)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def uniform_init(module, a=0, b=1, bias=0):
nn.init.uniform_(module.weight, a, b)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def kaiming_init(module,
a=0,
mode='fan_out',
nonlinearity='relu',
bias=0,
distribution='normal'):
assert distribution in ['uniform', 'normal']
if distribution == 'uniform':
nn.init.kaiming_uniform_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
else:
nn.init.kaiming_normal_(
module.weight, a=a, mode=mode, nonlinearity=nonlinearity)
if hasattr(module, 'bias') and module.bias is not None:
nn.init.constant_(module.bias, bias)
def caffe2_xavier_init(module, bias=0):
# `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch
# Acknowledgment to FAIR's internal code
kaiming_init(
module,
a=1,
mode='fan_in',
nonlinearity='leaky_relu',
distribution='uniform')
def bias_init_with_prob(prior_prob):
""" initialize conv/fc bias value according to giving probablity"""
bias_init = float(-np.log((1 - prior_prob) / prior_prob))
return bias_init

View File

@ -4,7 +4,9 @@ from torch.nn.modules import GroupNorm
from torch.nn.modules.batchnorm import _BatchNorm
from mmcls.models.backbones import ShuffleNetv2
from mmcls.models.backbones.shufflenet_v2 import InvertedResidual
from mmcls.models.backbones.shufflenet_v2 import (InvertedResidual,
channel_shuffle,
make_divisible)
def is_block(modules):
@ -30,35 +32,57 @@ def check_norm_state(modules, train_state):
return True
def test_shufflenetv2_invertedresidual():
def test_channel_shuffle():
x = torch.randn(1, 24, 56, 56)
with pytest.raises(AssertionError):
# num_channels should be divisible by groups
channel_shuffle(x, 7)
with pytest.raises(ValueError):
# stride must be in [1, 2]
InvertedResidual(24, 16, stride=3)
def test_make_divisible():
# test min_value is None
make_divisible(34, 8, None)
# test new_value < 0.9 * value
make_divisible(10, 8, None)
def test_shufflenetv2_invertedresidual():
with pytest.raises(AssertionError):
# when stride==1, 16 == branch_features << 1
InvertedResidual(24, 64, stride=1)
InvertedResidual(24, 32, stride=1)
# Test InvertedResidual forward
block = InvertedResidual(24, 64, stride=2)
block = InvertedResidual(24, 48, stride=2)
x = torch.randn(1, 24, 56, 56)
x_out = block(x)
assert x_out.shape == torch.Size([1, 64, 28, 28])
assert x_out.shape == torch.Size((1, 48, 28, 28))
# Test InvertedResidual with checkpoint forward
block = InvertedResidual(24, 24, stride=1, with_cp=True)
x = torch.randn(1, 24, 56, 56)
block = InvertedResidual(48, 48, stride=1, with_cp=True)
assert block.with_cp
x = torch.randn(1, 48, 56, 56)
x.requires_grad = True
x_out = block(x)
assert x_out.shape == torch.Size([1, 24, 56, 56])
assert x_out.shape == torch.Size((1, 48, 56, 56))
def test_ShuffleNetv2_backbone():
def test_shufflenetv2_backbone():
with pytest.raises(ValueError):
# groups must in 0.5, 1.0, 1.5, 2.0]
ShuffleNetv2(widen_factor=3.0)
with pytest.raises(AssertionError):
# frozen_stages must in [0, 1, 2]
ShuffleNetv2(widen_factor=3.0, frozen_stages=3)
with pytest.raises(TypeError):
# pretrained must be str or None
model = ShuffleNetv2()
model.init_weights(pretrained=1)
# Test ShuffleNetv2 norm state
model = ShuffleNetv2()
model.init_weights()
@ -81,19 +105,28 @@ def test_ShuffleNetv2_backbone():
for param in layer.parameters():
assert param.requires_grad is False
# Test ShuffleNetv2 with bn frozen
model = ShuffleNetv2(bn_frozen=True)
# Test ShuffleNetv2 with norm_eval
model = ShuffleNetv2(norm_eval=True)
model.init_weights()
model.train()
for i in range(1, 4):
layer = getattr(model, f'layer{i}')
assert check_norm_state(model.modules(), False)
for mod in layer.modules():
if isinstance(mod, _BatchNorm):
assert mod.training is False
for params in mod.parameters():
params.requires_grad = False
# Test ShuffleNetv2 forward with widen_factor=0.5
model = ShuffleNetv2(widen_factor=0.5)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 48, 28, 28))
assert feat[1].shape == torch.Size((1, 96, 14, 14))
assert feat[2].shape == torch.Size((1, 192, 7, 7))
# Test ShuffleNetv2 forward with widen_factor=1.0
model = ShuffleNetv2(widen_factor=1.0)
@ -106,11 +139,56 @@ def test_ShuffleNetv2_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 232, 28, 28])
assert feat[1].shape == torch.Size([1, 464, 14, 14])
assert feat[2].shape == torch.Size([1, 1024, 7, 7])
assert feat[3].shape == torch.Size([1, 1024, 7, 7])
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 116, 28, 28))
assert feat[1].shape == torch.Size((1, 232, 14, 14))
assert feat[2].shape == torch.Size((1, 464, 7, 7))
# Test ShuffleNetv2 forward with widen_factor=1.5
model = ShuffleNetv2(widen_factor=1.5)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 176, 28, 28))
assert feat[1].shape == torch.Size((1, 352, 14, 14))
assert feat[2].shape == torch.Size((1, 704, 7, 7))
# Test ShuffleNetv2 forward with widen_factor=2.0
model = ShuffleNetv2(widen_factor=2.0)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size((1, 244, 28, 28))
assert feat[1].shape == torch.Size((1, 488, 14, 14))
assert feat[2].shape == torch.Size((1, 976, 7, 7))
# Test ShuffleNetv2 forward with layers 3 forward
model = ShuffleNetv2(widen_factor=1.0, out_indices=(2, ))
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert isinstance(feat, torch.Tensor)
assert feat.shape == torch.Size((1, 464, 7, 7))
# Test ShuffleNetv2 forward with layers 1 2 forward
model = ShuffleNetv2(widen_factor=1.0, out_indices=(1, 2))
@ -123,23 +201,12 @@ def test_ShuffleNetv2_backbone():
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 3
assert feat[0].shape == torch.Size([1, 464, 14, 14])
assert feat[1].shape == torch.Size([1, 1024, 7, 7])
assert len(feat) == 2
assert feat[0].shape == torch.Size((1, 232, 14, 14))
assert feat[1].shape == torch.Size((1, 464, 7, 7))
# Test ShuffleNetv2 forward with checkpoint forward
model = ShuffleNetv2(widen_factor=1.0, with_cp=True)
model.init_weights()
model.train()
for m in model.modules():
if is_norm(m):
assert isinstance(m, _BatchNorm)
imgs = torch.randn(1, 3, 224, 224)
feat = model(imgs)
assert len(feat) == 4
assert feat[0].shape == torch.Size([1, 232, 28, 28])
assert feat[1].shape == torch.Size([1, 464, 14, 14])
assert feat[2].shape == torch.Size([1, 1024, 7, 7])
assert feat[3].shape == torch.Size([1, 1024, 7, 7])
if is_block(m):
assert m.with_cp