mirror of
https://github.com/KaiyangZhou/deep-person-reid.git
synced 2025-06-03 14:53:23 +08:00
add return mode for center loss
This commit is contained in:
parent
36b477c250
commit
032a94b9aa
@ -15,6 +15,7 @@ class DenseNet121(nn.Module):
|
||||
densenet121 = torchvision.models.densenet121(pretrained=True)
|
||||
self.base = densenet121.features
|
||||
self.classifier = nn.Linear(1024, num_classes)
|
||||
self.feat_dim = 1024 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
@ -28,5 +29,7 @@ class DenseNet121(nn.Module):
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
elif self.loss == {'cent'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unknown loss: {}".format(self.loss))
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
@ -15,6 +15,7 @@ class ResNet50(nn.Module):
|
||||
resnet50 = torchvision.models.resnet50(pretrained=True)
|
||||
self.base = nn.Sequential(*list(resnet50.children())[:-2])
|
||||
self.classifier = nn.Linear(2048, num_classes)
|
||||
self.feat_dim = 2048 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x = self.base(x)
|
||||
@ -28,8 +29,10 @@ class ResNet50(nn.Module):
|
||||
return y
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return y, f
|
||||
elif self.loss == {'cent'}:
|
||||
return y, f
|
||||
else:
|
||||
raise KeyError("Unknown loss: {}".format(self.loss))
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
class ResNet50M(nn.Module):
|
||||
"""ResNet50 + mid-level features.
|
||||
@ -52,6 +55,7 @@ class ResNet50M(nn.Module):
|
||||
self.layers5c = self.base[7][2]
|
||||
self.fc_fuse = nn.Sequential(nn.Linear(4096, 1024), nn.BatchNorm1d(1024), nn.ReLU())
|
||||
self.classifier = nn.Linear(3072, num_classes)
|
||||
self.feat_dim = 3072 # feature dimension
|
||||
|
||||
def forward(self, x):
|
||||
x1 = self.layers1(x)
|
||||
@ -78,8 +82,10 @@ class ResNet50M(nn.Module):
|
||||
return prelogits
|
||||
elif self.loss == {'xent', 'htri'}:
|
||||
return prelogits, combofeat
|
||||
elif self.loss == {'cent'}:
|
||||
return prelogits, combofeat
|
||||
else:
|
||||
raise KeyError("Unknown loss: {}".format(self.loss))
|
||||
raise KeyError("Unsupported loss: {}".format(self.loss))
|
||||
|
||||
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user