199 lines
6.1 KiB
Python
199 lines
6.1 KiB
Python
from __future__ import division, absolute_import
|
|
import torch
|
|
import torch.utils.model_zoo as model_zoo
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
|
|
__all__ = ['shufflenet']
|
|
|
|
model_urls = {
|
|
# training epoch = 90, top1 = 61.8
|
|
'imagenet':
|
|
'https://mega.nz/#!RDpUlQCY!tr_5xBEkelzDjveIYBBcGcovNCOrgfiJO9kiidz9fZM',
|
|
}
|
|
|
|
|
|
class ChannelShuffle(nn.Module):
|
|
|
|
def __init__(self, num_groups):
|
|
super(ChannelShuffle, self).__init__()
|
|
self.g = num_groups
|
|
|
|
def forward(self, x):
|
|
b, c, h, w = x.size()
|
|
n = c // self.g
|
|
# reshape
|
|
x = x.view(b, self.g, n, h, w)
|
|
# transpose
|
|
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
|
# flatten
|
|
x = x.view(b, c, h, w)
|
|
return x
|
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
|
|
def __init__(
|
|
self,
|
|
in_channels,
|
|
out_channels,
|
|
stride,
|
|
num_groups,
|
|
group_conv1x1=True
|
|
):
|
|
super(Bottleneck, self).__init__()
|
|
assert stride in [1, 2], 'Warning: stride must be either 1 or 2'
|
|
self.stride = stride
|
|
mid_channels = out_channels // 4
|
|
if stride == 2:
|
|
out_channels -= in_channels
|
|
# group conv is not applied to first conv1x1 at stage 2
|
|
num_groups_conv1x1 = num_groups if group_conv1x1 else 1
|
|
self.conv1 = nn.Conv2d(
|
|
in_channels,
|
|
mid_channels,
|
|
1,
|
|
groups=num_groups_conv1x1,
|
|
bias=False
|
|
)
|
|
self.bn1 = nn.BatchNorm2d(mid_channels)
|
|
self.shuffle1 = ChannelShuffle(num_groups)
|
|
self.conv2 = nn.Conv2d(
|
|
mid_channels,
|
|
mid_channels,
|
|
3,
|
|
stride=stride,
|
|
padding=1,
|
|
groups=mid_channels,
|
|
bias=False
|
|
)
|
|
self.bn2 = nn.BatchNorm2d(mid_channels)
|
|
self.conv3 = nn.Conv2d(
|
|
mid_channels, out_channels, 1, groups=num_groups, bias=False
|
|
)
|
|
self.bn3 = nn.BatchNorm2d(out_channels)
|
|
if stride == 2:
|
|
self.shortcut = nn.AvgPool2d(3, stride=2, padding=1)
|
|
|
|
def forward(self, x):
|
|
out = F.relu(self.bn1(self.conv1(x)))
|
|
out = self.shuffle1(out)
|
|
out = self.bn2(self.conv2(out))
|
|
out = self.bn3(self.conv3(out))
|
|
if self.stride == 2:
|
|
res = self.shortcut(x)
|
|
out = F.relu(torch.cat([res, out], 1))
|
|
else:
|
|
out = F.relu(x + out)
|
|
return out
|
|
|
|
|
|
# configuration of (num_groups: #out_channels) based on Table 1 in the paper
|
|
cfg = {
|
|
1: [144, 288, 576],
|
|
2: [200, 400, 800],
|
|
3: [240, 480, 960],
|
|
4: [272, 544, 1088],
|
|
8: [384, 768, 1536],
|
|
}
|
|
|
|
|
|
class ShuffleNet(nn.Module):
|
|
"""ShuffleNet.
|
|
|
|
Reference:
|
|
Zhang et al. ShuffleNet: An Extremely Efficient Convolutional Neural
|
|
Network for Mobile Devices. CVPR 2018.
|
|
|
|
Public keys:
|
|
- ``shufflenet``: ShuffleNet (groups=3).
|
|
"""
|
|
|
|
def __init__(self, num_classes, loss='softmax', num_groups=3, **kwargs):
|
|
super(ShuffleNet, self).__init__()
|
|
self.loss = loss
|
|
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
|
|
nn.BatchNorm2d(24),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(3, stride=2, padding=1),
|
|
)
|
|
|
|
self.stage2 = nn.Sequential(
|
|
Bottleneck(
|
|
24, cfg[num_groups][0], 2, num_groups, group_conv1x1=False
|
|
),
|
|
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
|
|
)
|
|
|
|
self.stage3 = nn.Sequential(
|
|
Bottleneck(cfg[num_groups][0], cfg[num_groups][1], 2, num_groups),
|
|
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
|
)
|
|
|
|
self.stage4 = nn.Sequential(
|
|
Bottleneck(cfg[num_groups][1], cfg[num_groups][2], 2, num_groups),
|
|
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
|
|
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
|
|
)
|
|
|
|
self.classifier = nn.Linear(cfg[num_groups][2], num_classes)
|
|
self.feat_dim = cfg[num_groups][2]
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(x)
|
|
x = self.stage2(x)
|
|
x = self.stage3(x)
|
|
x = self.stage4(x)
|
|
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1)
|
|
|
|
if not self.training:
|
|
return x
|
|
|
|
y = self.classifier(x)
|
|
|
|
if self.loss == 'softmax':
|
|
return y
|
|
elif self.loss == 'triplet':
|
|
return y, x
|
|
else:
|
|
raise KeyError('Unsupported loss: {}'.format(self.loss))
|
|
|
|
|
|
def init_pretrained_weights(model, model_url):
|
|
"""Initializes model with pretrained weights.
|
|
|
|
Layers that don't match with pretrained layers in name or size are kept unchanged.
|
|
"""
|
|
pretrain_dict = model_zoo.load_url(model_url)
|
|
model_dict = model.state_dict()
|
|
pretrain_dict = {
|
|
k: v
|
|
for k, v in pretrain_dict.items()
|
|
if k in model_dict and model_dict[k].size() == v.size()
|
|
}
|
|
model_dict.update(pretrain_dict)
|
|
model.load_state_dict(model_dict)
|
|
|
|
|
|
def shufflenet(num_classes, loss='softmax', pretrained=True, **kwargs):
|
|
model = ShuffleNet(num_classes, loss, **kwargs)
|
|
if pretrained:
|
|
# init_pretrained_weights(model, model_urls['imagenet'])
|
|
import warnings
|
|
warnings.warn(
|
|
'The imagenet pretrained weights need to be manually downloaded from {}'
|
|
.format(model_urls['imagenet'])
|
|
)
|
|
return model
|