deep-person-reid/models/DenseNet.py

34 lines
1018 B
Python
Raw Normal View History

2018-03-12 05:17:48 +08:00
from __future__ import absolute_import
import torch
from torch import nn
from torch.nn import functional as F
import torchvision
__all__ = ['DenseNet121']
class DenseNet121(nn.Module):
2018-03-12 21:53:08 +08:00
def __init__(self, num_classes, loss={'xent'}, **kwargs):
2018-03-12 05:17:48 +08:00
super(DenseNet121, self).__init__()
2018-03-12 21:53:08 +08:00
self.loss = loss
2018-03-12 05:17:48 +08:00
densenet121 = torchvision.models.densenet121(pretrained=True)
self.base = densenet121.features
self.classifier = nn.Linear(1024, num_classes)
2018-03-22 06:26:43 +08:00
self.feat_dim = 1024 # feature dimension
2018-03-12 05:17:48 +08:00
def forward(self, x):
x = self.base(x)
x = F.avg_pool2d(x, x.size()[2:])
2018-03-12 21:53:08 +08:00
f = x.view(x.size(0), -1)
2018-03-12 05:17:48 +08:00
if not self.training:
2018-03-12 21:53:08 +08:00
return f
y = self.classifier(f)
if self.loss == {'xent'}:
return y
elif self.loss == {'xent', 'htri'}:
return y, f
2018-03-22 06:26:43 +08:00
elif self.loss == {'cent'}:
return y, f
2018-03-12 21:53:08 +08:00
else:
2018-03-22 06:26:43 +08:00
raise KeyError("Unsupported loss: {}".format(self.loss))