diff --git a/models/DenseNet.py b/models/DenseNet.py index 0a9a45e..3a23a06 100644 --- a/models/DenseNet.py +++ b/models/DenseNet.py @@ -15,6 +15,7 @@ class DenseNet121(nn.Module): densenet121 = torchvision.models.densenet121(pretrained=True) self.base = densenet121.features self.classifier = nn.Linear(1024, num_classes) + self.feat_dim = 1024 # feature dimension def forward(self, x): x = self.base(x) @@ -28,5 +29,7 @@ class DenseNet121(nn.Module): return y elif self.loss == {'xent', 'htri'}: return y, f + elif self.loss == {'cent'}: + return y, f else: - raise KeyError("Unknown loss: {}".format(self.loss)) \ No newline at end of file + raise KeyError("Unsupported loss: {}".format(self.loss)) \ No newline at end of file diff --git a/models/ResNet.py b/models/ResNet.py index 29572cf..c811cea 100755 --- a/models/ResNet.py +++ b/models/ResNet.py @@ -15,6 +15,7 @@ class ResNet50(nn.Module): resnet50 = torchvision.models.resnet50(pretrained=True) self.base = nn.Sequential(*list(resnet50.children())[:-2]) self.classifier = nn.Linear(2048, num_classes) + self.feat_dim = 2048 # feature dimension def forward(self, x): x = self.base(x) @@ -28,8 +29,10 @@ class ResNet50(nn.Module): return y elif self.loss == {'xent', 'htri'}: return y, f + elif self.loss == {'cent'}: + return y, f else: - raise KeyError("Unknown loss: {}".format(self.loss)) + raise KeyError("Unsupported loss: {}".format(self.loss)) class ResNet50M(nn.Module): """ResNet50 + mid-level features. @@ -52,6 +55,7 @@ class ResNet50M(nn.Module): self.layers5c = self.base[7][2] self.fc_fuse = nn.Sequential(nn.Linear(4096, 1024), nn.BatchNorm1d(1024), nn.ReLU()) self.classifier = nn.Linear(3072, num_classes) + self.feat_dim = 3072 # feature dimension def forward(self, x): x1 = self.layers1(x) @@ -78,8 +82,10 @@ class ResNet50M(nn.Module): return prelogits elif self.loss == {'xent', 'htri'}: return prelogits, combofeat + elif self.loss == {'cent'}: + return prelogits, combofeat else: - raise KeyError("Unknown loss: {}".format(self.loss)) + raise KeyError("Unsupported loss: {}".format(self.loss))