diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index a81870492..2dbc49798 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -1,5 +1,13 @@ from .resnet import ResNet, ResNetV1d from .resnext import ResNeXt from .shufflenet_v1 import ShuffleNetv1 +from .shufflenet_v2 import ShuffleNetv2 -__all__ = ['ResNet', 'ResNeXt', 'ResNetV1d', 'ShuffleNetv1'] +__all__ = [ + 'ResNet', + 'ResNeXt', + 'ResNetV1d', + 'ResNetV1d', + 'ShuffleNetv1', + 'ShuffleNetv2', +] diff --git a/mmcls/models/backbones/shufflenet_v2.py b/mmcls/models/backbones/shufflenet_v2.py new file mode 100644 index 000000000..8be5563a4 --- /dev/null +++ b/mmcls/models/backbones/shufflenet_v2.py @@ -0,0 +1,281 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, constant_init, kaiming_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.utils import channel_shuffle +from .base_backbone import BaseBackbone + + +class InvertedResidual(nn.Module): + """InvertedResidual block for ShuffleNetV2 backbone. + + Args: + inplanes (int): The input channels of the block. + planes (int): The output channels of the block. + stride (int): Stride of the 3x3 convolution layer. Default: 1 + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor. + """ + + def __init__(self, + inplanes, + planes, + stride=1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + with_cp=False): + super(InvertedResidual, self).__init__() + self.stride = stride + self.with_cp = with_cp + + branch_features = planes // 2 + if self.stride == 1: + assert inplanes == branch_features * 2, ( + f'inplanes ({inplanes}) should equal to branch_features * 2 ' + f'({branch_features * 2}) when stride is 1') + + if inplanes != branch_features * 2: + assert self.stride != 1, ( + f'stride ({self.stride}) should not equal 1 when ' + f'inplanes != branch_features * 2') + + if self.stride > 1: + self.branch1 = nn.Sequential( + ConvModule( + inplanes, + inplanes, + kernel_size=3, + stride=self.stride, + padding=1, + groups=inplanes, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + inplanes, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ) + + self.branch2 = nn.Sequential( + ConvModule( + inplanes if (self.stride > 1) else branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + branch_features, + branch_features, + kernel_size=3, + stride=self.stride, + padding=1, + groups=branch_features, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None), + ConvModule( + branch_features, + branch_features, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + + def forward(self, x): + + def _inner_forward(x): + if self.stride > 1: + out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) + else: + x1, x2 = x.chunk(2, dim=1) + out = torch.cat((x1, self.branch2(x2)), dim=1) + + out = channel_shuffle(out, 2) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +class ShuffleNetv2(BaseBackbone): + """ShuffleNetv2 backbone. + + Args: + groups (int): The number of groups to be used in grouped 1x1 + convolutions in each InvertedResidual. Default: 3. + widen_factor (float): Width multiplier - adjusts the number of + channels in each layer by this amount. Default: 1.0. + out_indices (Sequence[int]): Output from which stages. + Default: (0, 1, 2, 3). + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + def __init__(self, + groups=3, + widen_factor=1.0, + out_indices=(0, 1, 2), + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU'), + norm_eval=False, + with_cp=False): + super(ShuffleNetv2, self).__init__() + self.stage_blocks = [4, 8, 4] + self.groups = groups + self.out_indices = out_indices + assert max(out_indices) < len(self.stage_blocks) + self.frozen_stages = frozen_stages + assert frozen_stages < len(self.stage_blocks) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + if widen_factor == 0.5: + channels = [48, 96, 192, 1024] + elif widen_factor == 1.0: + channels = [116, 232, 464, 1024] + elif widen_factor == 1.5: + channels = [176, 352, 704, 1024] + elif widen_factor == 2.0: + channels = [244, 488, 976, 2048] + else: + raise ValueError('widen_factor must be in [0.5, 1.0, 1.5, 2.0]. ' + f'But received {widen_factor}') + + self.inplanes = 24 + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.inplanes, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layers = nn.ModuleList() + for i, num_blocks in enumerate(self.stage_blocks): + layer = self._make_layer(channels[i], num_blocks) + self.layers.append(layer) + + output_channels = channels[-1] + self.conv2 = ConvModule( + in_channels=self.inplanes, + out_channels=output_channels, + kernel_size=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg) + + def _make_layer(self, planes, num_blocks): + """ Stack blocks to make a layer. + + Args: + planes (int): planes of the block. + num_blocks (int): number of blocks. + """ + layers = [] + for i in range(num_blocks): + stride = 2 if i == 0 else 1 + layers.append( + InvertedResidual( + inplanes=self.inplanes, + planes=planes, + stride=stride, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + + for i in range(self.frozen_stages): + m = self.layers[i] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def init_weights(self, pretrained=None): + if pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None. But received ' + f'{type(pretrained)}') + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + + outs = [] + for i, layer in enumerate(self.layers): + x = layer(x) + if i in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def train(self, mode=True): + super(ShuffleNetv2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() diff --git a/tests/test_backbones/test_shufflenet_v2.py b/tests/test_backbones/test_shufflenet_v2.py new file mode 100644 index 000000000..7b0c6e864 --- /dev/null +++ b/tests/test_backbones/test_shufflenet_v2.py @@ -0,0 +1,198 @@ +import pytest +import torch +from torch.nn.modules import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.backbones import ShuffleNetv2 +from mmcls.models.backbones.shufflenet_v2 import InvertedResidual + + +def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (InvertedResidual, )): + return True + return False + + +def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_shufflenetv2_invertedresidual(): + + with pytest.raises(AssertionError): + # when stride==1, inplanes should be equal to planes // 2 * 2 + InvertedResidual(24, 32, stride=1) + + with pytest.raises(AssertionError): + # when inplanes != planes // 2 * 2, stride should not be equal to 1. + InvertedResidual(24, 32, stride=1) + + # Test InvertedResidual forward + block = InvertedResidual(24, 48, stride=2) + x = torch.randn(1, 24, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size((1, 48, 28, 28)) + + # Test InvertedResidual with checkpoint forward + block = InvertedResidual(48, 48, stride=1, with_cp=True) + assert block.with_cp + x = torch.randn(1, 48, 56, 56) + x.requires_grad = True + x_out = block(x) + assert x_out.shape == torch.Size((1, 48, 56, 56)) + + +def test_shufflenetv2_backbone(): + + with pytest.raises(ValueError): + # groups must be in 0.5, 1.0, 1.5, 2.0] + ShuffleNetv2(widen_factor=3.0) + + with pytest.raises(AssertionError): + # frozen_stages must be in [0, 1, 2] + ShuffleNetv2(widen_factor=3.0, frozen_stages=3) + + with pytest.raises(TypeError): + # pretrained must be str or None + model = ShuffleNetv2() + model.init_weights(pretrained=1) + + # Test ShuffleNetv2 norm state + model = ShuffleNetv2() + model.init_weights() + model.train() + assert check_norm_state(model.modules(), True) + + # Test ShuffleNetv2 with first stage frozen + frozen_stages = 1 + model = ShuffleNetv2(frozen_stages=frozen_stages) + model.init_weights() + model.train() + for param in model.conv1.parameters(): + assert param.requires_grad is False + for i in range(0, frozen_stages): + layer = model.layers[i] + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + # Test ShuffleNetv2 with norm_eval + model = ShuffleNetv2(norm_eval=True) + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), False) + + # Test ShuffleNetv2 forward with widen_factor=0.5 + model = ShuffleNetv2(widen_factor=0.5) + model.init_weights() + model.train() + + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size((1, 48, 28, 28)) + assert feat[1].shape == torch.Size((1, 96, 14, 14)) + assert feat[2].shape == torch.Size((1, 192, 7, 7)) + + # Test ShuffleNetv2 forward with widen_factor=1.0 + model = ShuffleNetv2(widen_factor=1.0) + model.init_weights() + model.train() + + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size((1, 116, 28, 28)) + assert feat[1].shape == torch.Size((1, 232, 14, 14)) + assert feat[2].shape == torch.Size((1, 464, 7, 7)) + + # Test ShuffleNetv2 forward with widen_factor=1.5 + model = ShuffleNetv2(widen_factor=1.5) + model.init_weights() + model.train() + + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size((1, 176, 28, 28)) + assert feat[1].shape == torch.Size((1, 352, 14, 14)) + assert feat[2].shape == torch.Size((1, 704, 7, 7)) + + # Test ShuffleNetv2 forward with widen_factor=2.0 + model = ShuffleNetv2(widen_factor=2.0) + model.init_weights() + model.train() + + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size((1, 244, 28, 28)) + assert feat[1].shape == torch.Size((1, 488, 14, 14)) + assert feat[2].shape == torch.Size((1, 976, 7, 7)) + + # Test ShuffleNetv2 forward with layers 3 forward + model = ShuffleNetv2(widen_factor=1.0, out_indices=(2, )) + model.init_weights() + model.train() + + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert isinstance(feat, torch.Tensor) + assert feat.shape == torch.Size((1, 464, 7, 7)) + + # Test ShuffleNetv2 forward with layers 1 2 forward + model = ShuffleNetv2(widen_factor=1.0, out_indices=(1, 2)) + model.init_weights() + model.train() + + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 2 + assert feat[0].shape == torch.Size((1, 232, 14, 14)) + assert feat[1].shape == torch.Size((1, 464, 7, 7)) + + # Test ShuffleNetv2 forward with checkpoint forward + model = ShuffleNetv2(widen_factor=1.0, with_cp=True) + for m in model.modules(): + if is_block(m): + assert m.with_cp