update mobilenetv2; add imagenet/reid weights

pull/133/head
KaiyangZhou 2019-03-05 10:56:39 +00:00
parent eb680de371
commit 7ff331f871
3 changed files with 152 additions and 53 deletions

View File

@ -14,6 +14,8 @@
| se_resnet50_fc512<sup>:dog:</sup> | 27.1 | xent | (256, 128) | [91.9 (75.8)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/se_resnet50_fc512_market_xent.zip) | [81.5 (63.7)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/se_resnet50_fc512_duke_xent.zip) | [71.1 (39.8)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/se_resnet50_fc512_msmt_xent.zip) |
| shufflenet<sup>:dog:</sup> | 0.9 | xent | (256, 128) | [84.1(64.1)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/shufflenet_market_xent.zip) | [73.4(51.9)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/shufflenet_duke_xent.zip) | [51.3(24.2)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/shufflenet_msmt_xent.zip) |
| squeezenet1_0_fc512<sup>:dog:</sup> | 1.0 | xent | (256, 128) | [79.3 (52.2)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/squeezenet1_0_fc512_market_xent.zip) | [66.6 (42.6)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/squeezenet1_0_fc512_duke_xent.zip) | [44.1 (17.1)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/squeezenet1_0_fc512_msmt_xent.zip) |
| mobilenetv2_1dot0<sup>:dog:</sup> | 2.2 | xent | (256, 128) | [85.6 (67.3)](http://eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/mobilenetv2_1dot0_market.pth.tar) | [74.2 (54.7)](http://eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/mobilenetv2_1dot0_duke.pth.tar) | [57.4 (29.3)](http://eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/mobilenetv2_1dot0_msmt.pth.tar) |
| mobilenetv2_1dot4<sup>:dog:</sup> | 4.3 | xent | (256, 128) | [87.0 (68.5)](http://eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/mobilenetv2_1dot4_market.pth.tar) | [76.2 (55.8)](http://eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/mobilenetv2_1dot4_duke.pth.tar) | [60.1 (31.5)](http://eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/mobilenetv2_1dot4_msmt.pth.tar) |
| resnet50mid<sup>:dog:</sup> | 27.7 | xent | (256, 128) | [90.2 (76.0)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/resnet50mid_market_xent.zip) | [81.6 (64.0)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/resnet50mid_duke_xent.zip) | [69.0 (38.0)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/resnet50mid_msmt_xent.zip) |
| mlfn<sup>:dog:</sup> | 32.5 | xent | (256, 128) | [90.1 (74.3)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/mlfn_market_xent.zip) | [81.1 (63.2)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/mlfn_duke_xent.zip) | [66.4 (37.2)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/mlfn_msmt_xent.zip) |
| hacnn | 3.7 | xent | (160, 64) | [90.9 (75.6)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/hacnn_market_xent.zip) | [80.1 (63.2)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/hacnn_duke_xent.zip) | [64.7 (37.2)](http://www.eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/image-models/hacnn_msmt_xent.zip) |

View File

@ -38,7 +38,8 @@ __model_factory = {
'xception': xception,
# lightweight models
'nasnsetmobile': nasnetamobile,
'mobilenetv2': MobileNetV2,
'mobilenetv2_1dot0': mobilenetv2_1dot0,
'mobilenetv2_1dot4': mobilenetv2_1dot4,
'shufflenet': shufflenet,
'squeezenet1_0': squeezenet1_0,
'squeezenet1_0_fc512': squeezenet1_0_fc512,

View File

@ -4,10 +4,18 @@ from __future__ import division
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
import torch.utils.model_zoo as model_zoo
__all__ = ['MobileNetV2']
__all__ = ['mobilenetv2_1dot0', 'mobilenetv2_1dot4']
model_urls = {
# 1.0: top-1 71.3
'mobilenetv2_1dot0': 'http://eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/imagenet-pretrained/mobilenetv2_1.0-0f5d2d8f.pth',
# 1.4: top-1 73.9
'mobilenetv2_1dot4': 'http://eecs.qmul.ac.uk/~kz303/deep-person-reid/model-zoo/imagenet-pretrained/mobilenetv2_1.4-4d0d3520.pth',
}
class ConvBlock(nn.Module):
@ -33,7 +41,7 @@ class ConvBlock(nn.Module):
class Bottleneck(nn.Module):
def __init__(self, in_channels, out_channels, expansion_factor, stride):
def __init__(self, in_channels, out_channels, expansion_factor, stride=1):
super(Bottleneck, self).__init__()
mid_channels = in_channels * expansion_factor
self.use_residual = stride == 1 and in_channels == out_channels
@ -54,7 +62,6 @@ class Bottleneck(nn.Module):
return m
class MobileNetV2(nn.Module):
"""
MobileNetV2
@ -62,67 +69,156 @@ class MobileNetV2(nn.Module):
Reference:
Sandler et al. MobileNetV2: Inverted Residuals and Linear Bottlenecks. CVPR 2018.
"""
def __init__(self, num_classes, loss={'xent'}, **kwargs):
def __init__(self, num_classes, width_mult=1, loss={'xent'}, fc_dims=None, dropout_p=None, **kwargs):
super(MobileNetV2, self).__init__()
self.loss = loss
self.in_channels = int(32 * width_mult)
self.feature_dim = int(1280 * width_mult) if width_mult > 1 else 1280
self.conv1 = ConvBlock(3, 32, 3, s=2, p=1)
self.block2 = Bottleneck(32, 16, 1, 1)
self.block3 = nn.Sequential(
Bottleneck(16, 24, 6, 2),
Bottleneck(24, 24, 6, 1),
)
self.block4 = nn.Sequential(
Bottleneck(24, 32, 6, 2),
Bottleneck(32, 32, 6, 1),
Bottleneck(32, 32, 6, 1),
)
self.block5 = nn.Sequential(
Bottleneck(32, 64, 6, 2),
Bottleneck(64, 64, 6, 1),
Bottleneck(64, 64, 6, 1),
Bottleneck(64, 64, 6, 1),
)
self.block6 = nn.Sequential(
Bottleneck(64, 96, 6, 1),
Bottleneck(96, 96, 6, 1),
Bottleneck(96, 96, 6, 1),
)
self.block7 = nn.Sequential(
Bottleneck(96, 160, 6, 2),
Bottleneck(160, 160, 6, 1),
Bottleneck(160, 160, 6, 1),
)
self.block8 = Bottleneck(160, 320, 6, 1)
self.conv9 = ConvBlock(320, 1280, 1)
self.classifier = nn.Linear(1280, num_classes)
self.feat_dim = 1280
# construct layers
self.conv1 = ConvBlock(3, self.in_channels, 3, s=2, p=1)
self.conv2 = self._make_layer(Bottleneck, 1, int(16 * width_mult), 1, 1)
self.conv3 = self._make_layer(Bottleneck, 6, int(24 * width_mult), 2, 2)
self.conv4 = self._make_layer(Bottleneck, 6, int(32 * width_mult), 3, 2)
self.conv5 = self._make_layer(Bottleneck, 6, int(64 * width_mult), 4, 2)
self.conv6 = self._make_layer(Bottleneck, 6, int(96 * width_mult), 3, 1)
self.conv7 = self._make_layer(Bottleneck, 6, int(160 * width_mult), 3, 2)
self.conv8 = self._make_layer(Bottleneck, 6, int(320 * width_mult), 1, 1)
self.conv9 = ConvBlock(self.in_channels, self.feature_dim, 1)
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.fc = self._construct_fc_layer(fc_dims, self.feature_dim, dropout_p)
self.classifier = nn.Linear(self.feature_dim, num_classes)
self._init_params()
def _make_layer(self, block, t, c, n, s):
# t: expansion factor
# c: output channels
# n: number of blocks
# s: stride for first layer
layers = []
layers.append(block(self.in_channels, c, t, s))
self.in_channels = c
for i in range(1, n):
layers.append(block(self.in_channels, c, t))
return nn.Sequential(*layers)
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
"""
Construct fully connected layer
- fc_dims (list or tuple): dimensions of fc layers, if None,
no fc layers are constructed
- input_dim (int): input dimension
- dropout_p (float): dropout probability, if None, dropout is unused
"""
if fc_dims is None:
self.feature_dim = input_dim
return None
assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format(type(fc_dims))
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, x):
x = self.conv1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.conv5(x)
x = self.conv6(x)
x = self.conv7(x)
x = self.conv8(x)
x = self.conv9(x)
return x
def forward(self, x):
x = self.featuremaps(x)
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1)
x = F.dropout(x, training=self.training)
f = self.featuremaps(x)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
if self.fc is not None:
v = self.fc(v)
if not self.training:
return x
y = self.classifier(x)
return v
y = self.classifier(v)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, x
return y, v
else:
raise KeyError('Unsupported loss: {}'.format(self.loss))
raise KeyError("Unsupported loss: {}".format(self.loss))
def init_pretrained_weights(model, model_url):
"""
Initialize model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
print('Initialized model with pretrained weights from {}'.format(model_url))
def mobilenetv2_1dot0(num_classes, loss, pretrained=True, **kwargs):
model = MobileNetV2(
num_classes,
loss=loss,
width_mult=1,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['mobilenetv2_1dot0'])
return model
def mobilenetv2_1dot4(num_classes, loss, pretrained=True, **kwargs):
model = MobileNetV2(
num_classes,
loss=loss,
width_mult=1.4,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained:
init_pretrained_weights(model, model_urls['mobilenetv2_1dot4'])
return model