update param name

pull/17/head
KaiyangZhou 2018-04-25 10:47:34 +01:00
parent 4f90b517a3
commit 6f479d39ee
1 changed files with 4 additions and 4 deletions

View File

@ -152,7 +152,7 @@ class HACNN(nn.Module):
Reference:
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
"""
def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], embed_dim=512, **kwargs):
def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
super(HACNN, self).__init__()
self.loss = loss
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
@ -178,10 +178,10 @@ class HACNN(nn.Module):
)
self.ha3 = HarmAttn(nchannels[2]*6)
self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, embed_dim), nn.BatchNorm1d(embed_dim), nn.ReLU())
self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU())
self.classifier = nn.Linear(embed_dim, num_classes)
self.feat_dim = embed_dim
self.classifier = nn.Linear(feat_dim, num_classes)
self.feat_dim = feat_dim
def forward(self, x):
# input size (3, 160, 64)