update model
parent
6f479d39ee
commit
4435fcd312
128
models/HACNN.py
128
models/HACNN.py
|
@ -3,7 +3,6 @@ from __future__ import absolute_import
|
||||||
import torch
|
import torch
|
||||||
from torch import nn
|
from torch import nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
from torch.autograd import Variable
|
|
||||||
import torchvision
|
import torchvision
|
||||||
|
|
||||||
__all__ = ['HACNN']
|
__all__ = ['HACNN']
|
||||||
|
@ -28,7 +27,6 @@ class ConvBlock(nn.Module):
|
||||||
return F.relu(self.bn(self.conv(x)))
|
return F.relu(self.bn(self.conv(x)))
|
||||||
|
|
||||||
class InceptionA(nn.Module):
|
class InceptionA(nn.Module):
|
||||||
"""InceptionA (https://github.com/Cysu/dgd_person_reid)"""
|
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super(InceptionA, self).__init__()
|
super(InceptionA, self).__init__()
|
||||||
self.stream1 = ConvBlock(in_channels, out_channels, 1)
|
self.stream1 = ConvBlock(in_channels, out_channels, 1)
|
||||||
|
@ -55,7 +53,6 @@ class InceptionA(nn.Module):
|
||||||
return y
|
return y
|
||||||
|
|
||||||
class InceptionB(nn.Module):
|
class InceptionB(nn.Module):
|
||||||
"""InceptionB (https://github.com/Cysu/dgd_person_reid)"""
|
|
||||||
def __init__(self, in_channels, out_channels):
|
def __init__(self, in_channels, out_channels):
|
||||||
super(InceptionB, self).__init__()
|
super(InceptionB, self).__init__()
|
||||||
self.stream1 = nn.Sequential(
|
self.stream1 = nn.Sequential(
|
||||||
|
@ -89,7 +86,7 @@ class SpatialAttn(nn.Module):
|
||||||
# 3-by-3 conv
|
# 3-by-3 conv
|
||||||
x = self.conv1(x)
|
x = self.conv1(x)
|
||||||
# bilinear resizing
|
# bilinear resizing
|
||||||
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear')
|
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear', align_corners=True)
|
||||||
# scaling conv
|
# scaling conv
|
||||||
x = self.conv2(x)
|
x = self.conv2(x)
|
||||||
return x
|
return x
|
||||||
|
@ -128,22 +125,37 @@ class SoftAttn(nn.Module):
|
||||||
return y
|
return y
|
||||||
|
|
||||||
class HardAttn(nn.Module):
|
class HardAttn(nn.Module):
|
||||||
"""Hard Attention (Sec. 3.1.II)"""
|
"""Hard Attention (Sec. 3.1.II)
|
||||||
def __init__(self):
|
Output: num_regions*2 transformation parameters (i.e. t_x, t_y).
|
||||||
|
"""
|
||||||
|
def __init__(self, in_channels, num_regions):
|
||||||
super(HardAttn, self).__init__()
|
super(HardAttn, self).__init__()
|
||||||
|
self.fc = nn.Linear(in_channels, num_regions*2)
|
||||||
|
self.num_regions = num_regions
|
||||||
|
|
||||||
|
def init_params(self):
|
||||||
|
self.fc.weight.data.zero_()
|
||||||
|
# TODO
|
||||||
|
#self.fc.bias.data.copy_(torch.tensor([BLAH BLAH], dtype=torch.float))
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
raise NotImplementedError
|
# squeeze operation (global average pooling)
|
||||||
|
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
|
||||||
|
theta = F.tanh(self.fc(x))
|
||||||
|
theta = theta.view(-1, self.num_regions, 2)
|
||||||
|
return theta
|
||||||
|
|
||||||
class HarmAttn(nn.Module):
|
class HarmAttn(nn.Module):
|
||||||
"""Harmonious Attention (Sec. 3.1)"""
|
"""Harmonious Attention (Sec. 3.1)"""
|
||||||
def __init__(self, in_channels):
|
def __init__(self, in_channels, num_regions):
|
||||||
super(HarmAttn, self).__init__()
|
super(HarmAttn, self).__init__()
|
||||||
self.soft_attn = SoftAttn(in_channels)
|
self.soft_attn = SoftAttn(in_channels)
|
||||||
|
self.hard_attn = HardAttn(in_channels, num_regions)
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
y_soft_attn = self.soft_attn(x)
|
y_soft_attn = self.soft_attn(x)
|
||||||
return y_soft_attn
|
theta = self.hard_attn(x)
|
||||||
|
return y_soft_attn, theta
|
||||||
|
|
||||||
class HACNN(nn.Module):
|
class HACNN(nn.Module):
|
||||||
"""
|
"""
|
||||||
|
@ -152,55 +164,117 @@ class HACNN(nn.Module):
|
||||||
Reference:
|
Reference:
|
||||||
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
|
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
|
||||||
"""
|
"""
|
||||||
def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
|
def __init__(self, num_classes, loss={'xent'}, num_regions=4, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
|
||||||
super(HACNN, self).__init__()
|
super(HACNN, self).__init__()
|
||||||
self.loss = loss
|
self.loss = loss
|
||||||
|
self.num_regions = num_regions
|
||||||
|
self.init_scale_factors(num_regions)
|
||||||
|
|
||||||
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
|
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
|
||||||
|
|
||||||
# construct Inception + HarmAttn blocks
|
# construct Inception + HarmAttn blocks
|
||||||
# output channel of InceptionA is out_channels*4
|
# output channel of InceptionA is out_channels*4
|
||||||
# output channel of InceptionB is out_channels*2+in_channels
|
# output channel of InceptionB is out_channels*2+in_channels
|
||||||
|
# ============== Block 1 ==============
|
||||||
self.inception1 = nn.Sequential(
|
self.inception1 = nn.Sequential(
|
||||||
InceptionA(32, nchannels[0]),
|
InceptionA(32, nchannels[0]),
|
||||||
InceptionB(nchannels[0]*4, nchannels[0]),
|
InceptionB(nchannels[0]*4, nchannels[0]),
|
||||||
)
|
)
|
||||||
self.ha1 = HarmAttn(nchannels[0]*6)
|
self.ha1 = HarmAttn(nchannels[0]*6, num_regions)
|
||||||
|
self.local_conv1 = InceptionB(32, nchannels[0])
|
||||||
|
|
||||||
|
# ============== Block 2 ==============
|
||||||
self.inception2 = nn.Sequential(
|
self.inception2 = nn.Sequential(
|
||||||
InceptionA(nchannels[0]*6, nchannels[1]),
|
InceptionA(nchannels[0]*6, nchannels[1]),
|
||||||
InceptionB(nchannels[1]*4, nchannels[1]),
|
InceptionB(nchannels[1]*4, nchannels[1]),
|
||||||
)
|
)
|
||||||
self.ha2 = HarmAttn(nchannels[1]*6)
|
self.ha2 = HarmAttn(nchannels[1]*6, num_regions)
|
||||||
|
self.local_conv2 = InceptionB(nchannels[0]*2+32, nchannels[1])
|
||||||
|
|
||||||
|
# ============== Block 3 ==============
|
||||||
self.inception3 = nn.Sequential(
|
self.inception3 = nn.Sequential(
|
||||||
InceptionA(nchannels[1]*6, nchannels[2]),
|
InceptionA(nchannels[1]*6, nchannels[2]),
|
||||||
InceptionB(nchannels[2]*4, nchannels[2]),
|
InceptionB(nchannels[2]*4, nchannels[2]),
|
||||||
)
|
)
|
||||||
self.ha3 = HarmAttn(nchannels[2]*6)
|
self.ha3 = HarmAttn(nchannels[2]*6, num_regions)
|
||||||
|
self.local_conv3 = InceptionB(nchannels[1]*2+nchannels[0]*2+32, nchannels[2])
|
||||||
|
|
||||||
self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU())
|
# feature embedding layers
|
||||||
|
self.fc_global = nn.Sequential(
|
||||||
|
nn.Linear(nchannels[2]*6, feat_dim),
|
||||||
|
nn.BatchNorm1d(feat_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
self.fc_local = nn.Sequential(
|
||||||
|
nn.Linear((nchannels[2]*2+nchannels[1]*2+nchannels[0]*2+32)*num_regions, feat_dim),
|
||||||
|
nn.BatchNorm1d(feat_dim),
|
||||||
|
nn.ReLU(),
|
||||||
|
)
|
||||||
|
|
||||||
self.classifier = nn.Linear(feat_dim, num_classes)
|
self.classifier_global = nn.Linear(feat_dim, num_classes)
|
||||||
|
self.classifier_local = nn.Linear(feat_dim, num_classes)
|
||||||
self.feat_dim = feat_dim
|
self.feat_dim = feat_dim
|
||||||
|
|
||||||
|
def init_scale_factors(self, num_regions):
|
||||||
|
self.scale_factors = []
|
||||||
|
for region_idx in range(num_regions):
|
||||||
|
# TODO: initialize scale factors
|
||||||
|
scale_factors = torch.tensor([[1, 0], [0, 1]]).float()
|
||||||
|
self.scale_factors.append(scale_factors)
|
||||||
|
|
||||||
|
def stn(self, x, theta):
|
||||||
|
"""Perform spatial transform
|
||||||
|
x: (batch, channel, height, width)
|
||||||
|
theta: (batch, 2, 3)
|
||||||
|
"""
|
||||||
|
grid = F.affine_grid(theta, x.size())
|
||||||
|
x = F.grid_sample(x, grid)
|
||||||
|
return x
|
||||||
|
|
||||||
|
def transform_theta(self, theta_i, region_idx):
|
||||||
|
"""Transform theta from (batch, 2) to (batch, 2, 3),
|
||||||
|
which includes (s_w, s_h)"""
|
||||||
|
scale_factors = self.scale_factors[region_idx]
|
||||||
|
theta = torch.zeros(theta_i.size(0), 2, 3)
|
||||||
|
theta[:,:,:2] = scale_factors
|
||||||
|
theta[:,:,-1] = theta_i
|
||||||
|
return theta
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
# input size (3, 160, 64)
|
# input size (3, 160, 64)
|
||||||
x = self.conv(x)
|
x = self.conv(x)
|
||||||
|
|
||||||
# block 1
|
# ============== Block 1 ==============
|
||||||
|
# global branch
|
||||||
x1 = self.inception1(x)
|
x1 = self.inception1(x)
|
||||||
x1_attn = self.ha1(x1)
|
x1_attn, x1_theta = self.ha1(x1)
|
||||||
x1_out = x1 * x1_attn
|
x1_out = x1 * x1_attn
|
||||||
|
# local branch
|
||||||
|
x1_local = []
|
||||||
|
for region_idx in range(self.num_regions):
|
||||||
|
x1_theta_i = x1_theta[:,region_idx,:]
|
||||||
|
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
|
||||||
|
x1_trans_i = self.stn(x, x1_theta_i)
|
||||||
|
# TODO: reduce size of x1_trans_i to (24, 28)
|
||||||
|
sys.exit()
|
||||||
|
x1_local_i = self.local_conv1(x1_trans_i)
|
||||||
|
x1_local.append(x1_local_i)
|
||||||
|
|
||||||
# block 2
|
# ============== Block 2 ==============
|
||||||
|
# Block 2
|
||||||
|
# global branch
|
||||||
x2 = self.inception2(x1_out)
|
x2 = self.inception2(x1_out)
|
||||||
x2_attn = self.ha2(x2)
|
x2_attn, x2_theta = self.ha2(x2)
|
||||||
x2_out = x2 * x2_attn
|
x2_out = x2 * x2_attn
|
||||||
|
# local branch
|
||||||
|
|
||||||
# block 3
|
# ============== Block 3 ==============
|
||||||
|
# Block 3
|
||||||
|
# global branch
|
||||||
x3 = self.inception3(x2_out)
|
x3 = self.inception3(x2_out)
|
||||||
x3_attn = self.ha3(x3)
|
x3_attn, x3_theta = self.ha3(x3)
|
||||||
x3_out = x3 * x3_attn
|
x3_out = x3 * x3_attn
|
||||||
|
# local branch
|
||||||
|
|
||||||
x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1))
|
x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1))
|
||||||
x_global = self.fc_global(x_global)
|
x_global = self.fc_global(x_global)
|
||||||
|
@ -208,7 +282,7 @@ class HACNN(nn.Module):
|
||||||
if not self.training:
|
if not self.training:
|
||||||
return x_global
|
return x_global
|
||||||
|
|
||||||
prelogits = self.classifier(x_global)
|
prelogits = self.classifier_global(x_global)
|
||||||
|
|
||||||
if self.loss == {'xent'}:
|
if self.loss == {'xent'}:
|
||||||
return prelogits
|
return prelogits
|
||||||
|
@ -220,10 +294,4 @@ class HACNN(nn.Module):
|
||||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
import sys
|
pass
|
||||||
model = HACNN(10)
|
|
||||||
model.eval()
|
|
||||||
x = Variable(torch.rand(5, 3, 160, 64))
|
|
||||||
print "input size {}".format(x.size())
|
|
||||||
y = model(x)
|
|
||||||
print "output size {}".format(y.size())
|
|
Loading…
Reference in New Issue