191 lines
6.1 KiB
Python
191 lines
6.1 KiB
Python
from __future__ import absolute_import
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
import torchvision
|
|
|
|
__all__ = ['Xception']
|
|
|
|
class ConvBlock(nn.Module):
|
|
"""Basic convolutional block:
|
|
convolution (bias discarded) + batch normalization + relu6.
|
|
|
|
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
|
|
in_c (int): number of input channels.
|
|
out_c (int): number of output channels.
|
|
k (int or tuple): kernel size.
|
|
s (int or tuple): stride.
|
|
p (int or tuple): padding.
|
|
g (int): number of blocked connections from input channels
|
|
to output channels (default: 1).
|
|
"""
|
|
def __init__(self, in_c, out_c, k, s=1, p=0, g=1):
|
|
super(ConvBlock, self).__init__()
|
|
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False, groups=g)
|
|
self.bn = nn.BatchNorm2d(out_c)
|
|
|
|
def forward(self, x):
|
|
return F.relu6(self.bn(self.conv(x)))
|
|
|
|
class SepConv(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super(SepConv, self).__init__()
|
|
self.conv1 = nn.Sequential(
|
|
nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=False),
|
|
nn.BatchNorm2d(in_channels),
|
|
)
|
|
self.conv2 = nn.Sequential(
|
|
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
)
|
|
|
|
def forward(self, x):
|
|
return self.conv2(self.conv1(x))
|
|
|
|
class EntryFLow(nn.Module):
|
|
def __init__(self, nchannels):
|
|
super(EntryFLow, self).__init__()
|
|
self.conv1 = ConvBlock(3, nchannels[0], 3, s=2, p=1)
|
|
self.conv2 = ConvBlock(nchannels[0], nchannels[1], 3, p=1)
|
|
|
|
self.conv3 = nn.Sequential(
|
|
SepConv(nchannels[1], nchannels[2]),
|
|
nn.ReLU(),
|
|
SepConv(nchannels[2], nchannels[2]),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(3, stride=2, padding=1),
|
|
)
|
|
self.conv3s = nn.Sequential(
|
|
nn.Conv2d(nchannels[1], nchannels[2], 1, stride=2, bias=False),
|
|
nn.BatchNorm2d(nchannels[2]),
|
|
)
|
|
|
|
self.conv4 = nn.Sequential(
|
|
nn.ReLU(),
|
|
SepConv(nchannels[2], nchannels[3]),
|
|
nn.ReLU(),
|
|
SepConv(nchannels[3], nchannels[3]),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(3, stride=2, padding=1),
|
|
)
|
|
self.conv4s = nn.Sequential(
|
|
nn.Conv2d(nchannels[2], nchannels[3], 1, stride=2, bias=False),
|
|
nn.BatchNorm2d(nchannels[3])
|
|
)
|
|
|
|
self.conv5 = nn.Sequential(
|
|
nn.ReLU(),
|
|
SepConv(nchannels[3], nchannels[4]),
|
|
nn.ReLU(),
|
|
SepConv(nchannels[4], nchannels[4]),
|
|
nn.ReLU(),
|
|
nn.MaxPool2d(3, stride=2, padding=1),
|
|
)
|
|
self.conv5s = nn.Sequential(
|
|
nn.Conv2d(nchannels[3], nchannels[4], 1, stride=2, bias=False),
|
|
nn.BatchNorm2d(nchannels[4]),
|
|
)
|
|
|
|
def forward(self, x):
|
|
x1 = self.conv1(x)
|
|
x2 = self.conv2(x1)
|
|
|
|
x3 = self.conv3(x2)
|
|
x3s = self.conv3s(x2)
|
|
x3 = x3 + x3s
|
|
|
|
x4 = self.conv4(x3)
|
|
x4s = self.conv4s(x3)
|
|
x4 = x4 + x4s
|
|
|
|
x5 = self.conv5(x4)
|
|
x5s = self.conv5s(x4)
|
|
x5 = x5 + x5s
|
|
|
|
return x5
|
|
|
|
class MidFlowBlock(nn.Module):
|
|
def __init__(self, in_channels, out_channels):
|
|
super(MidFlowBlock, self).__init__()
|
|
self.conv1 = SepConv(in_channels, out_channels)
|
|
self.conv2 = SepConv(out_channels, out_channels)
|
|
self.conv3 = SepConv(out_channels, out_channels)
|
|
|
|
def forward(self, x):
|
|
x = self.conv1(F.relu(x))
|
|
x = self.conv2(F.relu(x))
|
|
x = self.conv3(F.relu(x))
|
|
return x
|
|
|
|
class MidFlow(nn.Module):
|
|
def __init__(self, in_channels, out_channels, num_repeats):
|
|
super(MidFlow, self).__init__()
|
|
self.blocks = self._make_layer(in_channels, out_channels, num_repeats)
|
|
|
|
def _make_layer(self, in_channels, out_channels, num):
|
|
layers = []
|
|
for i in range(num):
|
|
layers.append(MidFlowBlock(in_channels, out_channels))
|
|
in_channels = out_channels
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
return self.blocks(x)
|
|
|
|
class ExitFlow(nn.Module):
|
|
def __init__(self, in_channels, nchannels):
|
|
super(ExitFlow, self).__init__()
|
|
self.conv1 = SepConv(in_channels, nchannels[0])
|
|
self.conv2 = SepConv(nchannels[0], nchannels[1])
|
|
self.conv2s = nn.Sequential(
|
|
nn.Conv2d(in_channels, nchannels[1], 1, stride=2, bias=False),
|
|
nn.BatchNorm2d(nchannels[1]),
|
|
)
|
|
|
|
self.conv3 = SepConv(nchannels[1], nchannels[2])
|
|
self.conv4 = SepConv(nchannels[2], nchannels[3])
|
|
|
|
def forward(self, x):
|
|
x1 = self.conv1(F.relu(x))
|
|
x2 = self.conv2(F.relu(x1))
|
|
x2 = F.max_pool2d(x2, 3, stride=2, padding=1)
|
|
x2s = self.conv2s(x)
|
|
x2 = x2 + x2s
|
|
x3 = F.relu(self.conv3(x2))
|
|
x4 = F.relu(self.conv4(x3))
|
|
x4 = F.avg_pool2d(x4, x4.size()[2:]).view(x4.size(0), -1)
|
|
return x4
|
|
|
|
class Xception(nn.Module):
|
|
"""Xception
|
|
|
|
Reference:
|
|
Chollet. Xception: Deep Learning with Depthwise Separable Convolutions. CVPR 2017.
|
|
"""
|
|
def __init__(self, num_classes, loss={'xent'}, num_mid_flows=8, **kwargs):
|
|
super(Xception, self).__init__()
|
|
self.loss = loss
|
|
|
|
self.entryflow = EntryFLow(nchannels=[32, 64, 128, 256, 728])
|
|
self.midflow = MidFlow(728, 728, 8)
|
|
self.exitflow = ExitFlow(728, nchannels=[728, 1024, 1536, 2048])
|
|
self.classifier = nn.Linear(2048, num_classes)
|
|
self.feat_dim = 2048
|
|
|
|
def forward(self, x):
|
|
x = self.entryflow(x)
|
|
x = self.midflow(x)
|
|
x = self.exitflow(x)
|
|
|
|
if not self.training:
|
|
return x
|
|
|
|
y = self.classifier(x)
|
|
|
|
if self.loss == {'xent'}:
|
|
return y
|
|
elif self.loss == {'xent', 'htri'}:
|
|
return y, x
|
|
else:
|
|
raise KeyError("Unsupported loss: {}".format(self.loss)) |