update model

pull/17/head
KaiyangZhou 2018-04-26 12:22:00 +01:00
parent 6f479d39ee
commit 4435fcd312
1 changed files with 98 additions and 30 deletions

View File

@ -3,7 +3,6 @@ from __future__ import absolute_import
import torch
from torch import nn
from torch.nn import functional as F
from torch.autograd import Variable
import torchvision
__all__ = ['HACNN']
@ -28,7 +27,6 @@ class ConvBlock(nn.Module):
return F.relu(self.bn(self.conv(x)))
class InceptionA(nn.Module):
"""InceptionA (https://github.com/Cysu/dgd_person_reid)"""
def __init__(self, in_channels, out_channels):
super(InceptionA, self).__init__()
self.stream1 = ConvBlock(in_channels, out_channels, 1)
@ -55,7 +53,6 @@ class InceptionA(nn.Module):
return y
class InceptionB(nn.Module):
"""InceptionB (https://github.com/Cysu/dgd_person_reid)"""
def __init__(self, in_channels, out_channels):
super(InceptionB, self).__init__()
self.stream1 = nn.Sequential(
@ -89,7 +86,7 @@ class SpatialAttn(nn.Module):
# 3-by-3 conv
x = self.conv1(x)
# bilinear resizing
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear')
x = F.upsample(x, (x.size(2)*2, x.size(3)*2), mode='bilinear', align_corners=True)
# scaling conv
x = self.conv2(x)
return x
@ -128,22 +125,37 @@ class SoftAttn(nn.Module):
return y
class HardAttn(nn.Module):
"""Hard Attention (Sec. 3.1.II)"""
def __init__(self):
"""Hard Attention (Sec. 3.1.II)
Output: num_regions*2 transformation parameters (i.e. t_x, t_y).
"""
def __init__(self, in_channels, num_regions):
super(HardAttn, self).__init__()
self.fc = nn.Linear(in_channels, num_regions*2)
self.num_regions = num_regions
def init_params(self):
self.fc.weight.data.zero_()
# TODO
#self.fc.bias.data.copy_(torch.tensor([BLAH BLAH], dtype=torch.float))
def forward(self, x):
raise NotImplementedError
# squeeze operation (global average pooling)
x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), x.size(1))
theta = F.tanh(self.fc(x))
theta = theta.view(-1, self.num_regions, 2)
return theta
class HarmAttn(nn.Module):
"""Harmonious Attention (Sec. 3.1)"""
def __init__(self, in_channels):
def __init__(self, in_channels, num_regions):
super(HarmAttn, self).__init__()
self.soft_attn = SoftAttn(in_channels)
self.hard_attn = HardAttn(in_channels, num_regions)
def forward(self, x):
y_soft_attn = self.soft_attn(x)
return y_soft_attn
theta = self.hard_attn(x)
return y_soft_attn, theta
class HACNN(nn.Module):
"""
@ -152,55 +164,117 @@ class HACNN(nn.Module):
Reference:
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
"""
def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
def __init__(self, num_classes, loss={'xent'}, num_regions=4, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
super(HACNN, self).__init__()
self.loss = loss
self.num_regions = num_regions
self.init_scale_factors(num_regions)
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
# construct Inception + HarmAttn blocks
# output channel of InceptionA is out_channels*4
# output channel of InceptionB is out_channels*2+in_channels
# ============== Block 1 ==============
self.inception1 = nn.Sequential(
InceptionA(32, nchannels[0]),
InceptionB(nchannels[0]*4, nchannels[0]),
)
self.ha1 = HarmAttn(nchannels[0]*6)
self.ha1 = HarmAttn(nchannels[0]*6, num_regions)
self.local_conv1 = InceptionB(32, nchannels[0])
# ============== Block 2 ==============
self.inception2 = nn.Sequential(
InceptionA(nchannels[0]*6, nchannels[1]),
InceptionB(nchannels[1]*4, nchannels[1]),
)
self.ha2 = HarmAttn(nchannels[1]*6)
self.ha2 = HarmAttn(nchannels[1]*6, num_regions)
self.local_conv2 = InceptionB(nchannels[0]*2+32, nchannels[1])
# ============== Block 3 ==============
self.inception3 = nn.Sequential(
InceptionA(nchannels[1]*6, nchannels[2]),
InceptionB(nchannels[2]*4, nchannels[2]),
)
self.ha3 = HarmAttn(nchannels[2]*6)
self.ha3 = HarmAttn(nchannels[2]*6, num_regions)
self.local_conv3 = InceptionB(nchannels[1]*2+nchannels[0]*2+32, nchannels[2])
self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU())
# feature embedding layers
self.fc_global = nn.Sequential(
nn.Linear(nchannels[2]*6, feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.fc_local = nn.Sequential(
nn.Linear((nchannels[2]*2+nchannels[1]*2+nchannels[0]*2+32)*num_regions, feat_dim),
nn.BatchNorm1d(feat_dim),
nn.ReLU(),
)
self.classifier = nn.Linear(feat_dim, num_classes)
self.classifier_global = nn.Linear(feat_dim, num_classes)
self.classifier_local = nn.Linear(feat_dim, num_classes)
self.feat_dim = feat_dim
def init_scale_factors(self, num_regions):
self.scale_factors = []
for region_idx in range(num_regions):
# TODO: initialize scale factors
scale_factors = torch.tensor([[1, 0], [0, 1]]).float()
self.scale_factors.append(scale_factors)
def stn(self, x, theta):
"""Perform spatial transform
x: (batch, channel, height, width)
theta: (batch, 2, 3)
"""
grid = F.affine_grid(theta, x.size())
x = F.grid_sample(x, grid)
return x
def transform_theta(self, theta_i, region_idx):
"""Transform theta from (batch, 2) to (batch, 2, 3),
which includes (s_w, s_h)"""
scale_factors = self.scale_factors[region_idx]
theta = torch.zeros(theta_i.size(0), 2, 3)
theta[:,:,:2] = scale_factors
theta[:,:,-1] = theta_i
return theta
def forward(self, x):
# input size (3, 160, 64)
x = self.conv(x)
# block 1
# ============== Block 1 ==============
# global branch
x1 = self.inception1(x)
x1_attn = self.ha1(x1)
x1_attn, x1_theta = self.ha1(x1)
x1_out = x1 * x1_attn
# local branch
x1_local = []
for region_idx in range(self.num_regions):
x1_theta_i = x1_theta[:,region_idx,:]
x1_theta_i = self.transform_theta(x1_theta_i, region_idx)
x1_trans_i = self.stn(x, x1_theta_i)
# TODO: reduce size of x1_trans_i to (24, 28)
sys.exit()
x1_local_i = self.local_conv1(x1_trans_i)
x1_local.append(x1_local_i)
# block 2
# ============== Block 2 ==============
# Block 2
# global branch
x2 = self.inception2(x1_out)
x2_attn = self.ha2(x2)
x2_attn, x2_theta = self.ha2(x2)
x2_out = x2 * x2_attn
# local branch
# block 3
# ============== Block 3 ==============
# Block 3
# global branch
x3 = self.inception3(x2_out)
x3_attn = self.ha3(x3)
x3_attn, x3_theta = self.ha3(x3)
x3_out = x3 * x3_attn
# local branch
x_global = F.avg_pool2d(x3_out, x3_out.size()[2:]).view(x3_out.size(0), x3_out.size(1))
x_global = self.fc_global(x_global)
@ -208,7 +282,7 @@ class HACNN(nn.Module):
if not self.training:
return x_global
prelogits = self.classifier(x_global)
prelogits = self.classifier_global(x_global)
if self.loss == {'xent'}:
return prelogits
@ -220,10 +294,4 @@ class HACNN(nn.Module):
raise KeyError("Unsupported loss: {}".format(self.loss))
if __name__ == '__main__':
import sys
model = HACNN(10)
model.eval()
x = Variable(torch.rand(5, 3, 160, 64))
print "input size {}".format(x.size())
y = model(x)
print "output size {}".format(y.size())
pass