229 lines
7.3 KiB
Python
229 lines
7.3 KiB
Python
from __future__ import absolute_import
|
|
|
|
import torch
|
|
from torch import nn
|
|
from torch.nn import functional as F
|
|
from torch.autograd import Variable
|
|
import torchvision
|
|
|
|
__all__ = ['HACNN']
|
|
|
|
class ConvBlock(nn.Module):
|
|
"""Basic convolutional block:
|
|
convolution + batch normalization + relu.
|
|
|
|
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
|
|
in_c (int): number of input channels.
|
|
out_c (int): number of output channels.
|
|
k (int or tuple): kernel size.
|
|
s (int or tuple): stride.
|
|
p (int or tuple): padding.
|
|
"""
|
|
def __init__(self, in_c, out_c, k, s=1, p=0):
|
|
super(ConvBlock, self).__init__()
|
|
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
|
|
self.bn = nn.BatchNorm2d(out_c)
|
|
|
|
def forward(self, x):
|
|
return F.relu(self.bn(self.conv(x)))
|
|
|
|
class InceptionA(nn.Module):
|
|
"""InceptionA (https://github.com/Cysu/dgd_person_reid)"""
|
|
def __init__(self, in_channels, out_channels):
|
|
super(InceptionA, self).__init__()
|
|
self.stream1 = ConvBlock(in_channels, out_channels, 1)
|
|
self.stream2 = nn.Sequential(
|
|
ConvBlock(in_channels, out_channels, 1),
|
|
ConvBlock(out_channels, out_channels, 3, p=1),
|
|
)
|
|
self.stream3 = nn.Sequential(
|
|
ConvBlock(in_channels, out_channels, 1),
|
|
ConvBlock(out_channels, out_channels, 3, p=1),
|
|
ConvBlock(out_channels, out_channels, 3, p=1),
|
|
)
|
|
self.stream4 = nn.Sequential(
|
|
nn.AvgPool2d(3, stride=1, padding=1),
|
|
ConvBlock(in_channels, out_channels, 1),
|
|
)
|
|
|
|
def forward(self, x):
|
|
s1 = self.stream1(x)
|
|
s2 = self.stream2(x)
|
|
s3 = self.stream3(x)
|
|
s4 = self.stream4(x)
|
|
y = torch.cat([s1, s2, s3, s4], dim=1)
|
|
return y
|
|
|
|
class InceptionB(nn.Module):
|
|
"""InceptionB (https://github.com/Cysu/dgd_person_reid)"""
|
|
def __init__(self, in_channels, out_channels):
|
|
super(InceptionB, self).__init__()
|
|
self.stream1 = nn.Sequential(
|
|
ConvBlock(in_channels, out_channels, 1),
|
|
ConvBlock(out_channels, out_channels, 3, s=2, p=1),
|
|
)
|
|
self.stream2 = nn.Sequential(
|
|
ConvBlock(in_channels, out_channels, 1),
|
|
ConvBlock(out_channels, out_channels, 3, p=1),
|
|
ConvBlock(out_channels, out_channels, 3, s=2, p=1),
|
|
)
|
|
self.stream3 = nn.MaxPool2d(3, stride=2, padding=1)
|
|
|
|
def forward(self, x):
|
|
s1 = self.stream1(x)
|
|
s2 = self.stream2(x)
|
|
s3 = self.stream3(x)
|
|
y = torch.cat([s1, s2, s3], dim=1)
|
|
return y
|
|
|
|
class SpatialAttn(nn.Module):
|
|
"""Spatial Attention (Sec. 3.1.I.1)"""
|
|
def __init__(self):
|
|
super(SpatialAttn, self).__init__()
|
|
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
|
|
self.conv2 = ConvBlock(1, 1, 1)
|
|
|
|
def forward(self, x):
|
|
# global cross-channel averaging
|
|
x = x.mean(1, keepdim=True)
|
|
# 3-by-3 conv
|
|
x = self.conv1(x)
|
|
# bilinear resizing
|
|
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear')
|
|
# scaling conv
|
|
x = self.conv2(x)
|
|
return x
|
|
|
|
class ChannelAttn(nn.Module):
|
|
"""Channel Attention (Sec. 3.1.I.2)"""
|
|
def __init__(self, in_channels, reduction_rate=16):
|
|
super(ChannelAttn, self).__init__()
|
|
assert in_channels%reduction_rate == 0
|
|
self.conv1 = ConvBlock(in_channels, in_channels/reduction_rate, 1)
|
|
self.conv2 = ConvBlock(in_channels/reduction_rate, in_channels, 1)
|
|
|
|
def forward(self, x):
|
|
# squeeze operation (global average pooling)
|
|
x = F.avg_pool2d(x, x.size()[2:])
|
|
# excitation operation (2 conv layers)
|
|
x = self.conv2(self.conv1(x))
|
|
return x
|
|
|
|
class SoftAttn(nn.Module):
|
|
"""Soft Attention (Sec. 3.1.I)
|
|
Aim: Spatial Attention + Channel Attention
|
|
Output: attention maps with shape identical to input.
|
|
"""
|
|
def __init__(self, in_channels):
|
|
super(SoftAttn, self).__init__()
|
|
self.spatial_attn = SpatialAttn()
|
|
self.channel_attn = ChannelAttn(in_channels)
|
|
self.conv = ConvBlock(in_channels, in_channels, 1)
|
|
|
|
def forward(self, x):
|
|
y_spatial = self.spatial_attn(x)
|
|
y_channel = self.channel_attn(x)
|
|
y = y_spatial * y_channel
|
|
y = F.sigmoid(self.conv(y))
|
|
return y
|
|
|
|
class HardAttn(nn.Module):
|
|
"""Hard Attention (Sec. 3.1.II)"""
|
|
def __init__(self):
|
|
super(HardAttn, self).__init__()
|
|
|
|
def forward(self, x):
|
|
raise NotImplementedError
|
|
|
|
class HarmAttn(nn.Module):
|
|
"""Harmonious Attention (Sec. 3.1)"""
|
|
def __init__(self, in_channels):
|
|
super(HarmAttn, self).__init__()
|
|
self.soft_attn = SoftAttn(in_channels)
|
|
|
|
def forward(self, x):
|
|
y_soft_attn = self.soft_attn(x)
|
|
return y_soft_attn
|
|
|
|
class HACNN(nn.Module):
|
|
"""
|
|
Harmonious Attention Convolutional Neural Network
|
|
|
|
Reference:
|
|
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
|
|
"""
|
|
def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
|
|
super(HACNN, self).__init__()
|
|
self.loss = loss
|
|
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
|
|
|
|
# construct Inception + HarmAttn blocks
|
|
# output channel of InceptionA is out_channels*4
|
|
# output channel of InceptionB is out_channels*2+in_channels
|
|
self.inception1 = nn.Sequential(
|
|
InceptionA(32, nchannels[0]),
|
|
InceptionB(nchannels[0]*4, nchannels[0]),
|
|
)
|
|
self.ha1 = HarmAttn(nchannels[0]*6)
|
|
|
|
self.inception2 = nn.Sequential(
|
|
InceptionA(nchannels[0]*6, nchannels[1]),
|
|
InceptionB(nchannels[1]*4, nchannels[1]),
|
|
)
|
|
self.ha2 = HarmAttn(nchannels[1]*6)
|
|
|
|
self.inception3 = nn.Sequential(
|
|
InceptionA(nchannels[1]*6, nchannels[2]),
|
|
InceptionB(nchannels[2]*4, nchannels[2]),
|
|
)
|
|
self.ha3 = HarmAttn(nchannels[2]*6)
|
|
|
|
self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU())
|
|
|
|
self.classifier = nn.Linear(feat_dim, num_classes)
|
|
self.feat_dim = feat_dim
|
|
|
|
def forward(self, x):
|
|
# input size (3, 160, 64)
|
|
x = self.conv(x)
|
|
|
|
# block 1
|
|
x1 = self.inception1(x)
|
|
x1_attn = self.ha1(x1)
|
|
x1_out = x1 * x1_attn
|
|
|
|
# block 2
|
|
x2 = self.inception2(x1_out)
|
|
x2_attn = self.ha2(x2)
|
|
x2_out = x2 * x2_attn
|
|
|
|
# block 3
|
|
x3 = self.inception3(x2_out)
|
|
x3_attn = self.ha3(x3)
|
|
x3_out = x3 * x3_attn
|
|
|
|
x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1))
|
|
x_global = self.fc_global(x_global)
|
|
|
|
if not self.training:
|
|
return x_global
|
|
|
|
prelogits = self.classifier(x_global)
|
|
|
|
if self.loss == {'xent'}:
|
|
return prelogits
|
|
elif self.loss == {'xent', 'htri'}:
|
|
return prelogits, x_global
|
|
elif self.loss == {'cent'}:
|
|
return prelogits, x_global
|
|
else:
|
|
raise KeyError("Unsupported loss: {}".format(self.loss))
|
|
|
|
if __name__ == '__main__':
|
|
import sys
|
|
model = HACNN(10)
|
|
model.eval()
|
|
x = Variable(torch.rand(5, 3, 160, 64))
|
|
print "input size {}".format(x.size())
|
|
y = model(x)
|
|
print "output size {}".format(y.size()) |