From 397f8f6fb1ad1747aa28b90f3288f4e1bc117f8d Mon Sep 17 00:00:00 2001 From: KaiyangZhou Date: Thu, 3 May 2018 10:47:20 +0100 Subject: [PATCH] adapt to ring loss --- models/DenseNet.py | 2 ++ models/ResNet.py | 4 ++++ 2 files changed, 6 insertions(+) diff --git a/models/DenseNet.py b/models/DenseNet.py index 0c05e2c..70e3e24 100644 --- a/models/DenseNet.py +++ b/models/DenseNet.py @@ -30,5 +30,7 @@ class DenseNet121(nn.Module): return y, f elif self.loss == {'cent'}: return y, f + elif self.loss == {'ring'}: + return y, f else: raise KeyError("Unsupported loss: {}".format(self.loss)) \ No newline at end of file diff --git a/models/ResNet.py b/models/ResNet.py index f398401..457aecf 100755 --- a/models/ResNet.py +++ b/models/ResNet.py @@ -30,6 +30,8 @@ class ResNet50(nn.Module): return y, f elif self.loss == {'cent'}: return y, f + elif self.loss == {'ring'}: + return y, f else: raise KeyError("Unsupported loss: {}".format(self.loss)) @@ -83,5 +85,7 @@ class ResNet50M(nn.Module): return prelogits, combofeat elif self.loss == {'cent'}: return prelogits, combofeat + elif self.loss == {'ring'}: + return prelogits, combofeat else: raise KeyError("Unsupported loss: {}".format(self.loss)) \ No newline at end of file