From deee5d61d7622cb3adf84a07efea6e4bd79651c3 Mon Sep 17 00:00:00 2001 From: louzan Date: Tue, 16 Jun 2020 14:37:03 +0800 Subject: [PATCH] add mobilenetv2 --- mmcls/models/backbones/__init__.py | 13 +- mmcls/models/backbones/mobilenet_v2.py | 265 +++++++++++++++++++++ mmcls/models/backbones/shufflenet_v1.py | 10 +- mmcls/models/backbones/shufflenet_v2.py | 10 +- tests/test_backbones/test_mobilenet_v2.py | 255 ++++++++++++++++++++ tests/test_backbones/test_shufflenet_v1.py | 58 ++--- tests/test_backbones/test_shufflenet_v2.py | 48 ++-- 7 files changed, 590 insertions(+), 69 deletions(-) create mode 100644 mmcls/models/backbones/mobilenet_v2.py create mode 100644 tests/test_backbones/test_mobilenet_v2.py diff --git a/mmcls/models/backbones/__init__.py b/mmcls/models/backbones/__init__.py index 2dbc49798..6b999624b 100644 --- a/mmcls/models/backbones/__init__.py +++ b/mmcls/models/backbones/__init__.py @@ -1,13 +1,10 @@ +from .mobilenet_v2 import MobileNetV2 from .resnet import ResNet, ResNetV1d from .resnext import ResNeXt -from .shufflenet_v1 import ShuffleNetv1 -from .shufflenet_v2 import ShuffleNetv2 +from .shufflenet_v1 import ShuffleNetV1 +from .shufflenet_v2 import ShuffleNetV2 __all__ = [ - 'ResNet', - 'ResNeXt', - 'ResNetV1d', - 'ResNetV1d', - 'ShuffleNetv1', - 'ShuffleNetv2', + 'ResNet', 'ResNeXt', 'ResNetV1d', 'ResNetV1d', 'ShuffleNetV1', + 'ShuffleNetV2', 'MobileNetV2' ] diff --git a/mmcls/models/backbones/mobilenet_v2.py b/mmcls/models/backbones/mobilenet_v2.py new file mode 100644 index 000000000..15b750ea8 --- /dev/null +++ b/mmcls/models/backbones/mobilenet_v2.py @@ -0,0 +1,265 @@ +import torch.nn as nn +import torch.utils.checkpoint as cp +from mmcv.cnn import ConvModule, constant_init, kaiming_init +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.utils import make_divisible +from ..builder import BACKBONES +from .base_backbone import BaseBackbone + + +class InvertedResidual(nn.Module): + """InvertedResidual block for MobileNetV2. + + Args: + inplanes (int): The input channels of the InvertedResidual block. + planes (int): The output channels of the InvertedResidual block. + stride (int): Stride of the middle (first) 3x3 convolution. + expand_ratio (int): adjusts number of channels of the hidden layer + in InvertedResidual by this amount. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + + Returns: + Tensor: The output tensor + """ + + def __init__(self, + inplanes, + planes, + stride, + expand_ratio, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + with_cp=False): + super(InvertedResidual, self).__init__() + self.stride = stride + assert stride in [1, 2], f'stride must in [1, 2]. ' \ + f'But received {stride}.' + self.with_cp = with_cp + self.use_res_connect = self.stride == 1 and inplanes == planes + hidden_dim = int(round(inplanes * expand_ratio)) + + layers = [] + if expand_ratio != 1: + layers.append( + ConvModule( + in_channels=inplanes, + out_channels=hidden_dim, + kernel_size=3, + stride=1, + padding=1, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg)) + layers.extend([ + ConvModule( + in_channels=hidden_dim, + out_channels=hidden_dim, + kernel_size=3, + stride=stride, + padding=1, + groups=hidden_dim, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=act_cfg), + ConvModule( + in_channels=hidden_dim, + out_channels=planes, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=conv_cfg, + norm_cfg=norm_cfg, + act_cfg=None) + ]) + self.conv = nn.Sequential(*layers) + + def forward(self, x): + + def _inner_forward(x): + if self.use_res_connect: + return x + self.conv(x) + else: + return self.conv(x) + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + return out + + +@BACKBONES.register_module() +class MobileNetV2(BaseBackbone): + """MobileNetV2 backbone. + + Args: + widen_factor (float): Width multiplier, multiply number of + channels in each layer by this amount. Default: 1.0. + out_indices (None or Sequence[int]): Output from which stages. + Default: None + frozen_stages (int): Stages to be frozen (all param fixed). + Default: -1, which means not freezing any parameters. + conv_cfg (dict): Config dict for convolution layer. + Default: None, which means using conv2d. + norm_cfg (dict): Config dict for normalization layer. + Default: dict(type='BN'). + act_cfg (dict): Config dict for activation layer. + Default: dict(type='ReLU6'). + norm_eval (bool): Whether to set norm layers to eval mode, namely, + freeze running stats (mean and var). Note: Effect on Batch Norm + and its variants only. Default: False. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. Default: False. + """ + + # Parameters to build layers. 4 parameters are needed to construct a + # layer, from left to right: expand_ratio, channel, num_blocks, stride. + arch_settings = [[1, 16, 1, 1], [6, 24, 2, 2], [6, 32, 3, 2], + [6, 64, 4, 2], [6, 96, 3, 1], [6, 160, 3, 2], + [6, 320, 1, 1]] + + def __init__(self, + widen_factor=1., + out_indices=None, + frozen_stages=-1, + conv_cfg=None, + norm_cfg=dict(type='BN'), + act_cfg=dict(type='ReLU6'), + norm_eval=False, + with_cp=False): + super(MobileNetV2, self).__init__() + self.widen_factor = widen_factor + self.out_indices = out_indices + if out_indices is not None: + assert max(out_indices) < len(self.arch_settings) + self.frozen_stages = frozen_stages + assert frozen_stages < len(self.arch_settings) + self.conv_cfg = conv_cfg + self.norm_cfg = norm_cfg + self.act_cfg = act_cfg + self.norm_eval = norm_eval + self.with_cp = with_cp + + self.inplanes = make_divisible(32 * widen_factor, 8) + + self.conv1 = ConvModule( + in_channels=3, + out_channels=self.inplanes, + kernel_size=3, + stride=2, + padding=1, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + self.inverted_res_layers = [] + + for i, layer_cfg in enumerate(self.arch_settings): + expand_ratio, channel, num_blocks, stride = layer_cfg + planes = make_divisible(channel * widen_factor, 8) + inverted_res_layer = self.make_layer( + planes=planes, + num_blocks=num_blocks, + stride=stride, + expand_ratio=expand_ratio) + layer_name = f'layer{i + 1}' + self.add_module(layer_name, inverted_res_layer) + self.inverted_res_layers.append(layer_name) + + if widen_factor > 1.0: + self.out_channel = int(1280 * widen_factor) + else: + self.out_channel = 1280 + + self.conv2 = ConvModule( + in_channels=self.inplanes, + out_channels=self.out_channel, + kernel_size=1, + stride=1, + padding=0, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg) + + def make_layer(self, planes, num_blocks, stride, expand_ratio): + """ Stack InvertedResidual blocks to build a layer for MobileNetV2. + + Args: + planes (int): planes of block. + num_blocks (int): number of blocks. + stride (int): stride of the first block. Default: 1 + expand_ratio (int): Expand the number of channels of the + hidden layer in InvertedResidual by this ratio. Default: 6. + """ + layers = [] + for i in range(num_blocks): + if i >= 1: + stride = 1 + layers.append( + InvertedResidual( + self.inplanes, + planes, + stride, + expand_ratio=expand_ratio, + conv_cfg=self.conv_cfg, + norm_cfg=self.norm_cfg, + act_cfg=self.act_cfg, + with_cp=self.with_cp)) + self.inplanes = planes + + return nn.Sequential(*layers) + + def init_weights(self, pretrained=None): + if pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, (_BatchNorm, nn.GroupNorm)): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def forward(self, x): + x = self.conv1(x) + + outs = [] + for i, layer_name in enumerate(self.inverted_res_layers): + inverted_res_layer = getattr(self, layer_name) + x = inverted_res_layer(x) + if self.out_indices is not None and i in self.out_indices: + outs.append(x) + + x = self.conv2(x) + + if self.out_indices is None: + return x + else: + return tuple(outs) + + def _freeze_stages(self): + if self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): + param.requires_grad = False + + def train(self, mode=True): + super(MobileNetV2, self).train(mode) + self._freeze_stages() + if mode and self.norm_eval: + for m in self.modules(): + if isinstance(m, _BatchNorm): + m.eval() diff --git a/mmcls/models/backbones/shufflenet_v1.py b/mmcls/models/backbones/shufflenet_v1.py index edc2fe3d9..4e6fda236 100644 --- a/mmcls/models/backbones/shufflenet_v1.py +++ b/mmcls/models/backbones/shufflenet_v1.py @@ -6,6 +6,7 @@ from mmcv.cnn import (ConvModule, build_activation_layer, build_conv_layer, from torch.nn.modules.batchnorm import _BatchNorm from mmcls.models.utils import channel_shuffle, make_divisible +from ..builder import BACKBONES from .base_backbone import BaseBackbone @@ -139,8 +140,9 @@ class ShuffleUnit(nn.Module): return out -class ShuffleNetv1(BaseBackbone): - """ShuffleNetv1 backbone. +@BACKBONES.register_module() +class ShuffleNetV1(BaseBackbone): + """ShuffleNetV1 backbone. Args: groups (int, optional): The number of groups to be used in grouped 1x1 @@ -174,7 +176,7 @@ class ShuffleNetv1(BaseBackbone): act_cfg=dict(type='ReLU'), norm_eval=False, with_cp=False): - super(ShuffleNetv1, self).__init__() + super(ShuffleNetV1, self).__init__() self.stage_blocks = [3, 7, 3] self.groups = groups @@ -294,7 +296,7 @@ class ShuffleNetv1(BaseBackbone): return tuple(outs) def train(self, mode=True): - super(ShuffleNetv1, self).train(mode) + super(ShuffleNetV1, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): diff --git a/mmcls/models/backbones/shufflenet_v2.py b/mmcls/models/backbones/shufflenet_v2.py index 8be5563a4..db8d06e34 100644 --- a/mmcls/models/backbones/shufflenet_v2.py +++ b/mmcls/models/backbones/shufflenet_v2.py @@ -5,6 +5,7 @@ from mmcv.cnn import ConvModule, constant_init, kaiming_init from torch.nn.modules.batchnorm import _BatchNorm from mmcls.models.utils import channel_shuffle +from ..builder import BACKBONES from .base_backbone import BaseBackbone @@ -125,8 +126,9 @@ class InvertedResidual(nn.Module): return out -class ShuffleNetv2(BaseBackbone): - """ShuffleNetv2 backbone. +@BACKBONES.register_module() +class ShuffleNetV2(BaseBackbone): + """ShuffleNetV2 backbone. Args: groups (int): The number of groups to be used in grouped 1x1 @@ -160,7 +162,7 @@ class ShuffleNetv2(BaseBackbone): act_cfg=dict(type='ReLU'), norm_eval=False, with_cp=False): - super(ShuffleNetv2, self).__init__() + super(ShuffleNetV2, self).__init__() self.stage_blocks = [4, 8, 4] self.groups = groups self.out_indices = out_indices @@ -273,7 +275,7 @@ class ShuffleNetv2(BaseBackbone): return tuple(outs) def train(self, mode=True): - super(ShuffleNetv2, self).train(mode) + super(ShuffleNetV2, self).train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): diff --git a/tests/test_backbones/test_mobilenet_v2.py b/tests/test_backbones/test_mobilenet_v2.py new file mode 100644 index 000000000..3abd9488e --- /dev/null +++ b/tests/test_backbones/test_mobilenet_v2.py @@ -0,0 +1,255 @@ +import pytest +import torch +from torch.nn.modules import GroupNorm +from torch.nn.modules.batchnorm import _BatchNorm + +from mmcls.models.backbones import MobileNetV2 +from mmcls.models.backbones.mobilenet_v2 import InvertedResidual + + +def is_block(modules): + """Check if is ResNet building block.""" + if isinstance(modules, (InvertedResidual, )): + return True + return False + + +def is_norm(modules): + """Check if is one of the norms.""" + if isinstance(modules, (GroupNorm, _BatchNorm)): + return True + return False + + +def check_norm_state(modules, train_state): + """Check if norm layer is in correct train state.""" + for mod in modules: + if isinstance(mod, _BatchNorm): + if mod.training != train_state: + return False + return True + + +def test_mobilenetv2_invertedresidual(): + + with pytest.raises(AssertionError): + # stride must be in [1, 2] + InvertedResidual(16, 24, stride=3, expand_ratio=6) + + # Test InvertedResidual with checkpoint forward, stride=1 + block = InvertedResidual(16, 24, stride=1, expand_ratio=6) + x = torch.randn(1, 16, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size((1, 24, 56, 56)) + + # Test InvertedResidual with expand_ratio=1 + block = InvertedResidual(16, 16, stride=1, expand_ratio=1) + assert len(block.conv) == 2 + + # Test InvertedResidual with use_res_connect + block = InvertedResidual(16, 16, stride=1, expand_ratio=6) + x = torch.randn(1, 16, 56, 56) + x_out = block(x) + assert block.use_res_connect is True + assert x_out.shape == torch.Size((1, 16, 56, 56)) + + # Test InvertedResidual with checkpoint forward, stride=2 + block = InvertedResidual(16, 24, stride=2, expand_ratio=6) + x = torch.randn(1, 16, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size((1, 24, 28, 28)) + + # Test InvertedResidual with checkpoint forward + block = InvertedResidual(16, 24, stride=1, expand_ratio=6, with_cp=True) + assert block.with_cp + x = torch.randn(1, 16, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size((1, 24, 56, 56)) + + # Test InvertedResidual with act_cfg=dict(type='ReLU') + block = InvertedResidual( + 16, 24, stride=1, expand_ratio=6, act_cfg=dict(type='ReLU')) + x = torch.randn(1, 16, 56, 56) + x_out = block(x) + assert x_out.shape == torch.Size((1, 24, 56, 56)) + + +def test_mobilenetv2_backbone(): + with pytest.raises(TypeError): + # pretrained must be a string path + model = MobileNetV2() + model.init_weights(pretrained=0) + + with pytest.raises(AssertionError): + # frozen_stages must less than 7 + MobileNetV2(frozen_stages=8) + + with pytest.raises(AssertionError): + # the max value in out_indices must less than 7 + MobileNetV2(out_indices=[8]) + + # Test MobileNetV2 with first stage frozen + frozen_stages = 1 + model = MobileNetV2(frozen_stages=frozen_stages) + model.init_weights() + model.train() + + for mod in model.conv1.modules(): + for param in mod.parameters(): + assert param.requires_grad is False + for i in range(1, frozen_stages + 1): + layer = getattr(model, f'layer{i}') + for mod in layer.modules(): + if isinstance(mod, _BatchNorm): + assert mod.training is False + for param in layer.parameters(): + assert param.requires_grad is False + + # Test MobileNetV2 with norm_eval=True + model = MobileNetV2(norm_eval=True) + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), False) + + # Test MobileNetV2 forward with widen_factor=1.0 + model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7)) + model.init_weights() + model.train() + + assert check_norm_state(model.modules(), True) + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 forward with widen_factor=0.5 + model = MobileNetV2(widen_factor=0.5, out_indices=range(0, 7)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 8, 112, 112)) + assert feat[1].shape == torch.Size((1, 16, 56, 56)) + assert feat[2].shape == torch.Size((1, 16, 28, 28)) + assert feat[3].shape == torch.Size((1, 32, 14, 14)) + assert feat[4].shape == torch.Size((1, 48, 14, 14)) + assert feat[5].shape == torch.Size((1, 80, 7, 7)) + assert feat[6].shape == torch.Size((1, 160, 7, 7)) + + # Test MobileNetV2 forward with widen_factor=2.0 + model = MobileNetV2(widen_factor=2.0, out_indices=None) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == torch.Size((1, 2560, 7, 7)) + + # Test MobileNetV2 forward with out_indices=None + model = MobileNetV2(widen_factor=1.0, out_indices=None) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert feat.shape == torch.Size((1, 1280, 7, 7)) + + # Test MobileNetV2 forward with dict(type='ReLU') + model = MobileNetV2( + widen_factor=1.0, act_cfg=dict(type='ReLU'), out_indices=range(0, 7)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with GroupNorm forward + model = MobileNetV2(widen_factor=1.0, out_indices=range(0, 7)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, _BatchNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with BatchNorm forward + model = MobileNetV2( + widen_factor=1.0, + norm_cfg=dict(type='GN', num_groups=2, requires_grad=True), + out_indices=range(0, 7)) + for m in model.modules(): + if is_norm(m): + assert isinstance(m, GroupNorm) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) + + # Test MobileNetV2 with layers 1, 3, 5 out forward + model = MobileNetV2(widen_factor=1.0, out_indices=(0, 2, 4)) + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 3 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 32, 28, 28)) + assert feat[2].shape == torch.Size((1, 96, 14, 14)) + + # Test MobileNetV2 with checkpoint forward + model = MobileNetV2( + widen_factor=1.0, with_cp=True, out_indices=range(0, 7)) + for m in model.modules(): + if is_block(m): + assert m.with_cp + model.init_weights() + model.train() + + imgs = torch.randn(1, 3, 224, 224) + feat = model(imgs) + assert len(feat) == 7 + assert feat[0].shape == torch.Size((1, 16, 112, 112)) + assert feat[1].shape == torch.Size((1, 24, 56, 56)) + assert feat[2].shape == torch.Size((1, 32, 28, 28)) + assert feat[3].shape == torch.Size((1, 64, 14, 14)) + assert feat[4].shape == torch.Size((1, 96, 14, 14)) + assert feat[5].shape == torch.Size((1, 160, 7, 7)) + assert feat[6].shape == torch.Size((1, 320, 7, 7)) diff --git a/tests/test_backbones/test_shufflenet_v1.py b/tests/test_backbones/test_shufflenet_v1.py index 4b138e1a8..79674ca75 100644 --- a/tests/test_backbones/test_shufflenet_v1.py +++ b/tests/test_backbones/test_shufflenet_v1.py @@ -3,7 +3,7 @@ import torch from torch.nn.modules import GroupNorm from torch.nn.modules.batchnorm import _BatchNorm -from mmcls.models.backbones import ShuffleNetv1 +from mmcls.models.backbones import ShuffleNetV1 from mmcls.models.backbones.shufflenet_v1 import ShuffleUnit @@ -66,30 +66,30 @@ def test_shufflenetv1_backbone(): with pytest.raises(ValueError): # frozen_stages must be in range(-1, 4) - ShuffleNetv1(frozen_stages=10) + ShuffleNetV1(frozen_stages=10) with pytest.raises(ValueError): # the item in out_indices must be in range(0, 4) - ShuffleNetv1(out_indices=[5]) + ShuffleNetV1(out_indices=[5]) with pytest.raises(ValueError): # groups must be in [1, 2, 3, 4, 8] - ShuffleNetv1(groups=10) + ShuffleNetV1(groups=10) with pytest.raises(TypeError): # pretrained must be str or None - model = ShuffleNetv1() + model = ShuffleNetV1() model.init_weights(pretrained=1) - # Test ShuffleNetv1 norm state - model = ShuffleNetv1() + # Test ShuffleNetV1 norm state + model = ShuffleNetV1() model.init_weights() model.train() assert check_norm_state(model.modules(), True) - # Test ShuffleNetv1 with first stage frozen + # Test ShuffleNetV1 with first stage frozen frozen_stages = 1 - model = ShuffleNetv1(frozen_stages=frozen_stages) + model = ShuffleNetV1(frozen_stages=frozen_stages) model.init_weights() model.train() for param in model.conv1.parameters(): @@ -102,8 +102,8 @@ def test_shufflenetv1_backbone(): for param in layer.parameters(): assert param.requires_grad is False - # Test ShuffleNetv1 forward with groups=1 - model = ShuffleNetv1(groups=1) + # Test ShuffleNetV1 forward with groups=1 + model = ShuffleNetV1(groups=1) model.init_weights() model.train() @@ -118,8 +118,8 @@ def test_shufflenetv1_backbone(): assert feat[1].shape == torch.Size((1, 288, 14, 14)) assert feat[2].shape == torch.Size((1, 576, 7, 7)) - # Test ShuffleNetv1 forward with groups=2 - model = ShuffleNetv1(groups=2) + # Test ShuffleNetV1 forward with groups=2 + model = ShuffleNetV1(groups=2) model.init_weights() model.train() @@ -134,8 +134,8 @@ def test_shufflenetv1_backbone(): assert feat[1].shape == torch.Size((1, 400, 14, 14)) assert feat[2].shape == torch.Size((1, 800, 7, 7)) - # Test ShuffleNetv1 forward with groups=3 - model = ShuffleNetv1(groups=3) + # Test ShuffleNetV1 forward with groups=3 + model = ShuffleNetV1(groups=3) model.init_weights() model.train() @@ -150,8 +150,8 @@ def test_shufflenetv1_backbone(): assert feat[1].shape == torch.Size((1, 480, 14, 14)) assert feat[2].shape == torch.Size((1, 960, 7, 7)) - # Test ShuffleNetv1 forward with groups=4 - model = ShuffleNetv1(groups=4) + # Test ShuffleNetV1 forward with groups=4 + model = ShuffleNetV1(groups=4) model.init_weights() model.train() @@ -166,8 +166,8 @@ def test_shufflenetv1_backbone(): assert feat[1].shape == torch.Size((1, 544, 14, 14)) assert feat[2].shape == torch.Size((1, 1088, 7, 7)) - # Test ShuffleNetv1 forward with groups=8 - model = ShuffleNetv1(groups=8) + # Test ShuffleNetV1 forward with groups=8 + model = ShuffleNetV1(groups=8) model.init_weights() model.train() @@ -182,8 +182,8 @@ def test_shufflenetv1_backbone(): assert feat[1].shape == torch.Size((1, 768, 14, 14)) assert feat[2].shape == torch.Size((1, 1536, 7, 7)) - # Test ShuffleNetv1 forward with GroupNorm forward - model = ShuffleNetv1( + # Test ShuffleNetV1 forward with GroupNorm forward + model = ShuffleNetV1( groups=3, norm_cfg=dict(type='GN', num_groups=2, requires_grad=True)) model.init_weights() model.train() @@ -199,8 +199,8 @@ def test_shufflenetv1_backbone(): assert feat[1].shape == torch.Size((1, 480, 14, 14)) assert feat[2].shape == torch.Size((1, 960, 7, 7)) - # Test ShuffleNetv1 forward with layers 1, 2 forward - model = ShuffleNetv1(groups=3, out_indices=(1, 2)) + # Test ShuffleNetV1 forward with layers 1, 2 forward + model = ShuffleNetV1(groups=3, out_indices=(1, 2)) model.init_weights() model.train() @@ -214,8 +214,8 @@ def test_shufflenetv1_backbone(): assert feat[0].shape == torch.Size((1, 480, 14, 14)) assert feat[1].shape == torch.Size((1, 960, 7, 7)) - # Test ShuffleNetv1 forward with layers 2 forward - model = ShuffleNetv1(groups=3, out_indices=(2, )) + # Test ShuffleNetV1 forward with layers 2 forward + model = ShuffleNetV1(groups=3, out_indices=(2, )) model.init_weights() model.train() @@ -228,14 +228,14 @@ def test_shufflenetv1_backbone(): assert isinstance(feat, torch.Tensor) assert feat.shape == torch.Size((1, 960, 7, 7)) - # Test ShuffleNetv1 forward with checkpoint forward - model = ShuffleNetv1(groups=3, with_cp=True) + # Test ShuffleNetV1 forward with checkpoint forward + model = ShuffleNetV1(groups=3, with_cp=True) for m in model.modules(): if is_block(m): assert m.with_cp - # Test ShuffleNetv1 with norm_eval - model = ShuffleNetv1(norm_eval=True) + # Test ShuffleNetV1 with norm_eval + model = ShuffleNetV1(norm_eval=True) model.init_weights() model.train() diff --git a/tests/test_backbones/test_shufflenet_v2.py b/tests/test_backbones/test_shufflenet_v2.py index 7b0c6e864..48c58dc34 100644 --- a/tests/test_backbones/test_shufflenet_v2.py +++ b/tests/test_backbones/test_shufflenet_v2.py @@ -3,7 +3,7 @@ import torch from torch.nn.modules import GroupNorm from torch.nn.modules.batchnorm import _BatchNorm -from mmcls.models.backbones import ShuffleNetv2 +from mmcls.models.backbones import ShuffleNetV2 from mmcls.models.backbones.shufflenet_v2 import InvertedResidual @@ -59,26 +59,26 @@ def test_shufflenetv2_backbone(): with pytest.raises(ValueError): # groups must be in 0.5, 1.0, 1.5, 2.0] - ShuffleNetv2(widen_factor=3.0) + ShuffleNetV2(widen_factor=3.0) with pytest.raises(AssertionError): # frozen_stages must be in [0, 1, 2] - ShuffleNetv2(widen_factor=3.0, frozen_stages=3) + ShuffleNetV2(widen_factor=3.0, frozen_stages=3) with pytest.raises(TypeError): # pretrained must be str or None - model = ShuffleNetv2() + model = ShuffleNetV2() model.init_weights(pretrained=1) - # Test ShuffleNetv2 norm state - model = ShuffleNetv2() + # Test ShuffleNetV2 norm state + model = ShuffleNetV2() model.init_weights() model.train() assert check_norm_state(model.modules(), True) - # Test ShuffleNetv2 with first stage frozen + # Test ShuffleNetV2 with first stage frozen frozen_stages = 1 - model = ShuffleNetv2(frozen_stages=frozen_stages) + model = ShuffleNetV2(frozen_stages=frozen_stages) model.init_weights() model.train() for param in model.conv1.parameters(): @@ -91,15 +91,15 @@ def test_shufflenetv2_backbone(): for param in layer.parameters(): assert param.requires_grad is False - # Test ShuffleNetv2 with norm_eval - model = ShuffleNetv2(norm_eval=True) + # Test ShuffleNetV2 with norm_eval + model = ShuffleNetV2(norm_eval=True) model.init_weights() model.train() assert check_norm_state(model.modules(), False) - # Test ShuffleNetv2 forward with widen_factor=0.5 - model = ShuffleNetv2(widen_factor=0.5) + # Test ShuffleNetV2 forward with widen_factor=0.5 + model = ShuffleNetV2(widen_factor=0.5) model.init_weights() model.train() @@ -114,8 +114,8 @@ def test_shufflenetv2_backbone(): assert feat[1].shape == torch.Size((1, 96, 14, 14)) assert feat[2].shape == torch.Size((1, 192, 7, 7)) - # Test ShuffleNetv2 forward with widen_factor=1.0 - model = ShuffleNetv2(widen_factor=1.0) + # Test ShuffleNetV2 forward with widen_factor=1.0 + model = ShuffleNetV2(widen_factor=1.0) model.init_weights() model.train() @@ -130,8 +130,8 @@ def test_shufflenetv2_backbone(): assert feat[1].shape == torch.Size((1, 232, 14, 14)) assert feat[2].shape == torch.Size((1, 464, 7, 7)) - # Test ShuffleNetv2 forward with widen_factor=1.5 - model = ShuffleNetv2(widen_factor=1.5) + # Test ShuffleNetV2 forward with widen_factor=1.5 + model = ShuffleNetV2(widen_factor=1.5) model.init_weights() model.train() @@ -146,8 +146,8 @@ def test_shufflenetv2_backbone(): assert feat[1].shape == torch.Size((1, 352, 14, 14)) assert feat[2].shape == torch.Size((1, 704, 7, 7)) - # Test ShuffleNetv2 forward with widen_factor=2.0 - model = ShuffleNetv2(widen_factor=2.0) + # Test ShuffleNetV2 forward with widen_factor=2.0 + model = ShuffleNetV2(widen_factor=2.0) model.init_weights() model.train() @@ -162,8 +162,8 @@ def test_shufflenetv2_backbone(): assert feat[1].shape == torch.Size((1, 488, 14, 14)) assert feat[2].shape == torch.Size((1, 976, 7, 7)) - # Test ShuffleNetv2 forward with layers 3 forward - model = ShuffleNetv2(widen_factor=1.0, out_indices=(2, )) + # Test ShuffleNetV2 forward with layers 3 forward + model = ShuffleNetV2(widen_factor=1.0, out_indices=(2, )) model.init_weights() model.train() @@ -176,8 +176,8 @@ def test_shufflenetv2_backbone(): assert isinstance(feat, torch.Tensor) assert feat.shape == torch.Size((1, 464, 7, 7)) - # Test ShuffleNetv2 forward with layers 1 2 forward - model = ShuffleNetv2(widen_factor=1.0, out_indices=(1, 2)) + # Test ShuffleNetV2 forward with layers 1 2 forward + model = ShuffleNetV2(widen_factor=1.0, out_indices=(1, 2)) model.init_weights() model.train() @@ -191,8 +191,8 @@ def test_shufflenetv2_backbone(): assert feat[0].shape == torch.Size((1, 232, 14, 14)) assert feat[1].shape == torch.Size((1, 464, 7, 7)) - # Test ShuffleNetv2 forward with checkpoint forward - model = ShuffleNetv2(widen_factor=1.0, with_cp=True) + # Test ShuffleNetV2 forward with checkpoint forward + model = ShuffleNetV2(widen_factor=1.0, with_cp=True) for m in model.modules(): if is_block(m): assert m.with_cp