Source code for torchreid.models.xception

from __future__ import absolute_import
from __future__ import division

__all__ = ['xception']

import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.model_zoo as model_zoo
from torch.nn import init


pretrained_settings = {
    'xception': {
        'imagenet': {
            'url': 'http://data.lip6.fr/cadene/pretrainedmodels/xception-43020ad28.pth',
            'input_space': 'RGB',
            'input_size': [3, 299, 299],
            'input_range': [0, 1],
            'mean': [0.5, 0.5, 0.5],
            'std': [0.5, 0.5, 0.5],
            'num_classes': 1000,
            'scale': 0.8975 # The resize parameter of the validation transform should be 333, and make sure to center crop at 299x299
        }
    }
}


class SeparableConv2d(nn.Module):
    
    def __init__(self, in_channels, out_channels, kernel_size=1, stride=1, padding=0, dilation=1, bias=False):
        super(SeparableConv2d,self).__init__()

        self.conv1 = nn.Conv2d(in_channels, in_channels, kernel_size,
                               stride, padding, dilation,
                               groups=in_channels, bias=bias)
        self.pointwise = nn.Conv2d(in_channels, out_channels, 1, 1, 0, 1, 1, bias=bias)

    def forward(self,x):
        x = self.conv1(x)
        x = self.pointwise(x)
        return x


class Block(nn.Module):
    
    def __init__(self, in_filters, out_filters, reps, strides=1, start_with_relu=True, grow_first=True):
        super(Block, self).__init__()

        if out_filters != in_filters or strides != 1:
            self.skip = nn.Conv2d(in_filters, out_filters, 1, stride=strides, bias=False)
            self.skipbn = nn.BatchNorm2d(out_filters)
        else:
            self.skip = None

        self.relu = nn.ReLU(inplace=True)
        rep=[]

        filters = in_filters
        if grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(out_filters))
            filters = out_filters

        for i in range(reps - 1):
            rep.append(self.relu)
            rep.append(SeparableConv2d(filters, filters, 3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(filters))

        if not grow_first:
            rep.append(self.relu)
            rep.append(SeparableConv2d(in_filters, out_filters, 3, stride=1, padding=1, bias=False))
            rep.append(nn.BatchNorm2d(out_filters))

        if not start_with_relu:
            rep = rep[1:]
        else:
            rep[0] = nn.ReLU(inplace=False)

        if strides != 1:
            rep.append(nn.MaxPool2d(3, strides, 1))
        self.rep = nn.Sequential(*rep)

    def forward(self,inp):
        x = self.rep(inp)

        if self.skip is not None:
            skip = self.skip(inp)
            skip = self.skipbn(skip)
        else:
            skip = inp

        x += skip
        return x


[docs]class Xception(nn.Module): """Xception. Reference: Chollet. Xception: Deep Learning with Depthwise Separable Convolutions. CVPR 2017. Public keys: - ``xception``: Xception. """ def __init__(self, num_classes, loss, fc_dims=None, dropout_p=None, **kwargs): super(Xception, self).__init__() self.loss = loss self.conv1 = nn.Conv2d(3, 32, 3,2, 0, bias=False) self.bn1 = nn.BatchNorm2d(32) self.conv2 = nn.Conv2d(32, 64, 3, bias=False) self.bn2 = nn.BatchNorm2d(64) self.block1 = Block(64, 128, 2, 2, start_with_relu=False, grow_first=True) self.block2 = Block(128, 256, 2, 2, start_with_relu=True, grow_first=True) self.block3 = Block(256, 728, 2, 2, start_with_relu=True, grow_first=True) self.block4 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block5 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block6 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block7 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block8 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block9 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block10 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block11 = Block(728, 728, 3, 1, start_with_relu=True, grow_first=True) self.block12 = Block(728, 1024, 2, 2, start_with_relu=True, grow_first=False) self.conv3 = SeparableConv2d(1024, 1536, 3, 1, 1) self.bn3 = nn.BatchNorm2d(1536) self.conv4 = SeparableConv2d(1536, 2048, 3, 1, 1) self.bn4 = nn.BatchNorm2d(2048) self.global_avgpool = nn.AdaptiveAvgPool2d(1) self.feature_dim = 2048 self.fc = self._construct_fc_layer(fc_dims, 2048, dropout_p) self.classifier = nn.Linear(self.feature_dim, num_classes) self._init_params() def _construct_fc_layer(self, fc_dims, input_dim, dropout_p=None): """Constructs fully connected layer. Args: fc_dims (list or tuple): dimensions of fc layers, if None, no fc layers are constructed input_dim (int): input dimension dropout_p (float): dropout probability, if None, dropout is unused """ if fc_dims is None: self.feature_dim = input_dim return None assert isinstance(fc_dims, (list, tuple)), 'fc_dims must be either list or tuple, but got {}'.format(type(fc_dims)) layers = [] for dim in fc_dims: layers.append(nn.Linear(input_dim, dim)) layers.append(nn.BatchNorm1d(dim)) layers.append(nn.ReLU(inplace=True)) if dropout_p is not None: layers.append(nn.Dropout(p=dropout_p)) input_dim = dim self.feature_dim = fc_dims[-1] return nn.Sequential(*layers) def _init_params(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.normal_(m.weight, 0, 0.01) if m.bias is not None: nn.init.constant_(m.bias, 0) def featuremaps(self, input): x = self.conv1(input) x = self.bn1(x) x = F.relu(x, inplace=True) x = self.conv2(x) x = self.bn2(x) x = F.relu(x, inplace=True) x = self.block1(x) x = self.block2(x) x = self.block3(x) x = self.block4(x) x = self.block5(x) x = self.block6(x) x = self.block7(x) x = self.block8(x) x = self.block9(x) x = self.block10(x) x = self.block11(x) x = self.block12(x) x = self.conv3(x) x = self.bn3(x) x = F.relu(x, inplace=True) x = self.conv4(x) x = self.bn4(x) x = F.relu(x, inplace=True) return x def forward(self, x): f = self.featuremaps(x) v = self.global_avgpool(f) v = v.view(v.size(0), -1) if self.fc is not None: v = self.fc(v) if not self.training: return v y = self.classifier(v) if self.loss == 'softmax': return y elif self.loss == 'triplet': return y, v else: raise KeyError('Unsupported loss: {}'.format(self.loss))
def init_pretrained_weights(model, model_url): """Initialize models with pretrained weights. Layers that don't match with pretrained layers in name or size are kept unchanged. """ pretrain_dict = model_zoo.load_url(model_url) model_dict = model.state_dict() pretrain_dict = {k: v for k, v in pretrain_dict.items() if k in model_dict and model_dict[k].size() == v.size()} model_dict.update(pretrain_dict) model.load_state_dict(model_dict) def xception(num_classes, loss='softmax', pretrained=True, **kwargs): model = Xception( num_classes, loss, fc_dims=None, dropout_p=None, **kwargs ) if pretrained: model_url = pretrained_settings['xception']['imagenet']['url'] init_pretrained_weights(model, model_url) return model