update model
parent
6f479d39ee
commit
4435fcd312
126
models/HACNN.py
126
models/HACNN.py
|
@ -3,7 +3,6 @@ from __future__ import absolute_import
|
|||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
from torch.autograd import Variable
|
||||
import torchvision
|
||||
|
||||
__all__ = ['HACNN']
|
||||
|
@ -28,7 +27,6 @@ class ConvBlock(nn.Module):
|
|||
return F.relu(self.bn(self.conv(x)))
|
||||
|
||||
class InceptionA(nn.Module):
|
||||
"""InceptionA (https://github.com/Cysu/dgd_person_reid)"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(InceptionA, self).__init__()
|
||||
self.stream1 = ConvBlock(in_channels, out_channels, 1)
|
||||
|
@ -55,7 +53,6 @@ class InceptionA(nn.Module):
|
|||
return y
|
||||
|
||||
class InceptionB(nn.Module):
|
||||
"""InceptionB (https://github.com/Cysu/dgd_person_reid)"""
|
||||
def __init__(self, in_channels, out_channels):
|
||||
super(InceptionB, self).__init__()
|
||||
self.stream1 = nn.Sequential(
|
||||
|
@ -89,7 +86,7 @@ class SpatialAttn(nn.Module):
|
|||
# 3-by-3 conv
|
||||
x = self.conv1(x)
|
||||
# bilinear resizing
|
||||
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear')
|
||||
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear', align_corners=True)
|
||||
# scaling conv
|
||||
x = self.conv2(x)
|
||||
return x
|
||||
|
@ -128,22 +125,37 @@ class SoftAttn(nn.Module):
|
|||
return y
|
||||
|
||||
class HardAttn(nn.Module):
|
||||
"""Hard Attention (Sec. 3.1.II)"""
|
||||
def __init__(self):
|
||||
"""Hard Attention (Sec. 3.1.II)
|
||||
Output: num_regions*2 transformation parameters (i.e. t_x, t_y).
|
||||
"""
|
||||
def __init__(self, in_channels, num_regions):
|
||||
super(HardAttn, self).__init__()
|
||||
self.fc = nn.Linear(in_channels, num_regions*2)
|
||||
self.num_regions = num_regions
|
||||
|
||||
def init_params(self):
|
||||
self.fc.weight.data.zero_()
|
||||
# TODO
|
||||
#self.fc.bias.data.copy_(torch.tensor([BLAH BLAH], dtype=torch.float))
|
||||
|
||||
def forward(self, x):
|
||||
raise NotImplementedError
|
||||
# squeeze operation (global average pooling)
|
||||
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
|
||||
theta = F.tanh(self.fc(x))
|
||||
theta = theta.view(-1, self.num_regions, 2)
|
||||
return theta
|
||||
|
||||
class HarmAttn(nn.Module):
|
||||
"""Harmonious Attention (Sec. 3.1)"""
|
||||
def __init__(self, in_channels):
|
||||
def __init__(self, in_channels, num_regions):
|
||||
super(HarmAttn, self).__init__()
|
||||
self.soft_attn = SoftAttn(in_channels)
|
||||
self.hard_attn = HardAttn(in_channels, num_regions)
|
||||
|
||||
def forward(self, x):
|
||||
y_soft_attn = self.soft_attn(x)
|
||||
return y_soft_attn
|
||||
theta = self.hard_attn(x)
|
||||
return y_soft_attn, theta
|
||||
|
||||
class HACNN(nn.Module):
|
||||
"""
|
||||
|
@ -152,55 +164,117 @@ class HACNN(nn.Module):
|
|||
Reference:
|
||||
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
|
||||
"""
|
||||
def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
|
||||
def __init__(self, num_classes, loss={'xent'}, num_regions=4, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
|
||||
super(HACNN, self).__init__()
|
||||
self.loss = loss
|
||||
self.num_regions = num_regions
|
||||
self.init_scale_factors(num_regions)
|
||||
|
||||
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
|
||||
|
||||
# construct Inception + HarmAttn blocks
|
||||
# output channel of InceptionA is out_channels*4
|
||||
# output channel of InceptionB is out_channels*2+in_channels
|
||||
# ============== Block 1 ==============
|
||||
self.inception1 = nn.Sequential(
|
||||
InceptionA(32, nchannels[0]),
|
||||
InceptionB(nchannels[0]*4, nchannels[0]),
|
||||
)
|
||||
self.ha1 = HarmAttn(nchannels[0]*6)
|
||||
self.ha1 = HarmAttn(nchannels[0]*6, num_regions)
|
||||
self.local_conv1 = InceptionB(32, nchannels[0])
|
||||
|
||||
# ============== Block 2 ==============
|
||||
self.inception2 = nn.Sequential(
|
||||
InceptionA(nchannels[0]*6, nchannels[1]),
|
||||
InceptionB(nchannels[1]*4, nchannels[1]),
|
||||
)
|
||||
self.ha2 = HarmAttn(nchannels[1]*6)
|
||||
self.ha2 = HarmAttn(nchannels[1]*6, num_regions)
|
||||
self.local_conv2 = InceptionB(nchannels[0]*2+32, nchannels[1])
|
||||
|
||||
# ============== Block 3 ==============
|
||||
self.inception3 = nn.Sequential(
|
||||
InceptionA(nchannels[1]*6, nchannels[2]),
|
||||
InceptionB(nchannels[2]*4, nchannels[2]),
|
||||
)
|
||||
self.ha3 = HarmAttn(nchannels[2]*6)
|
||||
self.ha3 = HarmAttn(nchannels[2]*6, num_regions)
|
||||
self.local_conv3 = InceptionB(nchannels[1]*2+nchannels[0]*2+32, nchannels[2])
|
||||
|
||||
self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU())
|
||||
# feature embedding layers
|
||||
self.fc_global = nn.Sequential(
|
||||
nn.Linear(nchannels[2]*6, feat_dim),
|
||||
nn.BatchNorm1d(feat_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fc_local = nn.Sequential(
|
||||
nn.Linear((nchannels[2]*2+nchannels[1]*2+nchannels[0]*2+32)*num_regions, feat_dim),
|
||||
nn.BatchNorm1d(feat_dim),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
self.classifier = nn.Linear(feat_dim, num_classes)
|
||||
self.classifier_global = nn.Linear(feat_dim, num_classes)
|
||||
self.classifier_local = nn.Linear(feat_dim, num_classes)
|
||||
self.feat_dim = feat_dim
|
||||
|
||||
def init_scale_factors(self, num_regions):
|
||||
self.scale_factors = []
|
||||
for region_idx in range(num_regions):
|
||||
# TODO: initialize scale factors
|
||||
scale_factors = torch.tensor([[1, 0], [0, 1]]).float()
|
||||
self.scale_factors.append(scale_factors)
|
||||
|
||||
def stn(self, x, theta):
|
||||
"""Perform spatial transform
|
||||
x: (batch, channel, height, width)
|
||||
theta: (batch, 2, 3)
|
||||
"""
|
||||
grid = F.affine_grid(theta, x.size())
|
||||
x = F.grid_sample(x, grid)
|
||||
return x
|
||||
|
||||
def transform_theta(self, theta_i, region_idx):
|
||||
"""Transform theta from (batch, 2) to (batch, 2, 3),
|
||||
which includes (s_w, s_h)"""
|
||||
scale_factors = self.scale_factors[region_idx]
|
||||
theta = torch.zeros(theta_i.size(0), 2, 3)
|
||||
theta[:,:,:2] = scale_factors
|
||||
theta[:,:,-1] = theta_i
|
||||
return theta
|
||||
|
||||
def forward(self, x):
|
||||
# input size (3, 160, 64)
|
||||
x = self.conv(x)
|
||||
|
||||
# block 1
|
||||
# ============== Block 1 ==============
|
||||
# global branch
|
||||
x1 = self.inception1(x)
|
||||
x1_attn = self.ha1(x1)
|
||||
x1_attn, x1_theta = self.ha1(x1)
|
||||
x1_out = x1 * x1_attn
|
||||
# local branch
|
||||
x1_local = []
|
||||
for region_idx in range(self.num_regions):
|
||||
x1_theta_i = x1_theta[:,region_idx,:]
|
||||
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
|
||||
x1_trans_i = self.stn(x, x1_theta_i)
|
||||
# TODO: reduce size of x1_trans_i to (24, 28)
|
||||
sys.exit()
|
||||
x1_local_i = self.local_conv1(x1_trans_i)
|
||||
x1_local.append(x1_local_i)
|
||||
|
||||
# block 2
|
||||
# ============== Block 2 ==============
|
||||
# Block 2
|
||||
# global branch
|
||||
x2 = self.inception2(x1_out)
|
||||
x2_attn = self.ha2(x2)
|
||||
x2_attn, x2_theta = self.ha2(x2)
|
||||
x2_out = x2 * x2_attn
|
||||
# local branch
|
||||
|
||||
# block 3
|
||||
# ============== Block 3 ==============
|
||||
# Block 3
|
||||
# global branch
|
||||
x3 = self.inception3(x2_out)
|
||||
x3_attn = self.ha3(x3)
|
||||
x3_attn, x3_theta = self.ha3(x3)
|
||||
x3_out = x3 * x3_attn
|
||||
# local branch
|
||||
|
||||
x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1))
|
||||
x_global = self.fc_global(x_global)
|
||||
|
@ -208,7 +282,7 @@ class HACNN(nn.Module):
|
|||
if not self.training:
|
||||
return x_global
|
||||
|
||||
prelogits = self.classifier(x_global)
|
||||
prelogits = self.classifier_global(x_global)
|
||||
|
||||
if self.loss == {'xent'}:
|
||||
return prelogits
|
||||
|
@ -220,10 +294,4 @@ class HACNN(nn.Module):
|
|||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
if __name__ == '__main__':
|
||||
import sys
|
||||
model = HACNN(10)
|
||||
model.eval()
|
||||
x = Variable(torch.rand(5, 3, 160, 64))
|
||||
print "input size {}".format(x.size())
|
||||
y = model(x)
|
||||
print "output size {}".format(y.size())
|
||||
pass
|
Loading…
Reference in New Issue