From 6f479d39eea89cbcd9c091ffa8367c612f84799e Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Wed, 25 Apr 2018 10:47:34 +0100 Subject: [PATCH] update param name --- models/HACNN.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/models/HACNN.py b/models/HACNN.py index 5768027..7c381e9 100644 --- a/models/HACNN.py +++ b/models/HACNN.py @@ -152,7 +152,7 @@ class HACNN(nn.Module): Reference: Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018. """ - def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], embed_dim=512, **kwargs): + def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], feat_dim=512, **kwargs): super(HACNN, self).__init__() self.loss = loss self.conv = ConvBlock(3, 32, 3, s=2, p=1) @@ -178,10 +178,10 @@ class HACNN(nn.Module): ) self.ha3 = HarmAttn(nchannels[2]*6) - self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, embed_dim), nn.BatchNorm1d(embed_dim), nn.ReLU()) + self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU()) - self.classifier = nn.Linear(embed_dim, num_classes) - self.feat_dim = embed_dim + self.classifier = nn.Linear(feat_dim, num_classes) + self.feat_dim = feat_dim def forward(self, x): # input size (3, 160, 64)