From a54265179418fae23970f59c422d2d61c4391bf5 Mon Sep 17 00:00:00 2001 From: louzan Date: Thu, 28 May 2020 11:44:43 +0800 Subject: [PATCH] add shufflenet_v1 --- mmcls/models/backbones/shufflenet_v1.py | 311 ++++++++++++++++++++++++ 1 file changed, 311 insertions(+) create mode 100644 mmcls/models/backbones/shufflenet_v1.py diff --git a/mmcls/models/backbones/shufflenet_v1.py b/mmcls/models/backbones/shufflenet_v1.py new file mode 100644 index 00000000..9bd7cec4 --- /dev/null +++ b/mmcls/models/backbones/shufflenet_v1.py @@ -0,0 +1,311 @@ +import logging +from collections import OrderedDict + +import torch +import torch.nn as nn +import torch.utils.checkpoint as cp + +from ..runner import load_checkpoint +from .base_backbone import BaseBackbone +from .weight_init import constant_init, kaiming_init + + +def conv3x3(inplanes, planes, stride=1, padding=1, bias=False, groups=1): + """3x3 convolution with padding + """ + return nn.Conv2d( + inplanes, + planes, + kernel_size=3, + stride=stride, + padding=padding, + bias=bias, + groups=groups) + + +def conv1x1(inplanes, planes, groups=1): + """1x1 convolution with padding + - Normal pointwise convolution when groups == 1 + - Grouped pointwise convolution when groups > 1 + """ + return nn.Conv2d( + inplanes, + planes, + kernel_size=1, + groups=groups, + stride=1) + + +def channel_shuffle(x, groups): + batchsize, num_channels, height, width = x.data.size() + assert (num_channels % groups == 0) + channels_per_group = num_channels // groups + + # reshape + x = x.view(batchsize, groups, channels_per_group, height, width) + + # transpose + # - contiguous() required if transpose() is used before view(). + # See https://github.com/pytorch/pytorch/issues/764 + x = torch.transpose(x, 1, 2).contiguous() + + # flatten + x = x.view(batchsize, -1, height, width) + + return x + + +def _make_divisible(v, divisor, min_value=None): + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +# noinspection PyShadowingNames,PyShadowingNames +class ShuffleUnit(nn.Module): + def __init__(self, + inplanes, + planes, + groups=3, + first_block=True, + combine='add', + with_cp=False): + + super(ShuffleUnit, self).__init__() + self.inplanes = inplanes + self.planes = planes + self.first_block = first_block + self.combine = combine + self.groups = groups + self.bottleneck_channels = self.planes // 4 + self.with_cp = with_cp + + if self.combine == 'add': + self.depthwise_stride = 1 + self._combine_func = self._add + elif self.combine == 'concat': + self.depthwise_stride = 2 + self._combine_func = self._concat + self.planes -= self.inplanes + else: + raise ValueError("Cannot combine tensors with \"{}\" " + "Only \"add\" and \"concat\" are " + "supported".format(self.combine)) + + self.first_1x1_groups = self.groups if first_block else 1 + self.g_conv_1x1_compress = self._make_grouped_conv1x1( + self.inplanes, + self.bottleneck_channels, + self.first_1x1_groups, + batch_norm=True, + relu=True + ) + + self.depthwise_conv3x3 = conv3x3(self.bottleneck_channels, + self.bottleneck_channels, + stride=self.depthwise_stride, + groups=self.bottleneck_channels) + self.nn.BatchNorm2d_after_depthwise = \ + nn.BatchNorm2d(self.bottleneck_channels) + + self.g_conv_1x1_expand = \ + self._make_grouped_conv1x1(self.bottleneck_channels, + self.planes, + self.groups, + batch_norm=True, + relu=False) + self.avgpool = nn.AvgPool2d(kernel_size=3, stride=2, padding=1) + self.relu = nn.ReLU(inplace=True) + + @staticmethod + def _add(x, out): + # residual connection + return x + out + + @staticmethod + def _concat(x, out): + # concatenate along channel axis + return torch.cat((x, out), 1) + + @staticmethod + def _make_grouped_conv1x1(inplanes, planes, groups, + batch_norm=True, relu=False): + + modules = OrderedDict() + + conv = conv1x1(inplanes, planes, groups=groups) + modules['conv1x1'] = conv + + if batch_norm: + modules['batch_norm'] = nn.BatchNorm2d(planes) + if relu: + modules['relu'] = nn.ReLU() + if len(modules) > 1: + return nn.Sequential(modules) + else: + return conv + + def forward(self, x): + def _inner_forward(x): + residual = x + + if self.combine == 'concat': + residual = self.avgpool(residual) + + out = self.g_conv_1x1_compress(x) + out = channel_shuffle(out, self.groups) + out = self.depthwise_conv3x3(out) + out = self.nn.BatchNorm2d_after_depthwise(out) + out = self.g_conv_1x1_expand(out) + + out = self._combine_func(residual, out) + + return out + + if self.with_cp and x.requires_grad: + out = cp.checkpoint(_inner_forward, x) + else: + out = _inner_forward(x) + + out = self.relu(out) + + return out + + +class ShuffleNetv1(BaseBackbone): + """ShuffleNetv1 backbone. + + Args: + groups (int): number of groups to be used in grouped + 1x1 convolutions in each ShuffleUnit. Default is 3 for best + performance according to original paper. + widen_factor (float): Config of widen_factor. + out_indices (Sequence[int]): Output from which stages. + frozen_stages (int): Stages to be frozen (all param fixed). -1 means + not freezing any parameters. + bn_eval (bool): Whether to set BN layers as eval mode, namely, freeze + running stats (mean and var). + bn_frozen (bool): Whether to freeze weight and bias of BN layers. + with_cp (bool): Use checkpoint or not. Using checkpoint will save some + memory while slowing down the training speed. + """ + + def __init__(self, + groups=3, + widen_factor=1.0, + out_indices=(0, 1, 2, 3), + frozen_stages=-1, + bn_eval=True, + bn_frozen=False, + with_cp=False): + super(ShuffleNetv1, self).__init__() + blocks = [3, 7, 3] + self.groups = groups + self.out_indices = out_indices + self.frozen_stages = frozen_stages + self.bn_eval = bn_eval + self.bn_frozen = bn_frozen + self.with_cp = with_cp + + if groups == 1: + channels = [144, 288, 576] + elif groups == 2: + channels = [200, 400, 800] + elif groups == 3: + channels = [240, 480, 960] + elif groups == 4: + channels = [272, 544, 1088] + elif groups == 8: + channels = [384, 768, 1536] + else: + raise ValueError("{} groups is not supported for " + "1x1 Grouped Convolutions".format(groups)) + channels = [_make_divisible(ch * widen_factor, 8) for ch in channels] + + self.inplanes = int(24 * widen_factor) + self.conv1 = conv3x3(3, self.inplanes, stride=2) + self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) + + self.layer2 = self._make_layer(channels[0], blocks[0], + first_block=False, with_cp=with_cp) + self.layer3 = self._make_layer(channels[1], blocks[1], with_cp=with_cp) + self.layer4 = self._make_layer(channels[2], blocks[2], with_cp=with_cp) + + def init_weights(self, pretrained=None): + if isinstance(pretrained, str): + logger = logging.getLogger() + load_checkpoint(self, pretrained, strict=False, logger=logger) + elif pretrained is None: + for m in self.modules(): + if isinstance(m, nn.Conv2d): + kaiming_init(m) + elif isinstance(m, nn.BatchNorm2d): + constant_init(m, 1) + else: + raise TypeError('pretrained must be a str or None') + + def _make_layer(self, + outplanes, + blocks, + first_block=True, + with_cp=False): + layers = [] + for i in range(blocks): + if i == 0: + layers.append(ShuffleUnit(self.inplanes, outplanes, + groups=self.groups, + first_block=first_block, + combine='concat', + with_cp=with_cp)) + else: + layers.append(ShuffleUnit(self.inplanes, outplanes, + groups=self.groups, + first_block=True, + combine='add', + with_cp=with_cp)) + self.inplanes = outplanes + + return nn.Sequential(*layers) + + def forward(self, x): + x = self.conv1(x) + x = self.maxpool(x) + outs = [] + if 0 in self.out_indices: + outs.append(x) + x = self.layer2(x) + if 1 in self.out_indices: + outs.append(x) + x = self.layer3(x) + if 2 in self.out_indices: + outs.append(x) + x = self.layer4(x) + if 3 in self.out_indices: + outs.append(x) + + if len(outs) == 1: + return outs[0] + else: + return tuple(outs) + + def train(self, mode=True): + super(ShuffleNetv1, self).train(mode) + if self.bn_eval: + for m in self.modules(): + if isinstance(m, nn.BatchNorm2d): + m.eval() + if self.bn_frozen: + for params in m.parameters(): + params.requires_grad = False + if mode and self.frozen_stages >= 0: + for param in self.conv1.parameters(): + param.requires_grad = False + for i in range(1, self.frozen_stages + 1): + mod = getattr(self, 'layer{}'.format(i)) + mod.eval() + for param in mod.parameters(): + param.requires_grad = False