diff --git a/mmcls/models/backbones/shufflenet_v1.py b/mmcls/models/backbones/shufflenet_v1.py index 3518c061..0f2d82fb 100644 --- a/mmcls/models/backbones/shufflenet_v1.py +++ b/mmcls/models/backbones/shufflenet_v1.py @@ -4,14 +4,30 @@ from collections import OrderedDict import torch import torch.nn as nn import torch.utils.checkpoint as cp +from mmcv.cnn.weight_init import constant_init, kaiming_init from mmcv.runner import load_checkpoint from .base_backbone import BaseBackbone -from .weight_init import constant_init, kaiming_init def conv3x3(inplanes, planes, stride=1, padding=1, bias=False, groups=1): - """3x3 convolution with padding + """3x3 convolution + + Applies a 2D convolution with the kernel_size 3x3 over an input signal + composed of several input planes. + + Args: + inplanes (int): Number of channels of the feature maps. + planes (int): Number of channels produced by the convolution. + stride (int or tuple, optional): Stride of the convolution. + Default is 1. + padding (int): Controls the amount of implicit zero-paddings on both + sides for padding number of points for each dimension. + Default is 1. + bias (bool, optional): Whether to add a learnable bias to the output. + Default is True + groups (int, optional): Number of blocked connections from input + channels to output channels. Default is 1 """ return nn.Conv2d( inplanes, @@ -24,33 +40,61 @@ def conv3x3(inplanes, planes, stride=1, padding=1, bias=False, groups=1): def conv1x1(inplanes, planes, groups=1): - """1x1 convolution with padding - - Normal pointwise convolution when groups == 1 - - Grouped pointwise convolution when groups > 1 + """1x1 convolution + + Applies a 2D convolution with the kernel_size 1x1 over an input signal + composed of several input planes. + + Args: + inplanes (int): Number of channels of the input feature maps. + planes (int): Number of channels produced by the convolution. + groups (int, optional): Number of blocked connections from input + channels to output channels. Default: 1 """ return nn.Conv2d(inplanes, planes, kernel_size=1, groups=groups, stride=1) def channel_shuffle(x, groups): - batchsize, num_channels, height, width = x.data.size() - assert (num_channels % groups == 0) + """ Channel Shuffle operation + + This function enable cross-group information flow for multiple group + convolution layers. + + Args: + x: The input tensor. + groups (int): The number of groups to divide the input tensor + in channel dimension. + + Returns: + x: The output tensor after channel shuffle operation. + + """ + batchsize, num_channels, height, width = x.size() + assert (num_channels % groups == 0), 'num_channels should ' \ + 'be divisible by groups' channels_per_group = num_channels // groups - # reshape x = x.view(batchsize, groups, channels_per_group, height, width) - - # transpose - # - contiguous() required if transpose() is used before view(). - # See https://github.com/pytorch/pytorch/issues/764 x = torch.transpose(x, 1, 2).contiguous() - - # flatten x = x.view(batchsize, -1, height, width) return x def _make_divisible(v, divisor, min_value=None): + """ Make divisible function + + This function ensures that all layers have a channel number that is + divisible by divisor. + + Args: + v (int): The original channel number + divisor (int): The divisor to fully divide the channel number + min_value (int, optional): the minimum value of the output channel. + + Returns: + new_v (int): The modified output channel number + """ if min_value is None: min_value = divisor new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) @@ -61,6 +105,26 @@ def _make_divisible(v, divisor, min_value=None): class ShuffleUnit(nn.Module): + """ShuffleUnit block. + + ShuffleNet unit with pointwise group convolution (GConv) and channel + shuffle. + + Args: + inplanes (int): The input channels of the ShuffleUnit. + planes (int): The output channels of the ShuffleUnit. + groups (int, optional): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. + first_block (bool, optional): Whether is the first ShuffleUnit of a + sequential ShuffleUnits. If True, use the grouped 1x1 convolution. + combine (str, optional): The ways to combine the input and output + branches. + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + + Returns: + out: output tensor + """ def __init__(self, inplanes, @@ -69,7 +133,6 @@ class ShuffleUnit(nn.Module): first_block=True, combine='add', with_cp=False): - super(ShuffleUnit, self).__init__() self.inplanes = inplanes self.planes = planes @@ -82,18 +145,15 @@ class ShuffleUnit(nn.Module): if self.combine == 'add': self.depthwise_stride = 1 self._combine_func = self._add + assert inplanes == planes, 'inplanes must be equal to ' \ + 'planes when combine is add.' elif self.combine == 'concat': self.depthwise_stride = 2 self._combine_func = self._concat self.planes -= self.inplanes else: - raise ValueError("Cannot combine tensors with \"{}\" " - "Only \"add\" and \"concat\" are " - "supported".format(self.combine)) - - if combine == 'add': - assert inplanes == planes, \ - 'inplanes must be equal to outplanes when combine is add' + raise ValueError(f'Cannot combine tensors with {self.combine}. ' + f'Only "add" and "concat" are supported.') self.first_1x1_groups = self.groups if first_block else 1 self.g_conv_1x1_compress = self._make_grouped_conv1x1( @@ -108,15 +168,16 @@ class ShuffleUnit(nn.Module): self.bottleneck_channels, stride=self.depthwise_stride, groups=self.bottleneck_channels) - self.bn_after_depthwise = \ - nn.BatchNorm2d(self.bottleneck_channels) - self.g_conv_1x1_expand = \ - self._make_grouped_conv1x1(self.bottleneck_channels, - self.planes, - self.groups, - batch_norm=True, - relu=False) + self.bn_after_depthwise = nn.BatchNorm2d(self.bottleneck_channels) + + self.g_conv_1x1_expand = self._make_grouped_conv1x1( + self.bottleneck_channels, + self.planes, + self.groups, + batch_norm=True, + relu=False) + self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) self.relu = nn.ReLU(inplace=True) @@ -183,10 +244,11 @@ class ShuffleNetv1(BaseBackbone): """ShuffleNetv1 backbone. Args: - groups (int): number of groups to be used in grouped - 1x1 convolutions in each ShuffleUnit. Default is 3 for best - performance according to original paper. - widen_factor (float): Config of widen_factor. + groups (int, optional): The number of groups to be used in grouped 1x1 + convolutions in each ShuffleUnit. Default is 3 for best performance + according to original paper. + widen_factor (float, optional): Width multiplier - adjusts number of + channels in each layer by this amount. Default is 1.0. out_indices (Sequence[int]): Output from which stages. frozen_stages (int): Stages to be frozen (all param fixed). -1 means not freezing any parameters. @@ -208,6 +270,15 @@ class ShuffleNetv1(BaseBackbone): super(ShuffleNetv1, self).__init__() blocks = [3, 7, 3] self.groups = groups + + for indice in out_indices: + if indice not in range(0, 4): + raise ValueError(f'the item in out_indices must in ' + f'range(0, 4). But received {indice}') + + if frozen_stages not in [-1, 1, 2, 3]: + raise ValueError(f'frozen_stages must in [-1, 1, 2, 3]. ' + f'But received {frozen_stages}') self.out_indices = out_indices self.frozen_stages = frozen_stages self.bn_eval = bn_eval @@ -225,8 +296,9 @@ class ShuffleNetv1(BaseBackbone): elif groups == 8: channels = [384, 768, 1536] else: - raise ValueError("{} groups is not supported for " - "1x1 Grouped Convolutions".format(groups)) + raise ValueError(f'{groups} groups is not supported for 1x1 ' + f'Grouped Convolutions') + channels = [_make_divisible(ch * widen_factor, 8) for ch in channels] self.inplanes = int(24 * widen_factor) @@ -252,6 +324,20 @@ class ShuffleNetv1(BaseBackbone): raise TypeError('pretrained must be a str or None') def _make_layer(self, outplanes, blocks, first_block=True, with_cp=False): + """ Stack n bottleneck modules where n is inferred from the depth of + the network. + + Args: + outplanes: number of output channels + blocks: number of blocks to be built + first_block (bool, optional): Whether is the first ShuffleUnit of a + sequential ShuffleUnits. If True, use the grouped 1x1 + convolution. + with_cp (bool, optional): Use checkpoint or not. Using checkpoint + will save some memory while slowing down the training speed. + + Returns: a Module consisting of n sequential ShuffleUnits. + """ layers = [] for i in range(blocks): if i == 0: @@ -272,6 +358,7 @@ class ShuffleNetv1(BaseBackbone): first_block=True, combine='add', with_cp=with_cp)) + self.inplanes = outplanes return nn.Sequential(*layers) @@ -311,7 +398,7 @@ class ShuffleNetv1(BaseBackbone): for param in self.conv1.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): - mod = getattr(self, 'layer{}'.format(i)) - mod.eval() - for param in mod.parameters(): + layer = getattr(self, f'layer{i}') + layer.eval() + for param in layer.parameters(): param.requires_grad = False diff --git a/mmcls/models/backbones/weight_init.py b/mmcls/models/backbones/weight_init.py deleted file mode 100644 index e06e6cca..00000000 --- a/mmcls/models/backbones/weight_init.py +++ /dev/null @@ -1,66 +0,0 @@ -# Copyright (c) Open-MMLab. All rights reserved. -import numpy as np -import torch.nn as nn - - -def constant_init(module, val, bias=0): - if hasattr(module, 'weight') and module.weight is not None: - nn.init.constant_(module.weight, val) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def xavier_init(module, gain=1, bias=0, distribution='normal'): - assert distribution in ['uniform', 'normal'] - if distribution == 'uniform': - nn.init.xavier_uniform_(module.weight, gain=gain) - else: - nn.init.xavier_normal_(module.weight, gain=gain) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def normal_init(module, mean=0, std=1, bias=0): - nn.init.normal_(module.weight, mean, std) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def uniform_init(module, a=0, b=1, bias=0): - nn.init.uniform_(module.weight, a, b) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def kaiming_init(module, - a=0, - mode='fan_out', - nonlinearity='relu', - bias=0, - distribution='normal'): - assert distribution in ['uniform', 'normal'] - if distribution == 'uniform': - nn.init.kaiming_uniform_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) - else: - nn.init.kaiming_normal_( - module.weight, a=a, mode=mode, nonlinearity=nonlinearity) - if hasattr(module, 'bias') and module.bias is not None: - nn.init.constant_(module.bias, bias) - - -def caffe2_xavier_init(module, bias=0): - # `XavierFill` in Caffe2 corresponds to `kaiming_uniform_` in PyTorch - # Acknowledgment to FAIR's internal code - kaiming_init( - module, - a=1, - mode='fan_in', - nonlinearity='leaky_relu', - distribution='uniform') - - -def bias_init_with_prob(prior_prob): - """ initialize conv/fc bias value according to giving probablity""" - bias_init = float(-np.log((1 - prior_prob) / prior_prob)) - return bias_init diff --git a/tests/test_backbone.py b/tests/test_backbone.py index 8b2d584f..0bd69ba1 100644 --- a/tests/test_backbone.py +++ b/tests/test_backbone.py @@ -37,7 +37,7 @@ def test_shufflenetv1_shuffleuint(): ShuffleUnit(24, 16, groups=3, first_block=True, combine='test') with pytest.raises(ValueError): - # in_channels must be divisible by groups + # inplanes must be divisible by groups ShuffleUnit(64, 64, groups=3, first_block=True, combine='add') with pytest.raises(AssertionError): @@ -66,6 +66,14 @@ def test_shufflenetv1_shuffleuint(): def test_shufflenetv1_backbone(): + with pytest.raises(ValueError): + # frozen_stages must in [-1, 1, 2, 3] + ShuffleNetv1(frozen_stages=10) + + with pytest.raises(ValueError): + # the item in out_indices must in [0, 1, 2, 3] + ShuffleNetv1(out_indices=[5]) + with pytest.raises(ValueError): # groups must in [1, 2, 3, 4, 8] ShuffleNetv1(groups=10)