mirror of
https://github.com/open-mmlab/mmclassification.git
synced 2025-06-03 21:53:55 +08:00
Merge branch 'dev_shufflenetv2' into 'master'
add shufflenetv2 See merge request open-mmlab/mmclassification!11
This commit is contained in:
commit
9435eecd46
@ -1,5 +1,13 @@
|
||||
from .resnet import ResNet, ResNetV1d
|
||||
from .resnext import ResNeXt
|
||||
from .shufflenet_v1 import ShuffleNetv1
|
||||
from .shufflenet_v2 import ShuffleNetv2
|
||||
|
||||
__all__ = ['ResNet', 'ResNeXt', 'ResNetV1d', 'ShuffleNetv1']
|
||||
__all__ = [
|
||||
'ResNet',
|
||||
'ResNeXt',
|
||||
'ResNetV1d',
|
||||
'ResNetV1d',
|
||||
'ShuffleNetv1',
|
||||
'ShuffleNetv2',
|
||||
]
|
||||
|
281
mmcls/models/backbones/shufflenet_v2.py
Normal file
281
mmcls/models/backbones/shufflenet_v2.py
Normal file
@ -0,0 +1,281 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.utils.checkpoint as cp
|
||||
from mmcv.cnn import ConvModule, constant_init, kaiming_init
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.utils import channel_shuffle
|
||||
from .base_backbone import BaseBackbone
|
||||
|
||||
|
||||
class InvertedResidual(nn.Module):
|
||||
"""InvertedResidual block for ShuffleNetV2 backbone.
|
||||
|
||||
Args:
|
||||
inplanes (int): The input channels of the block.
|
||||
planes (int): The output channels of the block.
|
||||
stride (int): Stride of the 3x3 convolution layer. Default: 1
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
|
||||
Returns:
|
||||
Tensor: The output tensor.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
inplanes,
|
||||
planes,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
with_cp=False):
|
||||
super(InvertedResidual, self).__init__()
|
||||
self.stride = stride
|
||||
self.with_cp = with_cp
|
||||
|
||||
branch_features = planes // 2
|
||||
if self.stride == 1:
|
||||
assert inplanes == branch_features * 2, (
|
||||
f'inplanes ({inplanes}) should equal to branch_features * 2 '
|
||||
f'({branch_features * 2}) when stride is 1')
|
||||
|
||||
if inplanes != branch_features * 2:
|
||||
assert self.stride != 1, (
|
||||
f'stride ({self.stride}) should not equal 1 when '
|
||||
f'inplanes != branch_features * 2')
|
||||
|
||||
if self.stride > 1:
|
||||
self.branch1 = nn.Sequential(
|
||||
ConvModule(
|
||||
inplanes,
|
||||
inplanes,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
groups=inplanes,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
ConvModule(
|
||||
inplanes,
|
||||
branch_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
ConvModule(
|
||||
inplanes if (self.stride > 1) else branch_features,
|
||||
branch_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
branch_features,
|
||||
branch_features,
|
||||
kernel_size=3,
|
||||
stride=self.stride,
|
||||
padding=1,
|
||||
groups=branch_features,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
ConvModule(
|
||||
branch_features,
|
||||
branch_features,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
|
||||
def forward(self, x):
|
||||
|
||||
def _inner_forward(x):
|
||||
if self.stride > 1:
|
||||
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1)
|
||||
else:
|
||||
x1, x2 = x.chunk(2, dim=1)
|
||||
out = torch.cat((x1, self.branch2(x2)), dim=1)
|
||||
|
||||
out = channel_shuffle(out, 2)
|
||||
|
||||
return out
|
||||
|
||||
if self.with_cp and x.requires_grad:
|
||||
out = cp.checkpoint(_inner_forward, x)
|
||||
else:
|
||||
out = _inner_forward(x)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ShuffleNetv2(BaseBackbone):
|
||||
"""ShuffleNetv2 backbone.
|
||||
|
||||
Args:
|
||||
groups (int): The number of groups to be used in grouped 1x1
|
||||
convolutions in each InvertedResidual. Default: 3.
|
||||
widen_factor (float): Width multiplier - adjusts the number of
|
||||
channels in each layer by this amount. Default: 1.0.
|
||||
out_indices (Sequence[int]): Output from which stages.
|
||||
Default: (0, 1, 2, 3).
|
||||
frozen_stages (int): Stages to be frozen (all param fixed).
|
||||
Default: -1, which means not freezing any parameters.
|
||||
conv_cfg (dict): Config dict for convolution layer.
|
||||
Default: None, which means using conv2d.
|
||||
norm_cfg (dict): Config dict for normalization layer.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config dict for activation layer.
|
||||
Default: dict(type='ReLU').
|
||||
norm_eval (bool): Whether to set norm layers to eval mode, namely,
|
||||
freeze running stats (mean and var). Note: Effect on Batch Norm
|
||||
and its variants only. Default: False.
|
||||
with_cp (bool): Use checkpoint or not. Using checkpoint will save some
|
||||
memory while slowing down the training speed. Default: False.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
groups=3,
|
||||
widen_factor=1.0,
|
||||
out_indices=(0, 1, 2),
|
||||
frozen_stages=-1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
norm_eval=False,
|
||||
with_cp=False):
|
||||
super(ShuffleNetv2, self).__init__()
|
||||
self.stage_blocks = [4, 8, 4]
|
||||
self.groups = groups
|
||||
self.out_indices = out_indices
|
||||
assert max(out_indices) < len(self.stage_blocks)
|
||||
self.frozen_stages = frozen_stages
|
||||
assert frozen_stages < len(self.stage_blocks)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.norm_eval = norm_eval
|
||||
self.with_cp = with_cp
|
||||
|
||||
if widen_factor == 0.5:
|
||||
channels = [48, 96, 192, 1024]
|
||||
elif widen_factor == 1.0:
|
||||
channels = [116, 232, 464, 1024]
|
||||
elif widen_factor == 1.5:
|
||||
channels = [176, 352, 704, 1024]
|
||||
elif widen_factor == 2.0:
|
||||
channels = [244, 488, 976, 2048]
|
||||
else:
|
||||
raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. '
|
||||
f'But received {widen_factor}')
|
||||
|
||||
self.inplanes = 24
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=3,
|
||||
out_channels=self.inplanes,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i, num_blocks in enumerate(self.stage_blocks):
|
||||
layer = self._make_layer(channels[i], num_blocks)
|
||||
self.layers.append(layer)
|
||||
|
||||
output_channels = channels[-1]
|
||||
self.conv2 = ConvModule(
|
||||
in_channels=self.inplanes,
|
||||
out_channels=output_channels,
|
||||
kernel_size=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def _make_layer(self, planes, num_blocks):
|
||||
""" Stack blocks to make a layer.
|
||||
|
||||
Args:
|
||||
planes (int): planes of the block.
|
||||
num_blocks (int): number of blocks.
|
||||
"""
|
||||
layers = []
|
||||
for i in range(num_blocks):
|
||||
stride = 2 if i == 0 else 1
|
||||
layers.append(
|
||||
InvertedResidual(
|
||||
inplanes=self.inplanes,
|
||||
planes=planes,
|
||||
stride=stride,
|
||||
conv_cfg=self.conv_cfg,
|
||||
norm_cfg=self.norm_cfg,
|
||||
act_cfg=self.act_cfg,
|
||||
with_cp=self.with_cp))
|
||||
self.inplanes = planes
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def _freeze_stages(self):
|
||||
if self.frozen_stages >= 0:
|
||||
for param in self.conv1.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
for i in range(self.frozen_stages):
|
||||
m = self.layers[i]
|
||||
m.eval()
|
||||
for param in m.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def init_weights(self, pretrained=None):
|
||||
if pretrained is None:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
kaiming_init(m)
|
||||
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
|
||||
constant_init(m, 1)
|
||||
else:
|
||||
raise TypeError('pretrained must be a str or None. But received '
|
||||
f'{type(pretrained)}')
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
outs = []
|
||||
for i, layer in enumerate(self.layers):
|
||||
x = layer(x)
|
||||
if i in self.out_indices:
|
||||
outs.append(x)
|
||||
|
||||
if len(outs) == 1:
|
||||
return outs[0]
|
||||
else:
|
||||
return tuple(outs)
|
||||
|
||||
def train(self, mode=True):
|
||||
super(ShuffleNetv2, self).train(mode)
|
||||
self._freeze_stages()
|
||||
if mode and self.norm_eval:
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.BatchNorm2d):
|
||||
m.eval()
|
198
tests/test_backbones/test_shufflenet_v2.py
Normal file
198
tests/test_backbones/test_shufflenet_v2.py
Normal file
@ -0,0 +1,198 @@
|
||||
import pytest
|
||||
import torch
|
||||
from torch.nn.modules import GroupNorm
|
||||
from torch.nn.modules.batchnorm import _BatchNorm
|
||||
|
||||
from mmcls.models.backbones import ShuffleNetv2
|
||||
from mmcls.models.backbones.shufflenet_v2 import InvertedResidual
|
||||
|
||||
|
||||
def is_block(modules):
|
||||
"""Check if is ResNet building block."""
|
||||
if isinstance(modules, (InvertedResidual, )):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def is_norm(modules):
|
||||
"""Check if is one of the norms."""
|
||||
if isinstance(modules, (GroupNorm, _BatchNorm)):
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def check_norm_state(modules, train_state):
|
||||
"""Check if norm layer is in correct train state."""
|
||||
for mod in modules:
|
||||
if isinstance(mod, _BatchNorm):
|
||||
if mod.training != train_state:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def test_shufflenetv2_invertedresidual():
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# when stride==1, inplanes should be equal to planes // 2 * 2
|
||||
InvertedResidual(24, 32, stride=1)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# when inplanes != planes // 2 * 2, stride should not be equal to 1.
|
||||
InvertedResidual(24, 32, stride=1)
|
||||
|
||||
# Test InvertedResidual forward
|
||||
block = InvertedResidual(24, 48, stride=2)
|
||||
x = torch.randn(1, 24, 56, 56)
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size((1, 48, 28, 28))
|
||||
|
||||
# Test InvertedResidual with checkpoint forward
|
||||
block = InvertedResidual(48, 48, stride=1, with_cp=True)
|
||||
assert block.with_cp
|
||||
x = torch.randn(1, 48, 56, 56)
|
||||
x.requires_grad = True
|
||||
x_out = block(x)
|
||||
assert x_out.shape == torch.Size((1, 48, 56, 56))
|
||||
|
||||
|
||||
def test_shufflenetv2_backbone():
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
# groups must be in 0.5, 1.0, 1.5, 2.0]
|
||||
ShuffleNetv2(widen_factor=3.0)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# frozen_stages must be in [0, 1, 2]
|
||||
ShuffleNetv2(widen_factor=3.0, frozen_stages=3)
|
||||
|
||||
with pytest.raises(TypeError):
|
||||
# pretrained must be str or None
|
||||
model = ShuffleNetv2()
|
||||
model.init_weights(pretrained=1)
|
||||
|
||||
# Test ShuffleNetv2 norm state
|
||||
model = ShuffleNetv2()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
assert check_norm_state(model.modules(), True)
|
||||
|
||||
# Test ShuffleNetv2 with first stage frozen
|
||||
frozen_stages = 1
|
||||
model = ShuffleNetv2(frozen_stages=frozen_stages)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
for param in model.conv1.parameters():
|
||||
assert param.requires_grad is False
|
||||
for i in range(0, frozen_stages):
|
||||
layer = model.layers[i]
|
||||
for mod in layer.modules():
|
||||
if isinstance(mod, _BatchNorm):
|
||||
assert mod.training is False
|
||||
for param in layer.parameters():
|
||||
assert param.requires_grad is False
|
||||
|
||||
# Test ShuffleNetv2 with norm_eval
|
||||
model = ShuffleNetv2(norm_eval=True)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
assert check_norm_state(model.modules(), False)
|
||||
|
||||
# Test ShuffleNetv2 forward with widen_factor=0.5
|
||||
model = ShuffleNetv2(widen_factor=0.5)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size((1, 48, 28, 28))
|
||||
assert feat[1].shape == torch.Size((1, 96, 14, 14))
|
||||
assert feat[2].shape == torch.Size((1, 192, 7, 7))
|
||||
|
||||
# Test ShuffleNetv2 forward with widen_factor=1.0
|
||||
model = ShuffleNetv2(widen_factor=1.0)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size((1, 116, 28, 28))
|
||||
assert feat[1].shape == torch.Size((1, 232, 14, 14))
|
||||
assert feat[2].shape == torch.Size((1, 464, 7, 7))
|
||||
|
||||
# Test ShuffleNetv2 forward with widen_factor=1.5
|
||||
model = ShuffleNetv2(widen_factor=1.5)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size((1, 176, 28, 28))
|
||||
assert feat[1].shape == torch.Size((1, 352, 14, 14))
|
||||
assert feat[2].shape == torch.Size((1, 704, 7, 7))
|
||||
|
||||
# Test ShuffleNetv2 forward with widen_factor=2.0
|
||||
model = ShuffleNetv2(widen_factor=2.0)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 3
|
||||
assert feat[0].shape == torch.Size((1, 244, 28, 28))
|
||||
assert feat[1].shape == torch.Size((1, 488, 14, 14))
|
||||
assert feat[2].shape == torch.Size((1, 976, 7, 7))
|
||||
|
||||
# Test ShuffleNetv2 forward with layers 3 forward
|
||||
model = ShuffleNetv2(widen_factor=1.0, out_indices=(2, ))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert isinstance(feat, torch.Tensor)
|
||||
assert feat.shape == torch.Size((1, 464, 7, 7))
|
||||
|
||||
# Test ShuffleNetv2 forward with layers 1 2 forward
|
||||
model = ShuffleNetv2(widen_factor=1.0, out_indices=(1, 2))
|
||||
model.init_weights()
|
||||
model.train()
|
||||
|
||||
for m in model.modules():
|
||||
if is_norm(m):
|
||||
assert isinstance(m, _BatchNorm)
|
||||
|
||||
imgs = torch.randn(1, 3, 224, 224)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 2
|
||||
assert feat[0].shape == torch.Size((1, 232, 14, 14))
|
||||
assert feat[1].shape == torch.Size((1, 464, 7, 7))
|
||||
|
||||
# Test ShuffleNetv2 forward with checkpoint forward
|
||||
model = ShuffleNetv2(widen_factor=1.0, with_cp=True)
|
||||
for m in model.modules():
|
||||
if is_block(m):
|
||||
assert m.with_cp
|
Loading…
x
Reference in New Issue
Block a user