diff --git a/models/Inception.py b/models/Inception.py index aac1174..61f7631 100644 --- a/models/Inception.py +++ b/models/Inception.py @@ -343,6 +343,7 @@ def inceptionv4(num_classes=1000, pretrained='imagenet'): class InceptionV4ReID(nn.Module): def __init__(self, num_classes, loss={'xent'}, **kwargs): super(InceptionV4ReID, self).__init__() + self.loss = loss base = inceptionv4() self.features = base.features self.classifier = nn.Linear(1536, num_classes)