restore models
parent
3d658f55c3
commit
74393c8aa6
|
@ -0,0 +1,47 @@
|
|||
from __future__ import absolute_import
|
||||
|
||||
from .resnet import *
|
||||
from .resnext import *
|
||||
from .seresnet import *
|
||||
from .densenet import *
|
||||
from .mudeep import *
|
||||
from .hacnn import *
|
||||
from .squeeze import *
|
||||
from .mobilenetv2 import *
|
||||
from .shufflenet import *
|
||||
from .xception import *
|
||||
from .inceptionv4 import *
|
||||
from .nasnet import *
|
||||
from .inceptionresnetv2 import *
|
||||
|
||||
|
||||
__model_factory = {
|
||||
'resnet50': ResNet50,
|
||||
'resnet101': ResNet101,
|
||||
'seresnet50': SEResNet50,
|
||||
'seresnet101': SEResNet101,
|
||||
'seresnext50': SEResNeXt50,
|
||||
'seresnext101': SEResNeXt101,
|
||||
'resnext101': ResNeXt101_32x4d,
|
||||
'resnet50m': ResNet50M,
|
||||
'densenet121': DenseNet121,
|
||||
'squeezenet': SqueezeNet,
|
||||
'mobilenetv2': MobileNetV2,
|
||||
'shufflenet': ShuffleNet,
|
||||
'xception': Xception,
|
||||
'inceptionv4': InceptionV4,
|
||||
'nasnsetmobile': NASNetAMobile,
|
||||
'inceptionresnetv2': InceptionResNetV2,
|
||||
'mudeep': MuDeep,
|
||||
'hacnn': HACNN,
|
||||
}
|
||||
|
||||
|
||||
def get_names():
|
||||
return __model_factory.keys()
|
||||
|
||||
|
||||
def init_model(name, *args, **kwargs):
|
||||
if name not in __model_factory.keys():
|
||||
raise KeyError("Unknown model: {}".format(name))
|
||||
return __model_factory[name](*args, **kwargs)
|
|
@ -0,0 +1,34 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
__all__ = ['DenseNet121']
|
||||
|
||||
|
||||
class DenseNet121(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(DenseNet121, self).__init__()
|
||||
self.loss = loss
|
||||
densenet121 = torchvision.models.densenet121(pretrained=True)
|
||||
self.base = densenet121.features
|
||||
self.classifier = nn.Linear(1024, num_classes)
|
||||
self.feat_dim = 1024
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
@ -0,0 +1,377 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
__all__ = ['HACNN']
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic convolutional block:
|
||||
convolution + batch normalization + relu.
|
||||
|
||||
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
|
||||
- in_c (int): number of input channels.
|
||||
- out_c (int): number of output channels.
|
||||
- k (int or tuple): kernel size.
|
||||
- s (int or tuple): stride.
|
||||
- p (int or tuple): padding.
|
||||
"""
|
||||
def __init__(self, in_c, out_c, k, s=1, p=0):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
|
||||
self.bn = nn.BatchNorm2d(out_c)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class InceptionA(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
- in_channels (int): number of input channels
|
||||
- out_channels (int): number of output channels AFTER concatenation
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(InceptionA, self).__init__()
|
||||
mid_channels = out_channels // 4
|
||||
|
||||
self.stream1 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
||||
)
|
||||
self.stream2 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
||||
)
|
||||
self.stream3 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
||||
)
|
||||
self.stream4 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1),
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
s4 = self.stream4(x)
|
||||
y = torch.cat([s1, s2, s3, s4], dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class InceptionB(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
- in_channels (int): number of input channels
|
||||
- out_channels (int): number of output channels AFTER concatenation
|
||||
"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(InceptionB, self).__init__()
|
||||
mid_channels = out_channels // 4
|
||||
|
||||
self.stream1 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
|
||||
)
|
||||
self.stream2 = nn.Sequential(
|
||||
ConvBlock(in_channels, mid_channels, 1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, p=1),
|
||||
ConvBlock(mid_channels, mid_channels, 3, s=2, p=1),
|
||||
)
|
||||
self.stream3 = nn.Sequential(
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
ConvBlock(in_channels, mid_channels*2, 1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
y = torch.cat([s1, s2, s3], dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class SpatialAttn(nn.Module):
|
||||
"""Spatial Attention (Sec. 3.1.I.1)"""
|
||||
def __init__(self):
|
||||
super(SpatialAttn, self).__init__()
|
||||
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
|
||||
self.conv2 = ConvBlock(1, 1, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# global cross-channel averaging
|
||||
x = x.mean(1, keepdim=True)
|
||||
# 3-by-3 conv
|
||||
x = self.conv1(x)
|
||||
# bilinear resizing
|
||||
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear', align_corners=True)
|
||||
# scaling conv
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class ChannelAttn(nn.Module):
|
||||
"""Channel Attention (Sec. 3.1.I.2)"""
|
||||
def __init__(self, in_channels, reduction_rate=16):
|
||||
super(ChannelAttn, self).__init__()
|
||||
assert in_channels%reduction_rate == 0
|
||||
self.conv1 = ConvBlock(in_channels, in_channels // reduction_rate, 1)
|
||||
self.conv2 = ConvBlock(in_channels // reduction_rate, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
# squeeze operation (global average pooling)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
# excitation operation (2 conv layers)
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
||||
|
||||
class SoftAttn(nn.Module):
|
||||
"""Soft Attention (Sec. 3.1.I)
|
||||
Aim: Spatial Attention + Channel Attention
|
||||
Output: attention maps with shape identical to input.
|
||||
"""
|
||||
def __init__(self, in_channels):
|
||||
super(SoftAttn, self).__init__()
|
||||
self.spatial_attn = SpatialAttn()
|
||||
self.channel_attn = ChannelAttn(in_channels)
|
||||
self.conv = ConvBlock(in_channels, in_channels, 1)
|
||||
|
||||
def forward(self, x):
|
||||
y_spatial = self.spatial_attn(x)
|
||||
y_channel = self.channel_attn(x)
|
||||
y = y_spatial * y_channel
|
||||
y = F.sigmoid(self.conv(y))
|
||||
return y
|
||||
|
||||
|
||||
class HardAttn(nn.Module):
|
||||
"""Hard Attention (Sec. 3.1.II)"""
|
||||
def __init__(self, in_channels):
|
||||
super(HardAttn, self).__init__()
|
||||
self.fc = nn.Linear(in_channels, 4*2)
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
self.fc.weight.data.zero_()
|
||||
self.fc.bias.data.copy_(torch.tensor([0, -0.75, 0, -0.25, 0, 0.25, 0, 0.75], dtype=torch.float))
|
||||
|
||||
def forward(self, x):
|
||||
# squeeze operation (global average pooling)
|
||||
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
|
||||
# predict transformation parameters
|
||||
theta = F.tanh(self.fc(x))
|
||||
theta = theta.view(-1, 4, 2)
|
||||
return theta
|
||||
|
||||
|
||||
class HarmAttn(nn.Module):
|
||||
"""Harmonious Attention (Sec. 3.1)"""
|
||||
def __init__(self, in_channels):
|
||||
super(HarmAttn, self).__init__()
|
||||
self.soft_attn = SoftAttn(in_channels)
|
||||
self.hard_attn = HardAttn(in_channels)
|
||||
|
||||
def forward(self, x):
|
||||
y_soft_attn = self.soft_attn(x)
|
||||
theta = self.hard_attn(x)
|
||||
return y_soft_attn, theta
|
||||
|
||||
|
||||
class HACNN(nn.Module):
|
||||
"""
|
||||
Harmonious Attention Convolutional Neural Network
|
||||
|
||||
Reference:
|
||||
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
|
||||
|
||||
Args:
|
||||
- num_classes (int): number of classes to predict
|
||||
- nchannels (list): number of channels AFTER concatenation
|
||||
- feat_dim (int): feature dimension for a single stream
|
||||
- learn_region (bool): whether to learn region features (i.e. local branch)
|
||||
"""
|
||||
def __init__(self, num_classes, loss={'xent'}, nchannels=[128, 256, 384], feat_dim=512, learn_region=True, use_gpu=True, **kwargs):
|
||||
super(HACNN, self).__init__()
|
||||
self.loss = loss
|
||||
self.learn_region = learn_region
|
||||
self.use_gpu = use_gpu
|
||||
|
||||
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
|
||||
|
||||
# Construct Inception + HarmAttn blocks
|
||||
# ============== Block 1 ==============
|
||||
self.inception1 = nn.Sequential(
|
||||
InceptionA(32, nchannels[0]),
|
||||
InceptionB(nchannels[0], nchannels[0]),
|
||||
)
|
||||
self.ha1 = HarmAttn(nchannels[0])
|
||||
|
||||
# ============== Block 2 ==============
|
||||
self.inception2 = nn.Sequential(
|
||||
InceptionA(nchannels[0], nchannels[1]),
|
||||
InceptionB(nchannels[1], nchannels[1]),
|
||||
)
|
||||
self.ha2 = HarmAttn(nchannels[1])
|
||||
|
||||
# ============== Block 3 ==============
|
||||
self.inception3 = nn.Sequential(
|
||||
InceptionA(nchannels[1], nchannels[2]),
|
||||
InceptionB(nchannels[2], nchannels[2]),
|
||||
)
|
||||
self.ha3 = HarmAttn(nchannels[2])
|
||||
|
||||
self.fc_global = nn.Sequential(
|
||||
nn.Linear(nchannels[2], feat_dim),
|
||||
nn.BatchNorm1d(feat_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.classifier_global = nn.Linear(feat_dim, num_classes)
|
||||
|
||||
if self.learn_region:
|
||||
self.init_scale_factors()
|
||||
self.local_conv1 = InceptionB(32, nchannels[0])
|
||||
self.local_conv2 = InceptionB(nchannels[0], nchannels[1])
|
||||
self.local_conv3 = InceptionB(nchannels[1], nchannels[2])
|
||||
self.fc_local = nn.Sequential(
|
||||
nn.Linear(nchannels[2]*4, feat_dim),
|
||||
nn.BatchNorm1d(feat_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.classifier_local = nn.Linear(feat_dim, num_classes)
|
||||
self.feat_dim = feat_dim * 2
|
||||
else:
|
||||
self.feat_dim = feat_dim
|
||||
|
||||
def init_scale_factors(self):
|
||||
# initialize scale factors (s_w, s_h) for four regions
|
||||
self.scale_factors = []
|
||||
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
|
||||
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
|
||||
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
|
||||
self.scale_factors.append(torch.tensor([[1, 0], [0, 0.25]], dtype=torch.float))
|
||||
|
||||
def stn(self, x, theta):
|
||||
"""Perform spatial transform
|
||||
- x: (batch, channel, height, width)
|
||||
- theta: (batch, 2, 3)
|
||||
"""
|
||||
grid = F.affine_grid(theta, x.size())
|
||||
x = F.grid_sample(x, grid)
|
||||
return x
|
||||
|
||||
def transform_theta(self, theta_i, region_idx):
|
||||
"""Transform theta to include (s_w, s_h),
|
||||
resulting in (batch, 2, 3)"""
|
||||
scale_factors = self.scale_factors[region_idx]
|
||||
theta = torch.zeros(theta_i.size(0), 2, 3)
|
||||
theta[:,:,:2] = scale_factors
|
||||
theta[:,:,-1] = theta_i
|
||||
if self.use_gpu: theta = theta.cuda()
|
||||
return theta
|
||||
|
||||
def forward(self, x):
|
||||
assert x.size(2) == 160 and x.size(3) == 64, \
|
||||
"Input size does not match, expected (160, 64) but got ({}, {})".format(x.size(2), x.size(3))
|
||||
x = self.conv(x)
|
||||
|
||||
# ============== Block 1 ==============
|
||||
# global branch
|
||||
x1 = self.inception1(x)
|
||||
x1_attn, x1_theta = self.ha1(x1)
|
||||
x1_out = x1 * x1_attn
|
||||
# local branch
|
||||
if self.learn_region:
|
||||
x1_local_list = []
|
||||
for region_idx in range(4):
|
||||
x1_theta_i = x1_theta[:,region_idx,:]
|
||||
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
|
||||
x1_trans_i = self.stn(x, x1_theta_i)
|
||||
x1_trans_i = F.upsample(x1_trans_i, (24, 28), mode='bilinear', align_corners=True)
|
||||
x1_local_i = self.local_conv1(x1_trans_i)
|
||||
x1_local_list.append(x1_local_i)
|
||||
|
||||
# ============== Block 2 ==============
|
||||
# Block 2
|
||||
# global branch
|
||||
x2 = self.inception2(x1_out)
|
||||
x2_attn, x2_theta = self.ha2(x2)
|
||||
x2_out = x2 * x2_attn
|
||||
# local branch
|
||||
if self.learn_region:
|
||||
x2_local_list = []
|
||||
for region_idx in range(4):
|
||||
x2_theta_i = x2_theta[:,region_idx,:]
|
||||
x2_theta_i = self.transform_theta(x2_theta_i, region_idx)
|
||||
x2_trans_i = self.stn(x1_out, x2_theta_i)
|
||||
x2_trans_i = F.upsample(x2_trans_i, (12, 14), mode='bilinear', align_corners=True)
|
||||
x2_local_i = x2_trans_i + x1_local_list[region_idx]
|
||||
x2_local_i = self.local_conv2(x2_local_i)
|
||||
x2_local_list.append(x2_local_i)
|
||||
|
||||
# ============== Block 3 ==============
|
||||
# Block 3
|
||||
# global branch
|
||||
x3 = self.inception3(x2_out)
|
||||
x3_attn, x3_theta = self.ha3(x3)
|
||||
x3_out = x3 * x3_attn
|
||||
# local branch
|
||||
if self.learn_region:
|
||||
x3_local_list = []
|
||||
for region_idx in range(4):
|
||||
x3_theta_i = x3_theta[:,region_idx,:]
|
||||
x3_theta_i = self.transform_theta(x3_theta_i, region_idx)
|
||||
x3_trans_i = self.stn(x2_out, x3_theta_i)
|
||||
x3_trans_i = F.upsample(x3_trans_i, (6, 7), mode='bilinear', align_corners=True)
|
||||
x3_local_i = x3_trans_i + x2_local_list[region_idx]
|
||||
x3_local_i = self.local_conv3(x3_local_i)
|
||||
x3_local_list.append(x3_local_i)
|
||||
|
||||
# ============== Feature generation ==============
|
||||
# global branch
|
||||
x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1))
|
||||
x_global = self.fc_global(x_global)
|
||||
# local branch
|
||||
if self.learn_region:
|
||||
x_local_list = []
|
||||
for region_idx in range(4):
|
||||
x_local_i = x3_local_list[region_idx]
|
||||
x_local_i = F.avg_pool2d(x_local_i, x_local_i.size()[2:]).view(x_local_i.size(0), -1)
|
||||
x_local_list.append(x_local_i)
|
||||
x_local = torch.cat(x_local_list, 1)
|
||||
x_local = self.fc_local(x_local)
|
||||
|
||||
if not self.training:
|
||||
# l2 normalization before concatenation
|
||||
if self.learn_region:
|
||||
x_global = x_global / x_global.norm(p=2, dim=1, keepdim=True)
|
||||
x_local = x_local / x_local.norm(p=2, dim=1, keepdim=True)
|
||||
return torch.cat([x_global, x_local], 1)
|
||||
else:
|
||||
return x_global
|
||||
|
||||
prelogits_global = self.classifier_global(x_global)
|
||||
if self.learn_region:
|
||||
prelogits_local = self.classifier_local(x_local)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
if self.learn_region:
|
||||
return (prelogits_global, prelogits_local)
|
||||
else:
|
||||
return prelogits_global
|
||||
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
if self.learn_region:
|
||||
return (prelogits_global, prelogits_local), (x_global, x_local)
|
||||
else:
|
||||
return prelogits_global, x_global
|
||||
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
@ -0,0 +1,385 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
"""
|
||||
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
||||
"""
|
||||
|
||||
|
||||
__all__ = ['InceptionResNetV2']
|
||||
|
||||
|
||||
pretrained_settings = {
|
||||
'inceptionresnetv2': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 299, 299],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1000
|
||||
},
|
||||
'imagenet+background': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionresnetv2-520b38e4.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 299, 299],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_planes, out_planes,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, bias=False) # verify bias false
|
||||
self.bn = nn.BatchNorm2d(out_planes,
|
||||
eps=0.001, # value found in tensorflow
|
||||
momentum=0.1, # default pytorch value
|
||||
affine=True)
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Mixed_5b(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_5b, self).__init__()
|
||||
|
||||
self.branch0 = BasicConv2d(192, 96, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(192, 48, kernel_size=1, stride=1),
|
||||
BasicConv2d(48, 64, kernel_size=5, stride=1, padding=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(192, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(192, 64, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
x3 = self.branch3(x)
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Block35(nn.Module):
|
||||
|
||||
def __init__(self, scale=1.0):
|
||||
super(Block35, self).__init__()
|
||||
|
||||
self.scale = scale
|
||||
|
||||
self.branch0 = BasicConv2d(320, 32, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
||||
BasicConv2d(32, 32, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(320, 32, kernel_size=1, stride=1),
|
||||
BasicConv2d(32, 48, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(48, 64, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.conv2d = nn.Conv2d(128, 320, kernel_size=1, stride=1)
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
out = torch.cat((x0, x1, x2), 1)
|
||||
out = self.conv2d(out)
|
||||
out = out * self.scale + x
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_6a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_6a, self).__init__()
|
||||
|
||||
self.branch0 = BasicConv2d(320, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(320, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 256, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
out = torch.cat((x0, x1, x2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Block17(nn.Module):
|
||||
|
||||
def __init__(self, scale=1.0):
|
||||
super(Block17, self).__init__()
|
||||
|
||||
self.scale = scale
|
||||
|
||||
self.branch0 = BasicConv2d(1088, 192, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1088, 128, kernel_size=1, stride=1),
|
||||
BasicConv2d(128, 160, kernel_size=(1,7), stride=1, padding=(0,3)),
|
||||
BasicConv2d(160, 192, kernel_size=(7,1), stride=1, padding=(3,0))
|
||||
)
|
||||
|
||||
self.conv2d = nn.Conv2d(384, 1088, kernel_size=1, stride=1)
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
out = self.conv2d(out)
|
||||
out = out * self.scale + x
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_7a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_7a, self).__init__()
|
||||
|
||||
self.branch0 = nn.Sequential(
|
||||
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 384, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 288, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(1088, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 288, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(288, 320, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch3 = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
x3 = self.branch3(x)
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Block8(nn.Module):
|
||||
|
||||
def __init__(self, scale=1.0, noReLU=False):
|
||||
super(Block8, self).__init__()
|
||||
|
||||
self.scale = scale
|
||||
self.noReLU = noReLU
|
||||
|
||||
self.branch0 = BasicConv2d(2080, 192, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(2080, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 224, kernel_size=(1,3), stride=1, padding=(0,1)),
|
||||
BasicConv2d(224, 256, kernel_size=(3,1), stride=1, padding=(1,0))
|
||||
)
|
||||
|
||||
self.conv2d = nn.Conv2d(448, 2080, kernel_size=1, stride=1)
|
||||
if not self.noReLU:
|
||||
self.relu = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
out = self.conv2d(out)
|
||||
out = out * self.scale + x
|
||||
if not self.noReLU:
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def inceptionresnetv2(num_classes=1000, pretrained='imagenet'):
|
||||
r"""InceptionResNetV2 model architecture from the
|
||||
`"InceptionV4, Inception-ResNet..." <https://arxiv.org/abs/1602.07261>`_ paper.
|
||||
"""
|
||||
if pretrained:
|
||||
settings = pretrained_settings['inceptionresnetv2'][pretrained]
|
||||
assert num_classes == settings['num_classes'], \
|
||||
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
|
||||
|
||||
# both 'imagenet'&'imagenet+background' are loaded from same parameters
|
||||
model = InceptionResNetV2(num_classes=1001)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||
|
||||
if pretrained == 'imagenet':
|
||||
new_last_linear = nn.Linear(1536, 1000)
|
||||
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
||||
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
||||
model.last_linear = new_last_linear
|
||||
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
else:
|
||||
model = InceptionResNetV2(num_classes=num_classes)
|
||||
return model
|
||||
|
||||
|
||||
##################### Model Definition #########################
|
||||
|
||||
|
||||
class InceptionResNetV2(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(InceptionResNetV2, self).__init__()
|
||||
self.loss = loss
|
||||
# Modules
|
||||
self.conv2d_1a = BasicConv2d(3, 32, kernel_size=3, stride=2)
|
||||
self.conv2d_2a = BasicConv2d(32, 32, kernel_size=3, stride=1)
|
||||
self.conv2d_2b = BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.maxpool_3a = nn.MaxPool2d(3, stride=2)
|
||||
self.conv2d_3b = BasicConv2d(64, 80, kernel_size=1, stride=1)
|
||||
self.conv2d_4a = BasicConv2d(80, 192, kernel_size=3, stride=1)
|
||||
self.maxpool_5a = nn.MaxPool2d(3, stride=2)
|
||||
self.mixed_5b = Mixed_5b()
|
||||
self.repeat = nn.Sequential(
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17),
|
||||
Block35(scale=0.17)
|
||||
)
|
||||
self.mixed_6a = Mixed_6a()
|
||||
self.repeat_1 = nn.Sequential(
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10),
|
||||
Block17(scale=0.10)
|
||||
)
|
||||
self.mixed_7a = Mixed_7a()
|
||||
self.repeat_2 = nn.Sequential(
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20),
|
||||
Block8(scale=0.20)
|
||||
)
|
||||
self.block8 = Block8(noReLU=True)
|
||||
self.conv2d_7b = BasicConv2d(2080, 1536, kernel_size=1, stride=1)
|
||||
self.classifier = nn.Linear(1536, num_classes)
|
||||
self.feat_dim = 1536
|
||||
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
"""Load ImageNet pretrained weights"""
|
||||
settings = pretrained_settings['inceptionresnetv2']['imagenet']
|
||||
pretrained_dict = model_zoo.load_url(settings['url'], map_location=None)
|
||||
model_dict = self.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.load_state_dict(model_dict)
|
||||
|
||||
def features(self, input):
|
||||
x = self.conv2d_1a(input)
|
||||
x = self.conv2d_2a(x)
|
||||
x = self.conv2d_2b(x)
|
||||
x = self.maxpool_3a(x)
|
||||
x = self.conv2d_3b(x)
|
||||
x = self.conv2d_4a(x)
|
||||
x = self.maxpool_5a(x)
|
||||
x = self.mixed_5b(x)
|
||||
x = self.repeat(x)
|
||||
x = self.mixed_6a(x)
|
||||
x = self.repeat_1(x)
|
||||
x = self.mixed_7a(x)
|
||||
x = self.repeat_2(x)
|
||||
x = self.block8(x)
|
||||
x = self.conv2d_7b(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
x = x.view(x.size(0), -1)
|
||||
return x
|
||||
|
||||
def forward(self, input):
|
||||
x = self.features(input)
|
||||
|
||||
if not self.training:
|
||||
return x
|
||||
|
||||
y = self.classifier(x)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, x
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
@ -0,0 +1,368 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import os
|
||||
import sys
|
||||
|
||||
|
||||
__all__ = ['InceptionV4']
|
||||
|
||||
"""
|
||||
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
||||
"""
|
||||
|
||||
|
||||
pretrained_settings = {
|
||||
'inceptionv4': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 299, 299],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1000
|
||||
},
|
||||
'imagenet+background': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/inceptionv4-8e4777a0.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 299, 299],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1001
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class BasicConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_planes, out_planes, kernel_size, stride, padding=0):
|
||||
super(BasicConv2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_planes, out_planes,
|
||||
kernel_size=kernel_size, stride=stride,
|
||||
padding=padding, bias=False) # verify bias false
|
||||
self.bn = nn.BatchNorm2d(out_planes,
|
||||
eps=0.001, # value found in tensorflow
|
||||
momentum=0.1, # default pytorch value
|
||||
affine=True)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.bn(x)
|
||||
x = self.relu(x)
|
||||
return x
|
||||
|
||||
|
||||
class Mixed_3a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_3a, self).__init__()
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2)
|
||||
self.conv = BasicConv2d(64, 96, kernel_size=3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.maxpool(x)
|
||||
x1 = self.conv(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_4a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_4a, self).__init__()
|
||||
|
||||
self.branch0 = nn.Sequential(
|
||||
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1)
|
||||
)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(160, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 64, kernel_size=(1,7), stride=1, padding=(0,3)),
|
||||
BasicConv2d(64, 64, kernel_size=(7,1), stride=1, padding=(3,0)),
|
||||
BasicConv2d(64, 96, kernel_size=(3,3), stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Mixed_5a(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Mixed_5a, self).__init__()
|
||||
self.conv = BasicConv2d(192, 192, kernel_size=3, stride=2)
|
||||
self.maxpool = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.conv(x)
|
||||
x1 = self.maxpool(x)
|
||||
out = torch.cat((x0, x1), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Inception_A(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Inception_A, self).__init__()
|
||||
self.branch0 = BasicConv2d(384, 96, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(384, 64, kernel_size=1, stride=1),
|
||||
BasicConv2d(64, 96, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(96, 96, kernel_size=3, stride=1, padding=1)
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(384, 96, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
x3 = self.branch3(x)
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Reduction_A(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Reduction_A, self).__init__()
|
||||
self.branch0 = BasicConv2d(384, 384, kernel_size=3, stride=2)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(384, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 224, kernel_size=3, stride=1, padding=1),
|
||||
BasicConv2d(224, 256, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
out = torch.cat((x0, x1, x2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Inception_B(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Inception_B, self).__init__()
|
||||
self.branch0 = BasicConv2d(1024, 384, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)),
|
||||
BasicConv2d(224, 256, kernel_size=(7,1), stride=1, padding=(3,0))
|
||||
)
|
||||
|
||||
self.branch2 = nn.Sequential(
|
||||
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 192, kernel_size=(7,1), stride=1, padding=(3,0)),
|
||||
BasicConv2d(192, 224, kernel_size=(1,7), stride=1, padding=(0,3)),
|
||||
BasicConv2d(224, 224, kernel_size=(7,1), stride=1, padding=(3,0)),
|
||||
BasicConv2d(224, 256, kernel_size=(1,7), stride=1, padding=(0,3))
|
||||
)
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(1024, 128, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
x3 = self.branch3(x)
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Reduction_B(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Reduction_B, self).__init__()
|
||||
|
||||
self.branch0 = nn.Sequential(
|
||||
BasicConv2d(1024, 192, kernel_size=1, stride=1),
|
||||
BasicConv2d(192, 192, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch1 = nn.Sequential(
|
||||
BasicConv2d(1024, 256, kernel_size=1, stride=1),
|
||||
BasicConv2d(256, 256, kernel_size=(1,7), stride=1, padding=(0,3)),
|
||||
BasicConv2d(256, 320, kernel_size=(7,1), stride=1, padding=(3,0)),
|
||||
BasicConv2d(320, 320, kernel_size=3, stride=2)
|
||||
)
|
||||
|
||||
self.branch2 = nn.MaxPool2d(3, stride=2)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
x1 = self.branch1(x)
|
||||
x2 = self.branch2(x)
|
||||
out = torch.cat((x0, x1, x2), 1)
|
||||
return out
|
||||
|
||||
|
||||
class Inception_C(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(Inception_C, self).__init__()
|
||||
|
||||
self.branch0 = BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
||||
|
||||
self.branch1_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
||||
self.branch1_1a = BasicConv2d(384, 256, kernel_size=(1,3), stride=1, padding=(0,1))
|
||||
self.branch1_1b = BasicConv2d(384, 256, kernel_size=(3,1), stride=1, padding=(1,0))
|
||||
|
||||
self.branch2_0 = BasicConv2d(1536, 384, kernel_size=1, stride=1)
|
||||
self.branch2_1 = BasicConv2d(384, 448, kernel_size=(3,1), stride=1, padding=(1,0))
|
||||
self.branch2_2 = BasicConv2d(448, 512, kernel_size=(1,3), stride=1, padding=(0,1))
|
||||
self.branch2_3a = BasicConv2d(512, 256, kernel_size=(1,3), stride=1, padding=(0,1))
|
||||
self.branch2_3b = BasicConv2d(512, 256, kernel_size=(3,1), stride=1, padding=(1,0))
|
||||
|
||||
self.branch3 = nn.Sequential(
|
||||
nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False),
|
||||
BasicConv2d(1536, 256, kernel_size=1, stride=1)
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x0 = self.branch0(x)
|
||||
|
||||
x1_0 = self.branch1_0(x)
|
||||
x1_1a = self.branch1_1a(x1_0)
|
||||
x1_1b = self.branch1_1b(x1_0)
|
||||
x1 = torch.cat((x1_1a, x1_1b), 1)
|
||||
|
||||
x2_0 = self.branch2_0(x)
|
||||
x2_1 = self.branch2_1(x2_0)
|
||||
x2_2 = self.branch2_2(x2_1)
|
||||
x2_3a = self.branch2_3a(x2_2)
|
||||
x2_3b = self.branch2_3b(x2_2)
|
||||
x2 = torch.cat((x2_3a, x2_3b), 1)
|
||||
|
||||
x3 = self.branch3(x)
|
||||
|
||||
out = torch.cat((x0, x1, x2, x3), 1)
|
||||
return out
|
||||
|
||||
|
||||
class InceptionV4Base(nn.Module):
|
||||
|
||||
def __init__(self, num_classes=1001):
|
||||
super(InceptionV4Base, self).__init__()
|
||||
# Special attributs
|
||||
self.input_space = None
|
||||
self.input_size = (299, 299, 3)
|
||||
self.mean = None
|
||||
self.std = None
|
||||
# Modules
|
||||
self.features = nn.Sequential(
|
||||
BasicConv2d(3, 32, kernel_size=3, stride=2),
|
||||
BasicConv2d(32, 32, kernel_size=3, stride=1),
|
||||
BasicConv2d(32, 64, kernel_size=3, stride=1, padding=1),
|
||||
Mixed_3a(),
|
||||
Mixed_4a(),
|
||||
Mixed_5a(),
|
||||
Inception_A(),
|
||||
Inception_A(),
|
||||
Inception_A(),
|
||||
Inception_A(),
|
||||
Reduction_A(), # Mixed_6a
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Inception_B(),
|
||||
Reduction_B(), # Mixed_7a
|
||||
Inception_C(),
|
||||
Inception_C(),
|
||||
Inception_C()
|
||||
)
|
||||
self.avg_pool = nn.AvgPool2d(8, count_include_pad=False)
|
||||
self.last_linear = nn.Linear(1536, num_classes)
|
||||
|
||||
def logits(self, features):
|
||||
x = self.avg_pool(features)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.last_linear(x)
|
||||
return x
|
||||
|
||||
def forward(self, input):
|
||||
x = self.features(input)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
|
||||
|
||||
def inceptionv4(num_classes=1000, pretrained='imagenet'):
|
||||
if pretrained:
|
||||
settings = pretrained_settings['inceptionv4'][pretrained]
|
||||
assert num_classes == settings['num_classes'], \
|
||||
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
|
||||
|
||||
# both 'imagenet'&'imagenet+background' are loaded from same parameters
|
||||
model = InceptionV4Base(num_classes=1001)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||
|
||||
if pretrained == 'imagenet':
|
||||
new_last_linear = nn.Linear(1536, 1000)
|
||||
new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
||||
new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
||||
model.last_linear = new_last_linear
|
||||
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
else:
|
||||
model = InceptionV4Base(num_classes=num_classes)
|
||||
return model
|
||||
|
||||
|
||||
class InceptionV4(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(InceptionV4, self).__init__()
|
||||
self.loss = loss
|
||||
base = inceptionv4()
|
||||
self.features = base.features
|
||||
self.classifier = nn.Linear(1536, num_classes)
|
||||
self.feat_dim = 1536
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
@ -0,0 +1,123 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
__all__ = ['MobileNetV2']
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic convolutional block:
|
||||
convolution (bias discarded) + batch normalization + relu6.
|
||||
|
||||
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
|
||||
in_c (int): number of input channels.
|
||||
out_c (int): number of output channels.
|
||||
k (int or tuple): kernel size.
|
||||
s (int or tuple): stride.
|
||||
p (int or tuple): padding.
|
||||
g (int): number of blocked connections from input channels
|
||||
to output channels (default: 1).
|
||||
"""
|
||||
def __init__(self, in_c, out_c, k, s=1, p=0, g=1):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False, groups=g)
|
||||
self.bn = nn.BatchNorm2d(out_c)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu6(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, expansion_factor, stride):
|
||||
super(Bottleneck, self).__init__()
|
||||
mid_channels = in_channels * expansion_factor
|
||||
self.use_residual = stride == 1 and in_channels == out_channels
|
||||
self.conv1 = ConvBlock(in_channels, mid_channels, 1)
|
||||
self.dwconv2 = ConvBlock(mid_channels, mid_channels, 3, stride, 1, g=mid_channels)
|
||||
self.conv3 = nn.Sequential(
|
||||
nn.Conv2d(mid_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
m = self.conv1(x)
|
||||
m = self.dwconv2(m)
|
||||
m = self.conv3(m)
|
||||
if self.use_residual:
|
||||
return x + m
|
||||
else:
|
||||
return m
|
||||
|
||||
|
||||
|
||||
class MobileNetV2(nn.Module):
|
||||
"""
|
||||
MobileNetV2
|
||||
|
||||
Reference:
|
||||
Sandler et al. MobileNetV2: Inverted Residuals and Linear Bottlenecks. CVPR 2018.
|
||||
"""
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(MobileNetV2, self).__init__()
|
||||
self.loss = loss
|
||||
|
||||
self.conv1 = ConvBlock(3, 32, 3, s=2, p=1)
|
||||
self.block2 = Bottleneck(32, 16, 1, 1)
|
||||
self.block3 = nn.Sequential(
|
||||
Bottleneck(16, 24, 6, 2),
|
||||
Bottleneck(24, 24, 6, 1),
|
||||
)
|
||||
self.block4 = nn.Sequential(
|
||||
Bottleneck(24, 32, 6, 2),
|
||||
Bottleneck(32, 32, 6, 1),
|
||||
Bottleneck(32, 32, 6, 1),
|
||||
)
|
||||
self.block5 = nn.Sequential(
|
||||
Bottleneck(32, 64, 6, 2),
|
||||
Bottleneck(64, 64, 6, 1),
|
||||
Bottleneck(64, 64, 6, 1),
|
||||
Bottleneck(64, 64, 6, 1),
|
||||
)
|
||||
self.block6 = nn.Sequential(
|
||||
Bottleneck(64, 96, 6, 1),
|
||||
Bottleneck(96, 96, 6, 1),
|
||||
Bottleneck(96, 96, 6, 1),
|
||||
)
|
||||
self.block7 = nn.Sequential(
|
||||
Bottleneck(96, 160, 6, 2),
|
||||
Bottleneck(160, 160, 6, 1),
|
||||
Bottleneck(160, 160, 6, 1),
|
||||
)
|
||||
self.block8 = Bottleneck(160, 320, 6, 1)
|
||||
self.conv9 = ConvBlock(320, 1280, 1)
|
||||
self.classifier = nn.Linear(1280, num_classes)
|
||||
self.feat_dim = 1280
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.block2(x)
|
||||
x = self.block3(x)
|
||||
x = self.block4(x)
|
||||
x = self.block5(x)
|
||||
x = self.block6(x)
|
||||
x = self.block7(x)
|
||||
x = self.block8(x)
|
||||
x = self.conv9(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1)
|
||||
x = F.dropout(x, training=self.training)
|
||||
|
||||
if not self.training:
|
||||
return x
|
||||
|
||||
y = self.classifier(x)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, x
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
@ -0,0 +1,190 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
__all__ = ['MuDeep']
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic convolutional block:
|
||||
convolution + batch normalization + relu.
|
||||
|
||||
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
|
||||
in_c (int): number of input channels.
|
||||
out_c (int): number of output channels.
|
||||
k (int or tuple): kernel size.
|
||||
s (int or tuple): stride.
|
||||
p (int or tuple): padding.
|
||||
"""
|
||||
def __init__(self, in_c, out_c, k, s, p):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
|
||||
self.bn = nn.BatchNorm2d(out_c)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ConvLayers(nn.Module):
|
||||
"""Preprocessing layers."""
|
||||
def __init__(self):
|
||||
super(ConvLayers, self).__init__()
|
||||
self.conv1 = ConvBlock(3, 48, k=3, s=1, p=1)
|
||||
self.conv2 = ConvBlock(48, 96, k=3, s=1, p=1)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.maxpool(x)
|
||||
return x
|
||||
|
||||
|
||||
class MultiScaleA(nn.Module):
|
||||
"""Multi-scale stream layer A (Sec.3.1)"""
|
||||
def __init__(self):
|
||||
super(MultiScaleA, self).__init__()
|
||||
self.stream1 = nn.Sequential(
|
||||
ConvBlock(96, 96, k=1, s=1, p=0),
|
||||
ConvBlock(96, 24, k=3, s=1, p=1),
|
||||
)
|
||||
self.stream2 = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
ConvBlock(96, 24, k=1, s=1, p=0),
|
||||
)
|
||||
self.stream3 = ConvBlock(96, 24, k=1, s=1, p=0)
|
||||
self.stream4 = nn.Sequential(
|
||||
ConvBlock(96, 16, k=1, s=1, p=0),
|
||||
ConvBlock(16, 24, k=3, s=1, p=1),
|
||||
ConvBlock(24, 24, k=3, s=1, p=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
s4 = self.stream4(x)
|
||||
y = torch.cat([s1, s2, s3, s4], dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class Reduction(nn.Module):
|
||||
"""Reduction layer (Sec.3.1)"""
|
||||
def __init__(self):
|
||||
super(Reduction, self).__init__()
|
||||
self.stream1 = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.stream2 = ConvBlock(96, 96, k=3, s=2, p=1)
|
||||
self.stream3 = nn.Sequential(
|
||||
ConvBlock(96, 48, k=1, s=1, p=0),
|
||||
ConvBlock(48, 56, k=3, s=1, p=1),
|
||||
ConvBlock(56, 64, k=3, s=2, p=1),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
y = torch.cat([s1, s2, s3], dim=1)
|
||||
return y
|
||||
|
||||
|
||||
class MultiScaleB(nn.Module):
|
||||
"""Multi-scale stream layer B (Sec.3.1)"""
|
||||
def __init__(self):
|
||||
super(MultiScaleB, self).__init__()
|
||||
self.stream1 = nn.Sequential(
|
||||
nn.AvgPool2d(kernel_size=3, stride=1, padding=1),
|
||||
ConvBlock(256, 256, k=1, s=1, p=0),
|
||||
)
|
||||
self.stream2 = nn.Sequential(
|
||||
ConvBlock(256, 64, k=1, s=1, p=0),
|
||||
ConvBlock(64, 128, k=(1, 3), s=1, p=(0, 1)),
|
||||
ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
|
||||
)
|
||||
self.stream3 = ConvBlock(256, 256, k=1, s=1, p=0)
|
||||
self.stream4 = nn.Sequential(
|
||||
ConvBlock(256, 64, k=1, s=1, p=0),
|
||||
ConvBlock(64, 64, k=(1, 3), s=1, p=(0, 1)),
|
||||
ConvBlock(64, 128, k=(3, 1), s=1, p=(1, 0)),
|
||||
ConvBlock(128, 128, k=(1, 3), s=1, p=(0, 1)),
|
||||
ConvBlock(128, 256, k=(3, 1), s=1, p=(1, 0)),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
s1 = self.stream1(x)
|
||||
s2 = self.stream2(x)
|
||||
s3 = self.stream3(x)
|
||||
s4 = self.stream4(x)
|
||||
return s1, s2, s3, s4
|
||||
|
||||
|
||||
class Fusion(nn.Module):
|
||||
"""Saliency-based learning fusion layer (Sec.3.2)"""
|
||||
def __init__(self):
|
||||
super(Fusion, self).__init__()
|
||||
self.a1 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
||||
self.a2 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
||||
self.a3 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
||||
self.a4 = nn.Parameter(torch.rand(1, 256, 1, 1))
|
||||
|
||||
# We add an average pooling layer to reduce the spatial dimension
|
||||
# of feature maps, which differs from the original paper.
|
||||
self.avgpool = nn.AvgPool2d(kernel_size=4, stride=4, padding=0)
|
||||
|
||||
def forward(self, x1, x2, x3, x4):
|
||||
s1 = self.a1.expand_as(x1) * x1
|
||||
s2 = self.a2.expand_as(x2) * x2
|
||||
s3 = self.a3.expand_as(x3) * x3
|
||||
s4 = self.a4.expand_as(x4) * x4
|
||||
y = self.avgpool(s1 + s2 + s3 + s4)
|
||||
return y
|
||||
|
||||
|
||||
class MuDeep(nn.Module):
|
||||
"""Multiscale deep neural network.
|
||||
|
||||
Reference:
|
||||
Qian et al. Multi-scale Deep Learning Architectures for Person Re-identification. ICCV 2017.
|
||||
"""
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(MuDeep, self).__init__()
|
||||
self.loss = loss
|
||||
|
||||
self.block1 = ConvLayers()
|
||||
self.block2 = MultiScaleA()
|
||||
self.block3 = Reduction()
|
||||
self.block4 = MultiScaleB()
|
||||
self.block5 = Fusion()
|
||||
|
||||
# Due to this fully connected layer, input image has to be fixed
|
||||
# in shape, i.e. (3, 256, 128), such that the last convolutional feature
|
||||
# maps are of shape (256, 16, 8). If input shape is changed,
|
||||
# the input dimension of this layer has to be changed accordingly.
|
||||
self.fc = nn.Sequential(
|
||||
nn.Linear(256*16*8, 4096),
|
||||
nn.BatchNorm1d(4096),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.classifier = nn.Linear(4096, num_classes)
|
||||
self.feat_dim = 4096
|
||||
|
||||
def forward(self, x):
|
||||
x = self.block1(x)
|
||||
x = self.block2(x)
|
||||
x = self.block3(x)
|
||||
x = self.block4(x)
|
||||
x = self.block5(*x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
y = self.classifier(x)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, x
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
@ -0,0 +1,687 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as model_zoo
|
||||
import numpy as np
|
||||
|
||||
|
||||
"""
|
||||
NASNet Mobile
|
||||
Thanks to Anastasiia (https://github.com/DagnyT) for the great help, support and motivation!
|
||||
|
||||
|
||||
------------------------------------------------------------------------------------
|
||||
Architecture | Top-1 Acc | Top-5 Acc | Multiply-Adds | Params (M)
|
||||
------------------------------------------------------------------------------------
|
||||
| NASNet-A (4 @ 1056) | 74.08% | 91.74% | 564 M | 5.3 |
|
||||
------------------------------------------------------------------------------------
|
||||
# References:
|
||||
- [Learning Transferable Architectures for Scalable Image Recognition]
|
||||
(https://arxiv.org/abs/1707.07012)
|
||||
"""
|
||||
|
||||
|
||||
"""
|
||||
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
||||
"""
|
||||
|
||||
|
||||
pretrained_settings = {
|
||||
'nasnetamobile': {
|
||||
'imagenet': {
|
||||
#'url': 'https://github.com/veronikayurchuk/pretrained-models.pytorch/releases/download/v1.0/nasnetmobile-7e03cead.pth.tar',
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetamobile-7e03cead.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224], # resize 256
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.5, 0.5, 0.5],
|
||||
'std': [0.5, 0.5, 0.5],
|
||||
'num_classes': 1000
|
||||
},
|
||||
# 'imagenet+background': {
|
||||
# # 'url': 'http://data.lip6.fr/cadene/pretrainedmodels/nasnetalarge-a1897284.pth',
|
||||
# 'input_space': 'RGB',
|
||||
# 'input_size': [3, 224, 224], # resize 256
|
||||
# 'input_range': [0, 1],
|
||||
# 'mean': [0.5, 0.5, 0.5],
|
||||
# 'std': [0.5, 0.5, 0.5],
|
||||
# 'num_classes': 1001
|
||||
# }
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
__all__ = ['NASNetAMobile']
|
||||
|
||||
|
||||
class MaxPoolPad(nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super(MaxPoolPad, self).__init__()
|
||||
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
|
||||
self.pool = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pad(x)
|
||||
x = self.pool(x)
|
||||
x = x[:, :, 1:, 1:].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class AvgPoolPad(nn.Module):
|
||||
|
||||
def __init__(self, stride=2, padding=1):
|
||||
super(AvgPoolPad, self).__init__()
|
||||
self.pad = nn.ZeroPad2d((1, 0, 1, 0))
|
||||
self.pool = nn.AvgPool2d(3, stride=stride, padding=padding, count_include_pad=False)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.pad(x)
|
||||
x = self.pool(x)
|
||||
x = x[:, :, 1:, 1:].contiguous()
|
||||
return x
|
||||
|
||||
|
||||
class SeparableConv2d(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, dw_kernel, dw_stride, dw_padding, bias=False):
|
||||
super(SeparableConv2d, self).__init__()
|
||||
self.depthwise_conv2d = nn.Conv2d(in_channels, in_channels, dw_kernel,
|
||||
stride=dw_stride,
|
||||
padding=dw_padding,
|
||||
bias=bias,
|
||||
groups=in_channels)
|
||||
self.pointwise_conv2d = nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=bias)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.depthwise_conv2d(x)
|
||||
x = self.pointwise_conv2d(x)
|
||||
return x
|
||||
|
||||
|
||||
class BranchSeparables(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, name=None, bias=False):
|
||||
super(BranchSeparables, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.separable_1 = SeparableConv2d(in_channels, in_channels, kernel_size, stride, padding, bias=bias)
|
||||
self.bn_sep_1 = nn.BatchNorm2d(in_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.separable_2 = SeparableConv2d(in_channels, out_channels, kernel_size, 1, padding, bias=bias)
|
||||
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
self.name = name
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
if self.name == 'specific':
|
||||
x = nn.ZeroPad2d((1, 0, 1, 0))(x)
|
||||
x = self.separable_1(x)
|
||||
if self.name == 'specific':
|
||||
x = x[:, :, 1:, 1:].contiguous()
|
||||
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class BranchSeparablesStem(nn.Module):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, bias=False):
|
||||
super(BranchSeparablesStem, self).__init__()
|
||||
self.relu = nn.ReLU()
|
||||
self.separable_1 = SeparableConv2d(in_channels, out_channels, kernel_size, stride, padding, bias=bias)
|
||||
self.bn_sep_1 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
self.relu1 = nn.ReLU()
|
||||
self.separable_2 = SeparableConv2d(out_channels, out_channels, kernel_size, 1, padding, bias=bias)
|
||||
self.bn_sep_2 = nn.BatchNorm2d(out_channels, eps=0.001, momentum=0.1, affine=True)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.separable_1(x)
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class BranchSeparablesReduction(BranchSeparables):
|
||||
|
||||
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, z_padding=1, bias=False):
|
||||
BranchSeparables.__init__(self, in_channels, out_channels, kernel_size, stride, padding, bias)
|
||||
self.padding = nn.ZeroPad2d((z_padding, 0, z_padding, 0))
|
||||
|
||||
def forward(self, x):
|
||||
x = self.relu(x)
|
||||
x = self.padding(x)
|
||||
x = self.separable_1(x)
|
||||
x = x[:, :, 1:, 1:].contiguous()
|
||||
x = self.bn_sep_1(x)
|
||||
x = self.relu1(x)
|
||||
x = self.separable_2(x)
|
||||
x = self.bn_sep_2(x)
|
||||
return x
|
||||
|
||||
|
||||
class CellStem0(nn.Module):
|
||||
def __init__(self, stem_filters, num_filters=42):
|
||||
super(CellStem0, self).__init__()
|
||||
self.num_filters = num_filters
|
||||
self.stem_filters = stem_filters
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2)
|
||||
self.comb_iter_0_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False)
|
||||
|
||||
self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_1_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 7, 2, 3, bias=False)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
||||
self.comb_iter_2_right = BranchSeparablesStem(self.stem_filters, self.num_filters, 5, 2, 2, bias=False)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, bias=False)
|
||||
self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x1)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x1)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x1)
|
||||
x_comb_iter_2_right = self.comb_iter_2_right(x)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
||||
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
||||
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
||||
x_comb_iter_4_right = self.comb_iter_4_right(x1)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||
|
||||
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
||||
return x_out
|
||||
|
||||
|
||||
class CellStem1(nn.Module):
|
||||
|
||||
def __init__(self, stem_filters, num_filters):
|
||||
super(CellStem1, self).__init__()
|
||||
self.num_filters = num_filters
|
||||
self.stem_filters = stem_filters
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(2*self.num_filters, self.num_filters, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.path_1 = nn.Sequential()
|
||||
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_1.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False))
|
||||
self.path_2 = nn.ModuleList()
|
||||
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
||||
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_2.add_module('conv', nn.Conv2d(self.stem_filters, self.num_filters//2, 1, stride=1, bias=False))
|
||||
|
||||
self.final_path_bn = nn.BatchNorm2d(self.num_filters, eps=0.001, momentum=0.1, affine=True)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, name='specific', bias=False)
|
||||
self.comb_iter_0_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, name='specific', bias=False)
|
||||
|
||||
# self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_1_left = MaxPoolPad()
|
||||
self.comb_iter_1_right = BranchSeparables(self.num_filters, self.num_filters, 7, 2, 3, name='specific', bias=False)
|
||||
|
||||
# self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
||||
self.comb_iter_2_left = AvgPoolPad()
|
||||
self.comb_iter_2_right = BranchSeparables(self.num_filters, self.num_filters, 5, 2, 2, name='specific', bias=False)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(self.num_filters, self.num_filters, 3, 1, 1, name='specific', bias=False)
|
||||
# self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_4_right = MaxPoolPad()
|
||||
|
||||
def forward(self, x_conv0, x_stem_0):
|
||||
x_left = self.conv_1x1(x_stem_0)
|
||||
|
||||
x_relu = self.relu(x_conv0)
|
||||
# path 1
|
||||
x_path1 = self.path_1(x_relu)
|
||||
# path 2
|
||||
x_path2 = self.path_2.pad(x_relu)
|
||||
x_path2 = x_path2[:, :, 1:, 1:]
|
||||
x_path2 = self.path_2.avgpool(x_path2)
|
||||
x_path2 = self.path_2.conv(x_path2)
|
||||
# final path
|
||||
x_right = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_left)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_right)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_right)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_left)
|
||||
x_comb_iter_2_right = self.comb_iter_2_right(x_right)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
||||
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
||||
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
||||
x_comb_iter_4_right = self.comb_iter_4_right(x_left)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||
|
||||
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
||||
return x_out
|
||||
|
||||
|
||||
class FirstCell(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
super(FirstCell, self).__init__()
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.path_1 = nn.Sequential()
|
||||
self.path_1.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
self.path_2 = nn.ModuleList()
|
||||
self.path_2.add_module('pad', nn.ZeroPad2d((0, 1, 0, 1)))
|
||||
self.path_2.add_module('avgpool', nn.AvgPool2d(1, stride=2, count_include_pad=False))
|
||||
self.path_2.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
|
||||
self.final_path_bn = nn.BatchNorm2d(out_channels_left * 2, eps=0.001, momentum=0.1, affine=True)
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
|
||||
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
|
||||
self.comb_iter_1_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
|
||||
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
|
||||
def forward(self, x, x_prev):
|
||||
x_relu = self.relu(x_prev)
|
||||
# path 1
|
||||
x_path1 = self.path_1(x_relu)
|
||||
# path 2
|
||||
x_path2 = self.path_2.pad(x_relu)
|
||||
x_path2 = x_path2[:, :, 1:, 1:]
|
||||
x_path2 = self.path_2.avgpool(x_path2)
|
||||
x_path2 = self.path_2.conv(x_path2)
|
||||
# final path
|
||||
x_left = self.final_path_bn(torch.cat([x_path1, x_path2], 1))
|
||||
|
||||
x_right = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_left
|
||||
|
||||
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
|
||||
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_right
|
||||
|
||||
x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
||||
return x_out
|
||||
|
||||
|
||||
class NormalCell(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
super(NormalCell, self).__init__()
|
||||
self.conv_prev_1x1 = nn.Sequential()
|
||||
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 1, 2, bias=False)
|
||||
self.comb_iter_0_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
|
||||
|
||||
self.comb_iter_1_left = BranchSeparables(out_channels_left, out_channels_left, 5, 1, 2, bias=False)
|
||||
self.comb_iter_1_right = BranchSeparables(out_channels_left, out_channels_left, 3, 1, 1, bias=False)
|
||||
|
||||
self.comb_iter_2_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_3_left = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
|
||||
def forward(self, x, x_prev):
|
||||
x_left = self.conv_prev_1x1(x_prev)
|
||||
x_right = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_left)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_left
|
||||
|
||||
x_comb_iter_3_left = self.comb_iter_3_left(x_left)
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_left)
|
||||
x_comb_iter_3 = x_comb_iter_3_left + x_comb_iter_3_right
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_right)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_right
|
||||
|
||||
x_out = torch.cat([x_left, x_comb_iter_0, x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
||||
return x_out
|
||||
|
||||
|
||||
class ReductionCell0(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
super(ReductionCell0, self).__init__()
|
||||
self.conv_prev_1x1 = nn.Sequential()
|
||||
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.comb_iter_0_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
||||
self.comb_iter_0_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
||||
|
||||
self.comb_iter_1_left = MaxPoolPad()
|
||||
self.comb_iter_1_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 7, 2, 3, bias=False)
|
||||
|
||||
self.comb_iter_2_left = AvgPoolPad()
|
||||
self.comb_iter_2_right = BranchSeparablesReduction(out_channels_right, out_channels_right, 5, 2, 2, bias=False)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparablesReduction(out_channels_right, out_channels_right, 3, 1, 1, bias=False)
|
||||
self.comb_iter_4_right = MaxPoolPad()
|
||||
|
||||
def forward(self, x, x_prev):
|
||||
x_left = self.conv_prev_1x1(x_prev)
|
||||
x_right = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
||||
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
||||
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
||||
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
||||
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||
|
||||
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
||||
return x_out
|
||||
|
||||
|
||||
class ReductionCell1(nn.Module):
|
||||
|
||||
def __init__(self, in_channels_left, out_channels_left, in_channels_right, out_channels_right):
|
||||
super(ReductionCell1, self).__init__()
|
||||
self.conv_prev_1x1 = nn.Sequential()
|
||||
self.conv_prev_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_prev_1x1.add_module('conv', nn.Conv2d(in_channels_left, out_channels_left, 1, stride=1, bias=False))
|
||||
self.conv_prev_1x1.add_module('bn', nn.BatchNorm2d(out_channels_left, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.conv_1x1 = nn.Sequential()
|
||||
self.conv_1x1.add_module('relu', nn.ReLU())
|
||||
self.conv_1x1.add_module('conv', nn.Conv2d(in_channels_right, out_channels_right, 1, stride=1, bias=False))
|
||||
self.conv_1x1.add_module('bn', nn.BatchNorm2d(out_channels_right, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.comb_iter_0_left = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, name='specific', bias=False)
|
||||
self.comb_iter_0_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, name='specific', bias=False)
|
||||
|
||||
# self.comb_iter_1_left = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_1_left = MaxPoolPad()
|
||||
self.comb_iter_1_right = BranchSeparables(out_channels_right, out_channels_right, 7, 2, 3, name='specific', bias=False)
|
||||
|
||||
# self.comb_iter_2_left = nn.AvgPool2d(3, stride=2, padding=1, count_include_pad=False)
|
||||
self.comb_iter_2_left = AvgPoolPad()
|
||||
self.comb_iter_2_right = BranchSeparables(out_channels_right, out_channels_right, 5, 2, 2, name='specific', bias=False)
|
||||
|
||||
self.comb_iter_3_right = nn.AvgPool2d(3, stride=1, padding=1, count_include_pad=False)
|
||||
|
||||
self.comb_iter_4_left = BranchSeparables(out_channels_right, out_channels_right, 3, 1, 1, name='specific', bias=False)
|
||||
# self.comb_iter_4_right = nn.MaxPool2d(3, stride=2, padding=1)
|
||||
self.comb_iter_4_right =MaxPoolPad()
|
||||
|
||||
def forward(self, x, x_prev):
|
||||
x_left = self.conv_prev_1x1(x_prev)
|
||||
x_right = self.conv_1x1(x)
|
||||
|
||||
x_comb_iter_0_left = self.comb_iter_0_left(x_right)
|
||||
x_comb_iter_0_right = self.comb_iter_0_right(x_left)
|
||||
x_comb_iter_0 = x_comb_iter_0_left + x_comb_iter_0_right
|
||||
|
||||
x_comb_iter_1_left = self.comb_iter_1_left(x_right)
|
||||
x_comb_iter_1_right = self.comb_iter_1_right(x_left)
|
||||
x_comb_iter_1 = x_comb_iter_1_left + x_comb_iter_1_right
|
||||
|
||||
x_comb_iter_2_left = self.comb_iter_2_left(x_right)
|
||||
x_comb_iter_2_right = self.comb_iter_2_right(x_left)
|
||||
x_comb_iter_2 = x_comb_iter_2_left + x_comb_iter_2_right
|
||||
|
||||
x_comb_iter_3_right = self.comb_iter_3_right(x_comb_iter_0)
|
||||
x_comb_iter_3 = x_comb_iter_3_right + x_comb_iter_1
|
||||
|
||||
x_comb_iter_4_left = self.comb_iter_4_left(x_comb_iter_0)
|
||||
x_comb_iter_4_right = self.comb_iter_4_right(x_right)
|
||||
x_comb_iter_4 = x_comb_iter_4_left + x_comb_iter_4_right
|
||||
|
||||
x_out = torch.cat([x_comb_iter_1, x_comb_iter_2, x_comb_iter_3, x_comb_iter_4], 1)
|
||||
return x_out
|
||||
|
||||
|
||||
class NASNetAMobile(nn.Module):
|
||||
"""NASNetAMobile (4 @ 1056) """
|
||||
|
||||
def __init__(self, num_classes, stem_filters=32, penultimate_filters=1056, filters_multiplier=2, loss={'xent'}, **kwargs):
|
||||
super(NASNetAMobile, self).__init__()
|
||||
self.num_classes = num_classes
|
||||
self.stem_filters = stem_filters
|
||||
self.penultimate_filters = penultimate_filters
|
||||
self.filters_multiplier = filters_multiplier
|
||||
self.loss = loss
|
||||
|
||||
filters = self.penultimate_filters // 24
|
||||
# 24 is default value for the architecture
|
||||
|
||||
self.conv0 = nn.Sequential()
|
||||
self.conv0.add_module('conv', nn.Conv2d(in_channels=3, out_channels=self.stem_filters, kernel_size=3, padding=0, stride=2,
|
||||
bias=False))
|
||||
self.conv0.add_module('bn', nn.BatchNorm2d(self.stem_filters, eps=0.001, momentum=0.1, affine=True))
|
||||
|
||||
self.cell_stem_0 = CellStem0(self.stem_filters, num_filters=filters // (filters_multiplier ** 2))
|
||||
self.cell_stem_1 = CellStem1(self.stem_filters, num_filters=filters // filters_multiplier)
|
||||
|
||||
self.cell_0 = FirstCell(in_channels_left=filters, out_channels_left=filters//2, # 1, 0.5
|
||||
in_channels_right=2*filters, out_channels_right=filters) # 2, 1
|
||||
self.cell_1 = NormalCell(in_channels_left=2*filters, out_channels_left=filters, # 2, 1
|
||||
in_channels_right=6*filters, out_channels_right=filters) # 6, 1
|
||||
self.cell_2 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, # 6, 1
|
||||
in_channels_right=6*filters, out_channels_right=filters) # 6, 1
|
||||
self.cell_3 = NormalCell(in_channels_left=6*filters, out_channels_left=filters, # 6, 1
|
||||
in_channels_right=6*filters, out_channels_right=filters) # 6, 1
|
||||
|
||||
self.reduction_cell_0 = ReductionCell0(in_channels_left=6*filters, out_channels_left=2*filters, # 6, 2
|
||||
in_channels_right=6*filters, out_channels_right=2*filters) # 6, 2
|
||||
|
||||
self.cell_6 = FirstCell(in_channels_left=6*filters, out_channels_left=filters, # 6, 1
|
||||
in_channels_right=8*filters, out_channels_right=2*filters) # 8, 2
|
||||
self.cell_7 = NormalCell(in_channels_left=8*filters, out_channels_left=2*filters, # 8, 2
|
||||
in_channels_right=12*filters, out_channels_right=2*filters) # 12, 2
|
||||
self.cell_8 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, # 12, 2
|
||||
in_channels_right=12*filters, out_channels_right=2*filters) # 12, 2
|
||||
self.cell_9 = NormalCell(in_channels_left=12*filters, out_channels_left=2*filters, # 12, 2
|
||||
in_channels_right=12*filters, out_channels_right=2*filters) # 12, 2
|
||||
|
||||
self.reduction_cell_1 = ReductionCell1(in_channels_left=12*filters, out_channels_left=4*filters, # 12, 4
|
||||
in_channels_right=12*filters, out_channels_right=4*filters) # 12, 4
|
||||
|
||||
self.cell_12 = FirstCell(in_channels_left=12*filters, out_channels_left=2*filters, # 12, 2
|
||||
in_channels_right=16*filters, out_channels_right=4*filters) # 16, 4
|
||||
self.cell_13 = NormalCell(in_channels_left=16*filters, out_channels_left=4*filters, # 16, 4
|
||||
in_channels_right=24*filters, out_channels_right=4*filters) # 24, 4
|
||||
self.cell_14 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, # 24, 4
|
||||
in_channels_right=24*filters, out_channels_right=4*filters) # 24, 4
|
||||
self.cell_15 = NormalCell(in_channels_left=24*filters, out_channels_left=4*filters, # 24, 4
|
||||
in_channels_right=24*filters, out_channels_right=4*filters) # 24, 4
|
||||
|
||||
self.relu = nn.ReLU()
|
||||
self.dropout = nn.Dropout()
|
||||
self.classifier = nn.Linear(24 * filters, self.num_classes)
|
||||
self.feat_dim = 24 * filters
|
||||
|
||||
self.init_params()
|
||||
|
||||
def init_params(self):
|
||||
"""Load ImageNet pretrained weights"""
|
||||
settings = pretrained_settings['nasnetamobile']['imagenet']
|
||||
pretrained_dict = model_zoo.load_url(settings['url'], map_location=None)
|
||||
model_dict = self.state_dict()
|
||||
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
|
||||
model_dict.update(pretrained_dict)
|
||||
self.load_state_dict(model_dict)
|
||||
|
||||
def features(self, input):
|
||||
x_conv0 = self.conv0(input)
|
||||
x_stem_0 = self.cell_stem_0(x_conv0)
|
||||
x_stem_1 = self.cell_stem_1(x_conv0, x_stem_0)
|
||||
|
||||
x_cell_0 = self.cell_0(x_stem_1, x_stem_0)
|
||||
x_cell_1 = self.cell_1(x_cell_0, x_stem_1)
|
||||
x_cell_2 = self.cell_2(x_cell_1, x_cell_0)
|
||||
x_cell_3 = self.cell_3(x_cell_2, x_cell_1)
|
||||
|
||||
x_reduction_cell_0 = self.reduction_cell_0(x_cell_3, x_cell_2)
|
||||
|
||||
x_cell_6 = self.cell_6(x_reduction_cell_0, x_cell_3)
|
||||
x_cell_7 = self.cell_7(x_cell_6, x_reduction_cell_0)
|
||||
x_cell_8 = self.cell_8(x_cell_7, x_cell_6)
|
||||
x_cell_9 = self.cell_9(x_cell_8, x_cell_7)
|
||||
|
||||
x_reduction_cell_1 = self.reduction_cell_1(x_cell_9, x_cell_8)
|
||||
|
||||
x_cell_12 = self.cell_12(x_reduction_cell_1, x_cell_9)
|
||||
x_cell_13 = self.cell_13(x_cell_12, x_reduction_cell_1)
|
||||
x_cell_14 = self.cell_14(x_cell_13, x_cell_12)
|
||||
x_cell_15 = self.cell_15(x_cell_14, x_cell_13)
|
||||
|
||||
x_cell_15 = self.relu(x_cell_15)
|
||||
x_cell_15 = F.avg_pool2d(x_cell_15, x_cell_15.size()[2:])
|
||||
x_cell_15 = x_cell_15.view(x_cell_15.size(0), -1)
|
||||
x_cell_15 = self.dropout(x_cell_15)
|
||||
|
||||
return x_cell_15
|
||||
|
||||
def forward(self, input):
|
||||
f = self.features(input)
|
||||
|
||||
if not self.training:
|
||||
return f
|
||||
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
elif self.loss == {'cent'}:
|
||||
return y, f
|
||||
elif self.loss == {'ring'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
"""Following code is not used"""
|
||||
def nasnetamobile(num_classes=1001, pretrained='imagenet'):
|
||||
r"""NASNetALarge model architecture from the
|
||||
`"NASNet" <https://arxiv.org/abs/1707.07012>`_ paper.
|
||||
"""
|
||||
if pretrained:
|
||||
settings = pretrained_settings['nasnetamobile'][pretrained]
|
||||
assert num_classes == settings['num_classes'], \
|
||||
"num_classes should be {}, but is {}".format(settings['num_classes'], num_classes)
|
||||
|
||||
# both 'imagenet'&'imagenet+background' are loaded from same parameters
|
||||
model = NASNetAMobile(num_classes=num_classes)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url'], map_location=None))
|
||||
|
||||
# if pretrained == 'imagenet':
|
||||
# new_last_linear = nn.Linear(model.last_linear.in_features, 1000)
|
||||
# new_last_linear.weight.data = model.last_linear.weight.data[1:]
|
||||
# new_last_linear.bias.data = model.last_linear.bias.data[1:]
|
||||
# model.last_linear = new_last_linear
|
||||
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
else:
|
||||
settings = pretrained_settings['nasnetamobile']['imagenet']
|
||||
model = NASNetAMobile(num_classes=num_classes)
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
return model
|
|
@ -0,0 +1,113 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
__all__ = ['ResNet50', 'ResNet101', 'ResNet50M']
|
||||
|
||||
|
||||
class ResNet50(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(ResNet50, self).__init__()
|
||||
self.loss = loss
|
||||
resnet50 = torchvision.models.resnet50(pretrained=True)
|
||||
self.base = nn.Sequential(*list(resnet50.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
class ResNet101(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(ResNet101, self).__init__()
|
||||
self.loss = loss
|
||||
resnet101 = torchvision.models.resnet101(pretrained=True)
|
||||
self.base = nn.Sequential(*list(resnet101.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
class ResNet50M(nn.Module):
|
||||
"""ResNet50 + mid-level features.
|
||||
|
||||
Reference:
|
||||
Yu et al. The Devil is in the Middle: Exploiting Mid-level Representations for
|
||||
Cross-Domain Instance Matching. arXiv:1711.08106.
|
||||
"""
|
||||
def __init__(self, num_classes=0, loss={'xent'}, **kwargs):
|
||||
super(ResNet50M, self).__init__()
|
||||
self.loss = loss
|
||||
resnet50 = torchvision.models.resnet50(pretrained=True)
|
||||
base = nn.Sequential(*list(resnet50.children())[:-2])
|
||||
self.layers1 = nn.Sequential(base[0], base[1], base[2])
|
||||
self.layers2 = nn.Sequential(base[3], base[4])
|
||||
self.layers3 = base[5]
|
||||
self.layers4 = base[6]
|
||||
self.layers5a = base[7][0]
|
||||
self.layers5b = base[7][1]
|
||||
self.layers5c = base[7][2]
|
||||
self.fc_fuse = nn.Sequential(nn.Linear(4096, 1024), nn.BatchNorm1d(1024), nn.ReLU())
|
||||
self.classifier = nn.Linear(3072, num_classes)
|
||||
self.feat_dim = 3072 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.layers1(x)
|
||||
x2 = self.layers2(x1)
|
||||
x3 = self.layers3(x2)
|
||||
x4 = self.layers4(x3)
|
||||
x5a = self.layers5a(x4)
|
||||
x5b = self.layers5b(x5a)
|
||||
x5c = self.layers5c(x5b)
|
||||
|
||||
x5a_feat = F.avg_pool2d(x5a, x5a.size()[2:]).view(x5a.size(0), x5a.size(1))
|
||||
x5b_feat = F.avg_pool2d(x5b, x5b.size()[2:]).view(x5b.size(0), x5b.size(1))
|
||||
x5c_feat = F.avg_pool2d(x5c, x5c.size()[2:]).view(x5c.size(0), x5c.size(1))
|
||||
|
||||
midfeat = torch.cat((x5a_feat, x5b_feat), dim=1)
|
||||
midfeat = self.fc_fuse(midfeat)
|
||||
|
||||
combofeat = torch.cat((x5c_feat, midfeat), dim=1)
|
||||
|
||||
if not self.training:
|
||||
return combofeat
|
||||
|
||||
prelogits = self.classifier(combofeat)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return prelogits
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return prelogits, combofeat
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,549 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
from collections import OrderedDict
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.utils import model_zoo
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
"""
|
||||
Code imported from https://github.com/Cadene/pretrained-models.pytorch
|
||||
"""
|
||||
|
||||
|
||||
__all__ = ['SEResNet50', 'SEResNet101', 'SEResNeXt50', 'SEResNeXt101']
|
||||
|
||||
|
||||
pretrained_settings = {
|
||||
'senet154': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/senet154-c7b49a05.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet50': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet50-ce0d4300.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet101': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet101-7e38fcc6.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnet152': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnet152-d17c99b7.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnext50_32x4d': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext50_32x4d-a260b3a4.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
'se_resnext101_32x4d': {
|
||||
'imagenet': {
|
||||
'url': 'http://data.lip6.fr/cadene/pretrainedmodels/se_resnext101_32x4d-3b2fe3d8.pth',
|
||||
'input_space': 'RGB',
|
||||
'input_size': [3, 224, 224],
|
||||
'input_range': [0, 1],
|
||||
'mean': [0.485, 0.456, 0.406],
|
||||
'std': [0.229, 0.224, 0.225],
|
||||
'num_classes': 1000
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
class SEModule(nn.Module):
|
||||
|
||||
def __init__(self, channels, reduction):
|
||||
super(SEModule, self).__init__()
|
||||
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
||||
self.fc1 = nn.Conv2d(channels, channels // reduction, kernel_size=1, padding=0)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.fc2 = nn.Conv2d(channels // reduction, channels, kernel_size=1, padding=0)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
|
||||
def forward(self, x):
|
||||
module_input = x
|
||||
x = self.avg_pool(x)
|
||||
x = self.fc1(x)
|
||||
x = self.relu(x)
|
||||
x = self.fc2(x)
|
||||
x = self.sigmoid(x)
|
||||
return module_input * x
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
"""
|
||||
Base class for bottlenecks that implements `forward()` method.
|
||||
"""
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out = self.se_module(out) + residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class SEBottleneck(Bottleneck):
|
||||
"""
|
||||
Bottleneck for SENet154.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None):
|
||||
super(SEBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes * 2, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes * 2)
|
||||
self.conv2 = nn.Conv2d(planes * 2, planes * 4, kernel_size=3,
|
||||
stride=stride, padding=1, groups=groups,
|
||||
bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes * 4)
|
||||
self.conv3 = nn.Conv2d(planes * 4, planes * 4, kernel_size=1,
|
||||
bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se_module = SEModule(planes * 4, reduction=reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class SEResNetBottleneck(Bottleneck):
|
||||
"""
|
||||
ResNet bottleneck with a Squeeze-and-Excitation module. It follows Caffe
|
||||
implementation and uses `stride=stride` in `conv1` and not in `conv2`
|
||||
(the latter is used in the torchvision implementation of ResNet).
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None):
|
||||
super(SEResNetBottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False,
|
||||
stride=stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, padding=1,
|
||||
groups=groups, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se_module = SEModule(planes * 4, reduction=reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class SEResNeXtBottleneck(Bottleneck):
|
||||
"""
|
||||
ResNeXt bottleneck type C with a Squeeze-and-Excitation module.
|
||||
"""
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, groups, reduction, stride=1,
|
||||
downsample=None, base_width=4):
|
||||
super(SEResNeXtBottleneck, self).__init__()
|
||||
width = int(math.floor(planes * (base_width / 64.)) * groups)
|
||||
self.conv1 = nn.Conv2d(inplanes, width, kernel_size=1, bias=False,
|
||||
stride=1)
|
||||
self.bn1 = nn.BatchNorm2d(width)
|
||||
self.conv2 = nn.Conv2d(width, width, kernel_size=3, stride=stride,
|
||||
padding=1, groups=groups, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(width)
|
||||
self.conv3 = nn.Conv2d(width, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.se_module = SEModule(planes * 4, reduction=reduction)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
|
||||
class SENet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, groups, reduction, dropout_p=0.2,
|
||||
inplanes=128, input_3x3=True, downsample_kernel_size=3,
|
||||
downsample_padding=1, num_classes=1000):
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
block (nn.Module): Bottleneck class.
|
||||
- For SENet154: SEBottleneck
|
||||
- For SE-ResNet models: SEResNetBottleneck
|
||||
- For SE-ResNeXt models: SEResNeXtBottleneck
|
||||
layers (list of ints): Number of residual blocks for 4 layers of the
|
||||
network (layer1...layer4).
|
||||
groups (int): Number of groups for the 3x3 convolution in each
|
||||
bottleneck block.
|
||||
- For SENet154: 64
|
||||
- For SE-ResNet models: 1
|
||||
- For SE-ResNeXt models: 32
|
||||
reduction (int): Reduction ratio for Squeeze-and-Excitation modules.
|
||||
- For all models: 16
|
||||
dropout_p (float or None): Drop probability for the Dropout layer.
|
||||
If `None` the Dropout layer is not used.
|
||||
- For SENet154: 0.2
|
||||
- For SE-ResNet models: None
|
||||
- For SE-ResNeXt models: None
|
||||
inplanes (int): Number of input channels for layer1.
|
||||
- For SENet154: 128
|
||||
- For SE-ResNet models: 64
|
||||
- For SE-ResNeXt models: 64
|
||||
input_3x3 (bool): If `True`, use three 3x3 convolutions instead of
|
||||
a single 7x7 convolution in layer0.
|
||||
- For SENet154: True
|
||||
- For SE-ResNet models: False
|
||||
- For SE-ResNeXt models: False
|
||||
downsample_kernel_size (int): Kernel size for downsampling convolutions
|
||||
in layer2, layer3 and layer4.
|
||||
- For SENet154: 3
|
||||
- For SE-ResNet models: 1
|
||||
- For SE-ResNeXt models: 1
|
||||
downsample_padding (int): Padding for downsampling convolutions in
|
||||
layer2, layer3 and layer4.
|
||||
- For SENet154: 1
|
||||
- For SE-ResNet models: 0
|
||||
- For SE-ResNeXt models: 0
|
||||
num_classes (int): Number of outputs in `last_linear` layer.
|
||||
- For all models: 1000
|
||||
"""
|
||||
super(SENet, self).__init__()
|
||||
self.inplanes = inplanes
|
||||
if input_3x3:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(3, 64, 3, stride=2, padding=1,
|
||||
bias=False)),
|
||||
('bn1', nn.BatchNorm2d(64)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
('conv2', nn.Conv2d(64, 64, 3, stride=1, padding=1,
|
||||
bias=False)),
|
||||
('bn2', nn.BatchNorm2d(64)),
|
||||
('relu2', nn.ReLU(inplace=True)),
|
||||
('conv3', nn.Conv2d(64, inplanes, 3, stride=1, padding=1,
|
||||
bias=False)),
|
||||
('bn3', nn.BatchNorm2d(inplanes)),
|
||||
('relu3', nn.ReLU(inplace=True)),
|
||||
]
|
||||
else:
|
||||
layer0_modules = [
|
||||
('conv1', nn.Conv2d(3, inplanes, kernel_size=7, stride=2,
|
||||
padding=3, bias=False)),
|
||||
('bn1', nn.BatchNorm2d(inplanes)),
|
||||
('relu1', nn.ReLU(inplace=True)),
|
||||
]
|
||||
# To preserve compatibility with Caffe weights `ceil_mode=True`
|
||||
# is used instead of `padding=1`.
|
||||
layer0_modules.append(('pool', nn.MaxPool2d(3, stride=2,
|
||||
ceil_mode=True)))
|
||||
self.layer0 = nn.Sequential(OrderedDict(layer0_modules))
|
||||
self.layer1 = self._make_layer(
|
||||
block,
|
||||
planes=64,
|
||||
blocks=layers[0],
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=1,
|
||||
downsample_padding=0
|
||||
)
|
||||
self.layer2 = self._make_layer(
|
||||
block,
|
||||
planes=128,
|
||||
blocks=layers[1],
|
||||
stride=2,
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=downsample_kernel_size,
|
||||
downsample_padding=downsample_padding
|
||||
)
|
||||
self.layer3 = self._make_layer(
|
||||
block,
|
||||
planes=256,
|
||||
blocks=layers[2],
|
||||
stride=2,
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=downsample_kernel_size,
|
||||
downsample_padding=downsample_padding
|
||||
)
|
||||
self.layer4 = self._make_layer(
|
||||
block,
|
||||
planes=512,
|
||||
blocks=layers[3],
|
||||
stride=2,
|
||||
groups=groups,
|
||||
reduction=reduction,
|
||||
downsample_kernel_size=downsample_kernel_size,
|
||||
downsample_padding=downsample_padding
|
||||
)
|
||||
self.avg_pool = nn.AvgPool2d(7, stride=1)
|
||||
self.dropout = nn.Dropout(dropout_p) if dropout_p is not None else None
|
||||
self.last_linear = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
def _make_layer(self, block, planes, blocks, groups, reduction, stride=1,
|
||||
downsample_kernel_size=1, downsample_padding=0):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=downsample_kernel_size, stride=stride,
|
||||
padding=downsample_padding, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, groups, reduction, stride,
|
||||
downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes, groups, reduction))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def features(self, x):
|
||||
x = self.layer0(x)
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
return x
|
||||
|
||||
def logits(self, x):
|
||||
x = self.avg_pool(x)
|
||||
if self.dropout is not None:
|
||||
x = self.dropout(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.last_linear(x)
|
||||
return x
|
||||
|
||||
def forward(self, x):
|
||||
x = self.features(x)
|
||||
x = self.logits(x)
|
||||
return x
|
||||
|
||||
|
||||
def initialize_pretrained_model(model, num_classes, settings):
|
||||
assert num_classes == settings['num_classes'], \
|
||||
'num_classes should be {}, but is {}'.format(
|
||||
settings['num_classes'], num_classes)
|
||||
model.load_state_dict(model_zoo.load_url(settings['url']))
|
||||
model.input_space = settings['input_space']
|
||||
model.input_size = settings['input_size']
|
||||
model.input_range = settings['input_range']
|
||||
model.mean = settings['mean']
|
||||
model.std = settings['std']
|
||||
|
||||
|
||||
def senet154(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEBottleneck, [3, 8, 36, 3], groups=64, reduction=16,
|
||||
dropout_p=0.2, num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['senet154'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet50(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 6, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnet50'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet101(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 4, 23, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnet101'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnet152(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNetBottleneck, [3, 8, 36, 3], groups=1, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnet152'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnext50_32x4d(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 6, 3], groups=32, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnext50_32x4d'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
def se_resnext101_32x4d(num_classes=1000, pretrained='imagenet'):
|
||||
model = SENet(SEResNeXtBottleneck, [3, 4, 23, 3], groups=32, reduction=16,
|
||||
dropout_p=None, inplanes=64, input_3x3=False,
|
||||
downsample_kernel_size=1, downsample_padding=0,
|
||||
num_classes=num_classes)
|
||||
if pretrained is not None:
|
||||
settings = pretrained_settings['se_resnext101_32x4d'][pretrained]
|
||||
initialize_pretrained_model(model, num_classes, settings)
|
||||
return model
|
||||
|
||||
|
||||
##################### Model Definition #########################
|
||||
|
||||
|
||||
class SEResNet50(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(SEResNet50, self).__init__()
|
||||
self.loss = loss
|
||||
base = se_resnet50()
|
||||
self.base = nn.Sequential(*list(base.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
class SEResNet101(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(SEResNet101, self).__init__()
|
||||
self.loss = loss
|
||||
base = se_resnet101()
|
||||
self.base = nn.Sequential(*list(base.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
class SEResNeXt50(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(SEResNeXt50, self).__init__()
|
||||
self.loss = loss
|
||||
base = se_resnext50_32x4d()
|
||||
self.base = nn.Sequential(*list(base.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
class SEResNeXt101(nn.Module):
|
||||
def __init__(self, num_classes, loss={'xent'}, **kwargs):
|
||||
super(SEResNeXt101, self).__init__()
|
||||
self.loss = loss
|
||||
base = se_resnext101_32x4d()
|
||||
self.base = nn.Sequential(*list(base.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:])
|
||||
f = x.view(x.size(0), -1)
|
||||
if not self.training:
|
||||
return f
|
||||
y = self.classifier(f)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
@ -0,0 +1,133 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
__all__ = ['ShuffleNet']
|
||||
|
||||
|
||||
class ChannelShuffle(nn.Module):
|
||||
def __init__(self, num_groups):
|
||||
super(ChannelShuffle, self).__init__()
|
||||
self.g = num_groups
|
||||
|
||||
def forward(self, x):
|
||||
b, c, h, w = x.size()
|
||||
n = c / self.g
|
||||
# reshape
|
||||
x = x.view(b, self.g, n, h, w)
|
||||
# transpose
|
||||
x = x.permute(0, 2, 1, 3, 4).contiguous()
|
||||
# flatten
|
||||
x = x.view(b, c, h, w)
|
||||
return x
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, stride, num_groups, group_conv1x1=True):
|
||||
super(Bottleneck, self).__init__()
|
||||
assert stride in [1, 2], "Warning: stride must be either 1 or 2"
|
||||
self.stride = stride
|
||||
mid_channels = out_channels // 4
|
||||
if stride == 2: out_channels -= in_channels
|
||||
# group conv is not applied to first conv1x1 at stage 2
|
||||
num_groups_conv1x1 = num_groups if group_conv1x1 else 1
|
||||
self.conv1 = nn.Conv2d(in_channels, mid_channels, 1, groups=num_groups_conv1x1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(mid_channels)
|
||||
self.shuffle1 = ChannelShuffle(num_groups)
|
||||
self.conv2 = nn.Conv2d(mid_channels, mid_channels, 3, stride=stride, padding=1, groups=mid_channels, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(mid_channels)
|
||||
self.conv3 = nn.Conv2d(mid_channels, out_channels, 1, groups=num_groups, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(out_channels)
|
||||
if stride == 2: self.shortcut = nn.AvgPool2d(3, stride=2, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
out = F.relu(self.bn1(self.conv1(x)))
|
||||
out = self.shuffle1(out)
|
||||
out = self.bn2(self.conv2(out))
|
||||
out = self.bn3(self.conv3(out))
|
||||
if self.stride == 2:
|
||||
res = self.shortcut(x)
|
||||
out = F.relu(torch.cat([res, out], 1))
|
||||
else:
|
||||
out = F.relu(x + out)
|
||||
return out
|
||||
|
||||
|
||||
# configuration of (num_groups: #out_channels) based on Table 1 in the paper
|
||||
cfg = {
|
||||
1: [144, 288, 576],
|
||||
2: [200, 400, 800],
|
||||
3: [240, 480, 960],
|
||||
4: [272, 544, 1088],
|
||||
8: [384, 768, 1536],
|
||||
}
|
||||
|
||||
|
||||
class ShuffleNet(nn.Module):
|
||||
"""ShuffleNet
|
||||
|
||||
Reference:
|
||||
Zhang et al. ShuffleNet: An Extremely Efficient Convolutional Neural
|
||||
Network for Mobile Devices. CVPR 2018.
|
||||
"""
|
||||
def __init__(self, num_classes, loss={'xent'}, num_groups=3, **kwargs):
|
||||
super(ShuffleNet, self).__init__()
|
||||
self.loss = loss
|
||||
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(3, 24, 3, stride=2, padding=1, bias=False),
|
||||
nn.BatchNorm2d(24),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
)
|
||||
|
||||
self.stage2 = nn.Sequential(
|
||||
Bottleneck(24, cfg[num_groups][0], 2, num_groups, group_conv1x1=False),
|
||||
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][0], cfg[num_groups][0], 1, num_groups),
|
||||
)
|
||||
|
||||
self.stage3 = nn.Sequential(
|
||||
Bottleneck(cfg[num_groups][0], cfg[num_groups][1], 2, num_groups),
|
||||
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][1], cfg[num_groups][1], 1, num_groups),
|
||||
)
|
||||
|
||||
self.stage4 = nn.Sequential(
|
||||
Bottleneck(cfg[num_groups][1], cfg[num_groups][2], 2, num_groups),
|
||||
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
|
||||
Bottleneck(cfg[num_groups][2], cfg[num_groups][2], 1, num_groups),
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(cfg[num_groups][2], num_classes)
|
||||
self.feat_dim = cfg[num_groups][2]
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.stage2(x)
|
||||
x = self.stage3(x)
|
||||
x = self.stage4(x)
|
||||
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1)
|
||||
|
||||
if not self.training:
|
||||
return x
|
||||
|
||||
y = self.classifier(x)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, x
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
@ -0,0 +1,124 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
__all__ = ['SqueezeNet']
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic convolutional block:
|
||||
convolution + batch normalization + relu.
|
||||
|
||||
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
|
||||
in_c (int): number of input channels.
|
||||
out_c (int): number of output channels.
|
||||
k (int or tuple): kernel size.
|
||||
s (int or tuple): stride.
|
||||
p (int or tuple): padding.
|
||||
"""
|
||||
def __init__(self, in_c, out_c, k, s=1, p=0):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
|
||||
self.bn = nn.BatchNorm2d(out_c)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class ExpandLayer(nn.Module):
|
||||
def __init__(self, in_channels, e1_channels, e3_channels):
|
||||
super(ExpandLayer, self).__init__()
|
||||
self.conv11 = ConvBlock(in_channels, e1_channels, 1)
|
||||
self.conv33 = ConvBlock(in_channels, e3_channels, 3, p=1)
|
||||
|
||||
def forward(self, x):
|
||||
x11 = self.conv11(x)
|
||||
x33 = self.conv33(x)
|
||||
x = torch.cat([x11, x33], 1)
|
||||
return x
|
||||
|
||||
|
||||
class FireModule(nn.Module):
|
||||
"""
|
||||
Args:
|
||||
in_channels (int): number of input channels.
|
||||
s1_channels (int): number of 1-by-1 filters for squeeze layer.
|
||||
e1_channels (int): number of 1-by-1 filters for expand layer.
|
||||
e3_channels (int): number of 3-by-3 filters for expand layer.
|
||||
|
||||
Number of output channels from FireModule is e1_channels + e3_channels.
|
||||
"""
|
||||
def __init__(self, in_channels, s1_channels, e1_channels, e3_channels):
|
||||
super(FireModule, self).__init__()
|
||||
self.squeeze = ConvBlock(in_channels, s1_channels, 1)
|
||||
self.expand = ExpandLayer(s1_channels, e1_channels, e3_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.squeeze(x)
|
||||
x = self.expand(x)
|
||||
return x
|
||||
|
||||
|
||||
class SqueezeNet(nn.Module):
|
||||
"""SqueezeNet
|
||||
|
||||
Reference:
|
||||
Iandola et al. SqueezeNet: AlexNet-level accuracy with 50x fewer parameters
|
||||
and< 0.5 MB model size. arXiv:1602.07360.
|
||||
"""
|
||||
def __init__(self, num_classes, loss={'xent'}, bypass=True, **kwargs):
|
||||
super(SqueezeNet, self).__init__()
|
||||
self.loss = loss
|
||||
self.bypass = bypass
|
||||
|
||||
self.conv1 = ConvBlock(3, 96, 7, s=2, p=2)
|
||||
self.fire2 = FireModule(96, 16, 64, 64)
|
||||
self.fire3 = FireModule(128, 16, 64, 64)
|
||||
self.fire4 = FireModule(128, 32, 128, 128)
|
||||
self.fire5 = FireModule(256, 32, 128, 128)
|
||||
self.fire6 = FireModule(256, 48, 192, 192)
|
||||
self.fire7 = FireModule(384, 48, 192, 192)
|
||||
self.fire8 = FireModule(384, 64, 256, 256)
|
||||
self.fire9 = FireModule(512, 64, 256, 256)
|
||||
self.conv10 = nn.Conv2d(512, num_classes, 1)
|
||||
|
||||
self.feat_dim = num_classes
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x1 = F.max_pool2d(x1, 3, stride=2)
|
||||
x2 = self.fire2(x1)
|
||||
x3 = self.fire3(x2)
|
||||
if self.bypass:
|
||||
x3 = x3 + x2
|
||||
x4 = self.fire4(x3)
|
||||
x4 = F.max_pool2d(x4, 3, stride=2)
|
||||
x5 = self.fire5(x4)
|
||||
if self.bypass:
|
||||
x5 = x5 + x4
|
||||
x6 = self.fire6(x5)
|
||||
x7 = self.fire7(x6)
|
||||
if self.bypass:
|
||||
x7 = x7 + x6
|
||||
x8 = self.fire8(x7)
|
||||
x8 = F.max_pool2d(x8, 3, stride=2)
|
||||
x9 = self.fire9(x8)
|
||||
if self.bypass:
|
||||
x9 = x9 + x8
|
||||
x9 = F.dropout(x9, training=self.training)
|
||||
x10 = F.relu(self.conv10(x9))
|
||||
f = F.avg_pool2d(x10, x10.size()[2:]).view(x10.size(0), -1)
|
||||
|
||||
if not self.training:
|
||||
return f
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return f
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return f, f
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
@ -0,0 +1,199 @@
|
|||
from __future__ import absolute_import, division
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
import torchvision
|
||||
|
||||
|
||||
__all__ = ['Xception']
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
"""Basic convolutional block:
|
||||
convolution (bias discarded) + batch normalization + relu6.
|
||||
|
||||
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
|
||||
in_c (int): number of input channels.
|
||||
out_c (int): number of output channels.
|
||||
k (int or tuple): kernel size.
|
||||
s (int or tuple): stride.
|
||||
p (int or tuple): padding.
|
||||
g (int): number of blocked connections from input channels
|
||||
to output channels (default: 1).
|
||||
"""
|
||||
def __init__(self, in_c, out_c, k, s=1, p=0, g=1):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p, bias=False, groups=g)
|
||||
self.bn = nn.BatchNorm2d(out_c)
|
||||
|
||||
def forward(self, x):
|
||||
return F.relu6(self.bn(self.conv(x)))
|
||||
|
||||
|
||||
class SepConv(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(SepConv, self).__init__()
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, in_channels, 3, padding=1, groups=in_channels, bias=False),
|
||||
nn.BatchNorm2d(in_channels),
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv2d(in_channels, out_channels, 1, bias=False),
|
||||
nn.BatchNorm2d(out_channels),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.conv2(self.conv1(x))
|
||||
|
||||
|
||||
class EntryFLow(nn.Module):
|
||||
def __init__(self, nchannels):
|
||||
super(EntryFLow, self).__init__()
|
||||
self.conv1 = ConvBlock(3, nchannels[0], 3, s=2, p=1)
|
||||
self.conv2 = ConvBlock(nchannels[0], nchannels[1], 3, p=1)
|
||||
|
||||
self.conv3 = nn.Sequential(
|
||||
SepConv(nchannels[1], nchannels[2]),
|
||||
nn.ReLU(),
|
||||
SepConv(nchannels[2], nchannels[2]),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
)
|
||||
self.conv3s = nn.Sequential(
|
||||
nn.Conv2d(nchannels[1], nchannels[2], 1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(nchannels[2]),
|
||||
)
|
||||
|
||||
self.conv4 = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
SepConv(nchannels[2], nchannels[3]),
|
||||
nn.ReLU(),
|
||||
SepConv(nchannels[3], nchannels[3]),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
)
|
||||
self.conv4s = nn.Sequential(
|
||||
nn.Conv2d(nchannels[2], nchannels[3], 1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(nchannels[3])
|
||||
)
|
||||
|
||||
self.conv5 = nn.Sequential(
|
||||
nn.ReLU(),
|
||||
SepConv(nchannels[3], nchannels[4]),
|
||||
nn.ReLU(),
|
||||
SepConv(nchannels[4], nchannels[4]),
|
||||
nn.ReLU(),
|
||||
nn.MaxPool2d(3, stride=2, padding=1),
|
||||
)
|
||||
self.conv5s = nn.Sequential(
|
||||
nn.Conv2d(nchannels[3], nchannels[4], 1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(nchannels[4]),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(x)
|
||||
x2 = self.conv2(x1)
|
||||
|
||||
x3 = self.conv3(x2)
|
||||
x3s = self.conv3s(x2)
|
||||
x3 = x3 + x3s
|
||||
|
||||
x4 = self.conv4(x3)
|
||||
x4s = self.conv4s(x3)
|
||||
x4 = x4 + x4s
|
||||
|
||||
x5 = self.conv5(x4)
|
||||
x5s = self.conv5s(x4)
|
||||
x5 = x5 + x5s
|
||||
|
||||
return x5
|
||||
|
||||
|
||||
class MidFlowBlock(nn.Module):
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(MidFlowBlock, self).__init__()
|
||||
self.conv1 = SepConv(in_channels, out_channels)
|
||||
self.conv2 = SepConv(out_channels, out_channels)
|
||||
self.conv3 = SepConv(out_channels, out_channels)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(F.relu(x))
|
||||
x = self.conv2(F.relu(x))
|
||||
x = self.conv3(F.relu(x))
|
||||
return x
|
||||
|
||||
|
||||
class MidFlow(nn.Module):
|
||||
def __init__(self, in_channels, out_channels, num_repeats):
|
||||
super(MidFlow, self).__init__()
|
||||
self.blocks = self._make_layer(in_channels, out_channels, num_repeats)
|
||||
|
||||
def _make_layer(self, in_channels, out_channels, num):
|
||||
layers = []
|
||||
for i in range(num):
|
||||
layers.append(MidFlowBlock(in_channels, out_channels))
|
||||
in_channels = out_channels
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
return self.blocks(x)
|
||||
|
||||
|
||||
class ExitFlow(nn.Module):
|
||||
def __init__(self, in_channels, nchannels):
|
||||
super(ExitFlow, self).__init__()
|
||||
self.conv1 = SepConv(in_channels, nchannels[0])
|
||||
self.conv2 = SepConv(nchannels[0], nchannels[1])
|
||||
self.conv2s = nn.Sequential(
|
||||
nn.Conv2d(in_channels, nchannels[1], 1, stride=2, bias=False),
|
||||
nn.BatchNorm2d(nchannels[1]),
|
||||
)
|
||||
|
||||
self.conv3 = SepConv(nchannels[1], nchannels[2])
|
||||
self.conv4 = SepConv(nchannels[2], nchannels[3])
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.conv1(F.relu(x))
|
||||
x2 = self.conv2(F.relu(x1))
|
||||
x2 = F.max_pool2d(x2, 3, stride=2, padding=1)
|
||||
x2s = self.conv2s(x)
|
||||
x2 = x2 + x2s
|
||||
x3 = F.relu(self.conv3(x2))
|
||||
x4 = F.relu(self.conv4(x3))
|
||||
x4 = F.avg_pool2d(x4, x4.size()[2:]).view(x4.size(0), -1)
|
||||
return x4
|
||||
|
||||
|
||||
class Xception(nn.Module):
|
||||
"""Xception
|
||||
|
||||
Reference:
|
||||
Chollet. Xception: Deep Learning with Depthwise Separable Convolutions. CVPR 2017.
|
||||
"""
|
||||
def __init__(self, num_classes, loss={'xent'}, num_mid_flows=8, **kwargs):
|
||||
super(Xception, self).__init__()
|
||||
self.loss = loss
|
||||
|
||||
self.entryflow = EntryFLow(nchannels=[32, 64, 128, 256, 728])
|
||||
self.midflow = MidFlow(728, 728, 8)
|
||||
self.exitflow = ExitFlow(728, nchannels=[728, 1024, 1536, 2048])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048
|
||||
|
||||
def forward(self, x):
|
||||
x = self.entryflow(x)
|
||||
x = self.midflow(x)
|
||||
x = self.exitflow(x)
|
||||
|
||||
if not self.training:
|
||||
return x
|
||||
|
||||
y = self.classifier(x)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, x
|
||||
else:
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
Loading…
Reference in New Issue