update param name
parent
4f90b517a3
commit
6f479d39ee
|
@ -152,7 +152,7 @@ class HACNN(nn.Module):
|
|||
Reference:
|
||||
Li et al. Harmonious Attention Network for Person Re-identification. CVPR 2018.
|
||||
"""
|
||||
def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], embed_dim=512, **kwargs):
|
||||
def __init__(self, num_classes, loss={'xent'}, nchannels=[32, 64, 96], feat_dim=512, **kwargs):
|
||||
super(HACNN, self).__init__()
|
||||
self.loss = loss
|
||||
self.conv = ConvBlock(3, 32, 3, s=2, p=1)
|
||||
|
@ -178,10 +178,10 @@ class HACNN(nn.Module):
|
|||
)
|
||||
self.ha3 = HarmAttn(nchannels[2]*6)
|
||||
|
||||
self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, embed_dim), nn.BatchNorm1d(embed_dim), nn.ReLU())
|
||||
self.fc_global = nn.Sequential(nn.Linear(nchannels[2]*6, feat_dim), nn.BatchNorm1d(feat_dim), nn.ReLU())
|
||||
|
||||
self.classifier = nn.Linear(embed_dim, num_classes)
|
||||
self.feat_dim = embed_dim
|
||||
self.classifier = nn.Linear(feat_dim, num_classes)
|
||||
self.feat_dim = feat_dim
|
||||
|
||||
def forward(self, x):
|
||||
# input size (3, 160, 64)
|
||||
|
|
Loading…
Reference in New Issue