add return mode for center loss

This commit is contained in:
KaiyangZhou 2018-03-21 22:26:43 +00:00
parent 36b477c250
commit 032a94b9aa
2 changed files with 12 additions and 3 deletions

View File

@ -15,6 +15,7 @@ class DenseNet121(nn.Module):
densenet121 = torchvision.models.densenet121(pretrained=True)
self.base = densenet121.features
self.classifier = nn.Linear(1024, num_classes)
self.feat_dim = 1024 # feature dimension
def forward(self, x):
x = self.base(x)
@ -28,5 +29,7 @@ class DenseNet121(nn.Module):
return y
elif self.loss == {'xent', 'htri'}:
return y, f
elif self.loss == {'cent'}:
return y, f
else:
raise KeyError("Unknown loss: {}".format(self.loss))
raise KeyError("Unsupported loss: {}".format(self.loss))

View File

@ -15,6 +15,7 @@ class ResNet50(nn.Module):
resnet50 = torchvision.models.resnet50(pretrained=True)
self.base = nn.Sequential(*list(resnet50.children())[:-2])
self.classifier = nn.Linear(2048, num_classes)
self.feat_dim = 2048 # feature dimension
def forward(self, x):
x = self.base(x)
@ -28,8 +29,10 @@ class ResNet50(nn.Module):
return y
elif self.loss == {'xent', 'htri'}:
return y, f
elif self.loss == {'cent'}:
return y, f
else:
raise KeyError("Unknown loss: {}".format(self.loss))
raise KeyError("Unsupported loss: {}".format(self.loss))
class ResNet50M(nn.Module):
"""ResNet50 + mid-level features.
@ -52,6 +55,7 @@ class ResNet50M(nn.Module):
self.layers5c = self.base[7][2]
self.fc_fuse = nn.Sequential(nn.Linear(4096, 1024), nn.BatchNorm1d(1024), nn.ReLU())
self.classifier = nn.Linear(3072, num_classes)
self.feat_dim = 3072 # feature dimension
def forward(self, x):
x1 = self.layers1(x)
@ -78,8 +82,10 @@ class ResNet50M(nn.Module):
return prelogits
elif self.loss == {'xent', 'htri'}:
return prelogits, combofeat
elif self.loss == {'cent'}:
return prelogits, combofeat
else:
raise KeyError("Unknown loss: {}".format(self.loss))
raise KeyError("Unsupported loss: {}".format(self.loss))