deep-person-reid/models/HACNN.py

229 lines
7.3 KiB
Python

from __future__ import absolute_import
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torchvision
__all__ = ['HACNN']
class ConvBlock(nn.Module):
"""Basic convolutional block:
convolution + batch normalization + relu.
Args (following http://pytorch.org/docs/master/nn.html#torch.nn.Conv2d):
in_c (int): number of input channels.
out_c (int): number of output channels.
k (int or tuple): kernel size.
s (int or tuple): stride.
p (int or tuple): padding.
"""
def __init__(self, in_c, out_c, k, s=1, p=0):
super(ConvBlock, self).__init__()
self.conv = nn.Conv2d(in_c, out_c, k, stride=s, padding=p)
self.bn = nn.BatchNorm2d(out_c)
def forward(self, x):
return F.relu(self.bn(self.conv(x)))
class InceptionA(nn.Module):
"""InceptionA (https://github.com/Cysu/dgd_person_reid)"""
def __init__(self, in_channels, out_channels):
super(InceptionA, self).__init__()
self.stream1 = ConvBlock(in_channels, out_channels, 1)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, out_channels, 1),
ConvBlock(out_channels, out_channels, 3, p=1),
)
self.stream3 = nn.Sequential(
ConvBlock(in_channels, out_channels, 1),
ConvBlock(out_channels, out_channels, 3, p=1),
ConvBlock(out_channels, out_channels, 3, p=1),
)
self.stream4 = nn.Sequential(
nn.AvgPool2d(3, stride=1, padding=1),
ConvBlock(in_channels, out_channels, 1),
)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
s4 = self.stream4(x)
y = torch.cat([s1, s2, s3, s4], dim=1)
return y
class InceptionB(nn.Module):
"""InceptionB (https://github.com/Cysu/dgd_person_reid)"""
def __init__(self, in_channels, out_channels):
super(InceptionB, self).__init__()
self.stream1 = nn.Sequential(
ConvBlock(in_channels, out_channels, 1),
ConvBlock(out_channels, out_channels, 3, s=2, p=1),
)
self.stream2 = nn.Sequential(
ConvBlock(in_channels, out_channels, 1),
ConvBlock(out_channels, out_channels, 3, p=1),
ConvBlock(out_channels, out_channels, 3, s=2, p=1),
)
self.stream3 = nn.MaxPool2d(3, stride=2, padding=1)
def forward(self, x):
s1 = self.stream1(x)
s2 = self.stream2(x)
s3 = self.stream3(x)
y = torch.cat([s1, s2, s3], dim=1)
return y
class SpatialAttn(nn.Module):
"""Spatial Attention (Sec. 3.1.I.1)"""
def __init__(self):
super(SpatialAttn, self).__init__()
self.conv1 = ConvBlock(1, 1, 3, s=2, p=1)
self.conv2 = ConvBlock(1, 1, 1)
def forward(self, x):
# global cross-channel averaging
x = x.mean(1, keepdim=True)
# 3-by-3 conv
x = self.conv1(x)
# bilinear resizing
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear')
# scaling conv
x = self.conv2(x)
return x
class ChannelAttn(nn.Module):
"""Channel Attention (Sec. 3.1.I.2)"""
def __init__(self, in_channels, reduction_rate=16):
super(ChannelAttn, self).__init__()
assert in_channels%reduction_rate == 0
self.conv1 = ConvBlock(in_channels, in_channels/reduction_rate, 1)
self.conv2 = ConvBlock(in_channels/reduction_rate, in_channels, 1)
def forward(self, x):
# squeeze operation (global average pooling)
x = F.avg_pool2d(x, x.size()[2:])
# excitation operation (2 conv layers)
x = self.conv2(self.conv1(x))
return x
class SoftAttn(nn.Module):
"""Soft Attention (Sec. 3.1.I)
Aim: Spatial Attention + Channel Attention
Output: attention maps with shape identical to input.
"""
def __init__(self, in_channels):
super(SoftAttn, self).__init__()
self.spatial_attn = SpatialAttn()
self.channel_attn = ChannelAttn(in_channels)
self.conv = ConvBlock(in_channels, in_channels, 1)
def forward(self, x):
y_spatial = self.spatial_attn(x)
y_channel = self.channel_attn(x)
y = y_spatial * y_channel
y = F.sigmoid(self.conv(y))
return y
class HardAttn(nn.Module):
"""Hard Attention (Sec. 3.1.II)"""
def __init__(self):
super(HardAttn, self).__init__()
def forward(self, x):
raise NotImplementedError
class HarmAttn(nn.Module):
"""Harmonious Attention (Sec. 3.1)"""
def __init__(self, in_channels):
super(HarmAttn, self).__init__()
self.soft_attn = SoftAttn(in_channels)
def forward(self, x):
y_soft_attn = self.soft_attn(x)
return y_soft_attn
class HACNN(nn.Module):
"""
Harmonious Attention Convolutional Neural Network
Reference:
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
"""
def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
super(HACNN, self).__init__()
self.loss = loss
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
# construct Inception + HarmAttn blocks
# output channel of InceptionA is out_channels*4
# output channel of InceptionB is out_channels*2+in_channels
self.inception1 = nn.Sequential(
InceptionA(32, nchannels[0]),
InceptionB(nchannels[0]*4, nchannels[0]),
)
self.ha1 = HarmAttn(nchannels[0]*6)
self.inception2 = nn.Sequential(
InceptionA(nchannels[0]*6, nchannels[1]),
InceptionB(nchannels[1]*4, nchannels[1]),
)
self.ha2 = HarmAttn(nchannels[1]*6)
self.inception3 = nn.Sequential(
InceptionA(nchannels[1]*6, nchannels[2]),
InceptionB(nchannels[2]*4, nchannels[2]),
)
self.ha3 = HarmAttn(nchannels[2]*6)
self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU())
self.classifier = nn.Linear(feat_dim, num_classes)
self.feat_dim = feat_dim
def forward(self, x):
# input size (3, 160, 64)
x = self.conv(x)
# block 1
x1 = self.inception1(x)
x1_attn = self.ha1(x1)
x1_out = x1 * x1_attn
# block 2
x2 = self.inception2(x1_out)
x2_attn = self.ha2(x2)
x2_out = x2 * x2_attn
# block 3
x3 = self.inception3(x2_out)
x3_attn = self.ha3(x3)
x3_out = x3 * x3_attn
x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1))
x_global = self.fc_global(x_global)
if not self.training:
return x_global
prelogits = self.classifier(x_global)
if self.loss == {'xent'}:
return prelogits
elif self.loss == {'xent', 'htri'}:
return prelogits, x_global
elif self.loss == {'cent'}:
return prelogits, x_global
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
if __name__ == '__main__':
import sys
model = HACNN(10)
model.eval()
x = Variable(torch.rand(5, 3, 160, 64))
print "input size {}".format(x.size())
y = model(x)
print "output size {}".format(y.size())