from __future__ import absolute_import from __future__ import division import torch from torch import nn from torch.nn import functional as F import torchvision __all__ = ['ShuffleNet'] class ChannelShuffle(nn.Module): def __init__(self, num_groups): super(ChannelShuffle, self).__init__() self.g = num_groups def forward(self, x): b, c, h, w = x.size() n = c // self.g # reshape x = x.view(b, self.g, n, h, w) # transpose x = x.permute(0, 2, 1, 3, 4).contiguous() # flatten x = x.view(b, c, h, w) return x class Bottleneck(nn.Module): def __init__(self, in_channels, out_channels, stride, num_groups, group_conv1x1=True): super(Bottleneck, self).__init__() assert stride in [1, 2], "Warning: stride must be either 1 or 2" self.stride = stride mid_channels = out_channels // 4 if stride == 2: out_channels -= in_channels # group conv is not applied to first conv1x1 at stage 2 num_groups_conv1x1 = num_groups if group_conv1x1 else 1 self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, groups=num_groups_conv1x1, bias=False) self.bn1 = nn.BatchNorm2d(mid_channels) self.shuffle1 = ChannelShuffle(num_groups) self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, stride=stride, padding=1, groups=mid_channels, bias=False) self.bn2 = nn.BatchNorm2d(mid_channels) self.conv3 = nn.Conv2d(mid_channels, out_channels, 1, groups=num_groups, bias=False) self.bn3 = nn.BatchNorm2d(out_channels) if stride == 2: self.shortcut = nn.AvgPool2d(3, stride=2, padding=1) def forward(self, x): out = F.relu(self.bn1(self.conv1(x))) out = self.shuffle1(out) out = self.bn2(self.conv2(out)) out = self.bn3(self.conv3(out)) if self.stride == 2: res = self.shortcut(x) out = F.relu(torch.cat([res, out], 1)) else: out = F.relu(x + out) return out # configuration of (num_groups: #out_channels) based on Table 1 in the paper cfg = { 1: [144, 288, 576], 2: [200, 400, 800], 3: [240, 480, 960], 4: [272, 544, 1088], 8: [384, 768, 1536], } class ShuffleNet(nn.Module): """ ShuffleNet Reference: Zhang et al. ShuffleNet: An Extremely Efficient Convolutional Neural Network for Mobile Devices. CVPR 2018. """ def __init__(self, num_classes, loss={'xent'}, num_groups=3, **kwargs): super(ShuffleNet, self).__init__() self.loss = loss self.conv1 = nn.Sequential( nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False), nn.BatchNorm2d(24), nn.ReLU(), nn.MaxPool2d(3, stride=2, padding=1), ) self.stage2 = nn.Sequential( Bottleneck(24, cfg[num_groups][0], 2, num_groups, group_conv1x1=False), Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups), ) self.stage3 = nn.Sequential( Bottleneck(cfg[num_groups][0], cfg[num_groups][1], 2, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups), ) self.stage4 = nn.Sequential( Bottleneck(cfg[num_groups][1], cfg[num_groups][2], 2, num_groups), Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups), ) self.classifier = nn.Linear(cfg[num_groups][2], num_classes) self.feat_dim = cfg[num_groups][2] def forward(self, x): x = self.conv1(x) x = self.stage2(x) x = self.stage3(x) x = self.stage4(x) x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1) if not self.training: return x y = self.classifier(x) if self.loss == {'xent'}: return y elif self.loss == {'xent', 'htri'}: return y, x else: raise KeyError("Unsupported loss: {}".format(self.loss))