add eval mode

pull/17/head
KaiyangZhou 2018-05-12 17:45:16 +01:00
parent e8c3f95c15
commit cc99c36bbf
1 changed files with 3 additions and 0 deletions

View File

@ -108,6 +108,9 @@ class SqueezeNet(nn.Module):
x10 = F.relu(self.conv10(x9))
f = F.avg_pool2d(x10, x10.size()[2:]).view(x10.size(0), -1)
if not self.training:
return f
if self.loss == {'xent'}:
return f
elif self.loss == {'xent', 'htri'}: