adapt to ring loss

pull/17/head
KaiyangZhou 2018-05-03 10:47:20 +01:00
parent 2d68dc3d47
commit 397f8f6fb1
2 changed files with 6 additions and 0 deletions

View File

@ -30,5 +30,7 @@ class DenseNet121(nn.Module):
return y, f
elif self.loss == {'cent'}:
return y, f
elif self.loss == {'ring'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))

View File

@ -30,6 +30,8 @@ class ResNet50(nn.Module):
return y, f
elif self.loss == {'cent'}:
return y, f
elif self.loss == {'ring'}:
return y, f
else:
raise KeyError("Unsupported loss: {}".format(self.loss))
@ -83,5 +85,7 @@ class ResNet50M(nn.Module):
return prelogits, combofeat
elif self.loss == {'cent'}:
return prelogits, combofeat
elif self.loss == {'ring'}:
return prelogits, combofeat
else:
raise KeyError("Unsupported loss: {}".format(self.loss))