deep-person-reid/torchreid/models/xception.py

264 lines
8.7 KiB
Python
Raw Normal View History

2018-07-04 17:32:43 +08:00
from __future__ import absolute_import
from __future__ import division
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
import math
2018-07-02 17:33:10 +08:00
import torch
2018-10-28 00:46:48 +08:00
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
__all__ = ['xception']
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
pretrained_settings = {
'xception': {
'imagenet': {
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
'input_space': 'RGB',
'input_size': [3, 299, 299],
'input_range': [0, 1],
'mean': [0.5, 0.5, 0.5],
'std': [0.5, 0.5, 0.5],
'num_classes': 1000,
'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
}
}
}
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
class SeparableConv2d(nn.Module):
2018-10-28 06:22:23 +08:00
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
2018-10-28 00:46:48 +08:00
super(SeparableConv2d,self).__init__()
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
stride, padding, dilation,
groups=in_channels, bias=bias)
self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
def forward(self,x):
x = self.conv1(x)
x = self.pointwise(x)
return x
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
class Block(nn.Module):
2018-10-28 06:22:23 +08:00
def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
2018-10-28 00:46:48 +08:00
super(Block, self).__init__()
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
if out_filters != in_filters or strides != 1:
self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False)
2018-10-28 00:46:48 +08:00
self.skipbn = nn.BatchNorm2d(out_filters)
else:
2018-10-28 06:22:23 +08:00
self.skip = None
2018-10-28 00:46:48 +08:00
self.relu = nn.ReLU(inplace=True)
rep=[]
2018-10-28 06:22:23 +08:00
filters = in_filters
2018-10-28 00:46:48 +08:00
if grow_first:
rep.append(self.relu)
2018-10-28 06:22:23 +08:00
rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
2018-10-28 00:46:48 +08:00
rep.append(nn.BatchNorm2d(out_filters))
filters = out_filters
2018-10-28 06:22:23 +08:00
for i in range(reps - 1):
2018-10-28 00:46:48 +08:00
rep.append(self.relu)
2018-10-28 06:22:23 +08:00
rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False))
2018-10-28 00:46:48 +08:00
rep.append(nn.BatchNorm2d(filters))
if not grow_first:
rep.append(self.relu)
2018-10-28 06:22:23 +08:00
rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
2018-10-28 00:46:48 +08:00
rep.append(nn.BatchNorm2d(out_filters))
if not start_with_relu:
rep = rep[1:]
else:
rep[0] = nn.ReLU(inplace=False)
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
if strides != 1:
2018-10-28 06:22:23 +08:00
rep.append(nn.MaxPool2d(3, strides, 1))
2018-10-28 00:46:48 +08:00
self.rep = nn.Sequential(*rep)
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
def forward(self,inp):
x = self.rep(inp)
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
if self.skip is not None:
skip = self.skip(inp)
skip = self.skipbn(skip)
else:
skip = inp
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
x += skip
2018-10-28 00:46:48 +08:00
return x
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
class Xception(nn.Module):
"""
Xception
Reference:
Chollet. Xception: Deep Learning with Depthwise Separable Convolutions. CVPR 2017.
"""
def __init__(self, num_classes, loss, fc_dims=None, dropout_p=None, **kwargs):
super(Xception, self).__init__()
self.loss = loss
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False)
self.bn1 = nn.BatchNorm2d(32)
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
self.conv2 = nn.Conv2d(32, 64, 3, bias=False)
2018-10-28 00:46:48 +08:00
self.bn2 = nn.BatchNorm2d(64)
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True)
self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True)
self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True)
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True)
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False)
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1)
2018-10-28 00:46:48 +08:00
self.bn3 = nn.BatchNorm2d(1536)
2018-07-02 17:33:10 +08:00
2018-10-28 06:22:23 +08:00
self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1)
2018-10-28 00:46:48 +08:00
self.bn4 = nn.BatchNorm2d(2048)
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
self.global_avgpool = nn.AdaptiveAvgPool2d(1)
self.feature_dim = 2048
self.fc = self._construct_fc_layer(fc_dims, 2048, dropout_p)
self.classifier = nn.Linear(self.feature_dim, num_classes)
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
self._init_params()
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None):
"""
Construct fully connected layer
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
- fc_dims (list or tuple): dimensions of fc layers, if None,
no fc layers are constructed
- input_dim (int): input dimension
- dropout_p (float): dropout probability, if None, dropout is unused
"""
if fc_dims is None:
self.feature_dim = input_dim
return None
2019-01-31 06:41:47 +08:00
assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format(type(fc_dims))
2018-10-28 00:46:48 +08:00
layers = []
for dim in fc_dims:
layers.append(nn.Linear(input_dim, dim))
layers.append(nn.BatchNorm1d(dim))
layers.append(nn.ReLU(inplace=True))
if dropout_p is not None:
layers.append(nn.Dropout(p=dropout_p))
input_dim = dim
self.feature_dim = fc_dims[-1]
return nn.Sequential(*layers)
def _init_params(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def featuremaps(self, input):
x = self.conv1(input)
x = self.bn1(x)
x = F.relu(x, inplace=True)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x, inplace=True)
x = self.block1(x)
x = self.block2(x)
x = self.block3(x)
x = self.block4(x)
x = self.block5(x)
x = self.block6(x)
x = self.block7(x)
x = self.block8(x)
x = self.block9(x)
x = self.block10(x)
x = self.block11(x)
x = self.block12(x)
x = self.conv3(x)
x = self.bn3(x)
x = F.relu(x, inplace=True)
x = self.conv4(x)
x = self.bn4(x)
x = F.relu(x, inplace=True)
return x
2018-07-02 17:33:10 +08:00
def forward(self, x):
2018-10-28 00:46:48 +08:00
f = self.featuremaps(x)
v = self.global_avgpool(f)
v = v.view(v.size(0), -1)
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
if self.fc is not None:
v = self.fc(v)
2018-07-02 17:33:10 +08:00
2018-10-28 00:46:48 +08:00
if not self.training:
return v
y = self.classifier(v)
2018-07-02 17:33:10 +08:00
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
2018-10-28 00:46:48 +08:00
return y, v
2018-07-02 17:33:10 +08:00
else:
2019-01-31 06:41:47 +08:00
raise KeyError('Unsupported loss: {}'.format(self.loss))
2018-10-28 00:46:48 +08:00
def init_pretrained_weights(model, model_url):
"""
Initialize model with pretrained weights.
Layers that don't match with pretrained layers in name or size are kept unchanged.
"""
pretrain_dict = model_zoo.load_url(model_url)
model_dict = model.state_dict()
pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()}
model_dict.update(pretrain_dict)
model.load_state_dict(model_dict)
2019-01-31 06:41:47 +08:00
print('Initialized model with pretrained weights from {}'.format(model_url))
2018-10-28 00:46:48 +08:00
def xception(num_classes, loss, pretrained='imagenet', **kwargs):
model = Xception(
num_classes, loss,
fc_dims=None,
dropout_p=None,
**kwargs
)
if pretrained == 'imagenet':
model_url = pretrained_settings['xception']['imagenet']['url']
init_pretrained_weights(model, model_url)
return model