deep-person-reid/torchreid/models/shufflenet.py

162 lines
5.7 KiB
Python
Raw Normal View History

2018-07-04 17:32:43 +08:00
from __future__ import absolute_import
from __future__ import division
2018-07-02 17:33:10 +08:00
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torch.utils.model_zoo as model_zoo
2018-07-02 17:33:10 +08:00
__all__ = ['shufflenet']
model_urls = {
# training epoch = 90, top1 = 61.8
'imagenet': 'http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/imagenet-pretrained/shufflenet-bee1b265.pth.tar',
}
2018-07-02 17:33:10 +08:00
class ChannelShuffle(nn.Module):
def __init__(self, num_groups):
super(ChannelShuffle, self).__init__()
self.g = num_groups
def forward(self, x):
b, c, h, w = x.size()
2018-11-06 05:30:40 +08:00
n = c // self.g
2018-07-02 17:33:10 +08:00
# reshape
x = x.view(b, self.g, n, h, w)
# transpose
x = x.permute(0, 2, 1, 3, 4).contiguous()
# flatten
x = x.view(b, c, h, w)
return x
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels, stride, num_groups, group_conv1x1=True):
super(Bottleneck, self).__init__()
assert stride in [1, 2], "Warning: stride must be either 1 or 2"
self.stride = stride
mid_channels = out_channels // 4
if stride == 2: out_channels -= in_channels
# group conv is not applied to first conv1x1 at stage 2
num_groups_conv1x1 = num_groups if group_conv1x1 else 1
self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, groups=num_groups_conv1x1, bias=False)
self.bn1 = nn.BatchNorm2d(mid_channels)
self.shuffle1 = ChannelShuffle(num_groups)
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, stride=stride, padding=1, groups=mid_channels, bias=False)
self.bn2 = nn.BatchNorm2d(mid_channels)
self.conv3 = nn.Conv2d(mid_channels, out_channels, 1, groups=num_groups, bias=False)
self.bn3 = nn.BatchNorm2d(out_channels)
if stride == 2: self.shortcut = nn.AvgPool2d(3, stride=2, padding=1)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.shuffle1(out)
out = self.bn2(self.conv2(out))
out = self.bn3(self.conv3(out))
if self.stride == 2:
res = self.shortcut(x)
out = F.relu(torch.cat([res, out], 1))
else:
out = F.relu(x + out)
return out
# configuration of (num_groups: #out_channels) based on Table 1 in the paper
cfg = {
1: [144, 288, 576],
2: [200, 400, 800],
3: [240, 480, 960],
4: [272, 544, 1088],
8: [384, 768, 1536],
}
class ShuffleNet(nn.Module):
2018-10-28 00:46:48 +08:00
"""
ShuffleNet
2018-07-02 17:33:10 +08:00
Reference:
Zhang et al. ShuffleNet: An Extremely Efficient Convolutional Neural
Network for Mobile Devices. CVPR 2018.
"""
def __init__(self, num_classes, loss={'xent'}, num_groups=3, **kwargs):
super(ShuffleNet, self).__init__()
self.loss = loss
self.conv1 = nn.Sequential(
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
nn.BatchNorm2d(24),
nn.ReLU(),
nn.MaxPool2d(3, stride=2, padding=1),
)
self.stage2 = nn.Sequential(
Bottleneck(24, cfg[num_groups][0], 2, num_groups, group_conv1x1=False),
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
)
self.stage3 = nn.Sequential(
Bottleneck(cfg[num_groups][0], cfg[num_groups][1], 2, num_groups),
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
)
self.stage4 = nn.Sequential(
Bottleneck(cfg[num_groups][1], cfg[num_groups][2], 2, num_groups),
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
)
self.classifier = nn.Linear(cfg[num_groups][2], num_classes)
self.feat_dim = cfg[num_groups][2]
def forward(self, x):
x = self.conv1(x)
x = self.stage2(x)
x = self.stage3(x)
x = self.stage4(x)
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1)
if not self.training:
return x
y = self.classifier(x)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, x
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, model_url):
"""
Initialize model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print("Initialized model with pretrained weights from {}".format(model_url))
def shufflenet(num_classes, loss, pretrained='imagenet', **kwargs):
model = ShuffleNet(num_classes, loss, **kwargs)
if pretrained == 'imagenet':
init_pretrained_weights(model, model_urls['imagenet'])
return model