diff --git a/torchreid/models/mudeep.py b/torchreid/models/mudeep.py index f6a4043..e6bb20c 100644 --- a/torchreid/models/mudeep.py +++ b/torchreid/models/mudeep.py @@ -194,7 +194,10 @@ class MuDeep(nn.Module): x = x.view(x.size(0), -1) x = self.fc(x) y = self.classifier(x) - + + if not self.training: + return x + if self.loss == 'softmax': return y elif self.loss == 'triplet':