128 lines
3.8 KiB
Python
128 lines
3.8 KiB
Python
from __future__ import absolute_import
|
|
from __future__ import division
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
import torchvision
|
|
|
|
|
|
__all__ = ['MobileNetV2']
|
|
|
|
|
|
class ConvBlock(nn.Module):
|
|
"""Basic convolutional block:
|
|
convolution (bias discarded) + batch normalization + relu6.
|
|
|
|
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
|
|
in_c (int): number of input channels.
|
|
out_c (int): number of output channels.
|
|
k (int or tuple): kernel size.
|
|
s (int or tuple): stride.
|
|
p (int or tuple): padding.
|
|
g (int): number of blocked connections from input channels
|
|
to output channels (default: 1).
|
|
"""
|
|
def __init__(self, in_c, out_c, k, s=1, p=0, g=1):
|
|
super(ConvBlock, self).__init__()
|
|
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False, groups=g)
|
|
self.bn = nn.BatchNorm2d(out_c)
|
|
|
|
def forward(self, x):
|
|
return F.relu6(self.bn(self.conv(x)))
|
|
|
|
|
|
class Bottleneck(nn.Module):
|
|
def __init__(self, in_channels, out_channels, expansion_factor, stride):
|
|
super(Bottleneck, self).__init__()
|
|
mid_channels = in_channels * expansion_factor
|
|
self.use_residual = stride == 1 and in_channels == out_channels
|
|
self.conv1 = ConvBlock(in_channels, mid_channels, 1)
|
|
self.dwconv2 = ConvBlock(mid_channels, mid_channels, 3, stride, 1, g=mid_channels)
|
|
self.conv3 = nn.Sequential(
|
|
nn.Conv2d(mid_channels, out_channels, 1, bias=False),
|
|
nn.BatchNorm2d(out_channels),
|
|
)
|
|
|
|
def forward(self, x):
|
|
m = self.conv1(x)
|
|
m = self.dwconv2(m)
|
|
m = self.conv3(m)
|
|
if self.use_residual:
|
|
return x + m
|
|
else:
|
|
return m
|
|
|
|
|
|
|
|
class MobileNetV2(nn.Module):
|
|
"""
|
|
MobileNetV2
|
|
|
|
Reference:
|
|
Sandler et al. MobileNetV2: Inverted Residuals and Linear Bottlenecks. CVPR 2018.
|
|
"""
|
|
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
|
super(MobileNetV2, self).__init__()
|
|
self.loss = loss
|
|
|
|
self.conv1 = ConvBlock(3, 32, 3, s=2, p=1)
|
|
self.block2 = Bottleneck(32, 16, 1, 1)
|
|
self.block3 = nn.Sequential(
|
|
Bottleneck(16, 24, 6, 2),
|
|
Bottleneck(24, 24, 6, 1),
|
|
)
|
|
self.block4 = nn.Sequential(
|
|
Bottleneck(24, 32, 6, 2),
|
|
Bottleneck(32, 32, 6, 1),
|
|
Bottleneck(32, 32, 6, 1),
|
|
)
|
|
self.block5 = nn.Sequential(
|
|
Bottleneck(32, 64, 6, 2),
|
|
Bottleneck(64, 64, 6, 1),
|
|
Bottleneck(64, 64, 6, 1),
|
|
Bottleneck(64, 64, 6, 1),
|
|
)
|
|
self.block6 = nn.Sequential(
|
|
Bottleneck(64, 96, 6, 1),
|
|
Bottleneck(96, 96, 6, 1),
|
|
Bottleneck(96, 96, 6, 1),
|
|
)
|
|
self.block7 = nn.Sequential(
|
|
Bottleneck(96, 160, 6, 2),
|
|
Bottleneck(160, 160, 6, 1),
|
|
Bottleneck(160, 160, 6, 1),
|
|
)
|
|
self.block8 = Bottleneck(160, 320, 6, 1)
|
|
self.conv9 = ConvBlock(320, 1280, 1)
|
|
self.classifier = nn.Linear(1280, num_classes)
|
|
self.feat_dim = 1280
|
|
|
|
def featuremaps(self, x):
|
|
x = self.conv1(x)
|
|
x = self.block2(x)
|
|
x = self.block3(x)
|
|
x = self.block4(x)
|
|
x = self.block5(x)
|
|
x = self.block6(x)
|
|
x = self.block7(x)
|
|
x = self.block8(x)
|
|
x = self.conv9(x)
|
|
return x
|
|
|
|
def forward(self, x):
|
|
x = self.featuremaps(x)
|
|
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1)
|
|
x = F.dropout(x, training=self.training)
|
|
|
|
if not self.training:
|
|
return x
|
|
|
|
y = self.classifier(x)
|
|
|
|
if self.loss == {'xent'}:
|
|
return y
|
|
elif self.loss == {'xent', 'htri'}:
|
|
return y, x
|
|
else:
|
|
raise KeyError("Unsupported loss: {}".format(self.loss)) |