From b84cf42f7affbe8d9010b8060fb4538d782c6ac1 Mon Sep 17 00:00:00 2001
From: mark <makangi@163.com>
Date: Thu, 11 Jul 2019 09:54:33 +0800
Subject: [PATCH] Update center_loss.py

Loop structure is very time consuming
---
 layers/center_loss.py | 18 ++++++++++--------
 1 file changed, 10 insertions(+), 8 deletions(-)

diff --git a/layers/center_loss.py b/layers/center_loss.py
index 67875b4..0f5fd21 100644
--- a/layers/center_loss.py
+++ b/layers/center_loss.py
@@ -44,13 +44,15 @@ class CenterLoss(nn.Module):
         labels = labels.unsqueeze(1).expand(batch_size, self.num_classes)
         mask = labels.eq(classes.expand(batch_size, self.num_classes))
 
-        dist = []
-        for i in range(batch_size):
-            value = distmat[i][mask[i]]
-            value = value.clamp(min=1e-12, max=1e+12)  # for numerical stability
-            dist.append(value)
-        dist = torch.cat(dist)
-        loss = dist.mean()
+        dist = distmat * mask.float()
+        loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size
+        #dist = []
+        #for i in range(batch_size):
+        #    value = distmat[i][mask[i]]
+        #    value = value.clamp(min=1e-12, max=1e+12)  # for numerical stability
+        #    dist.append(value)
+        #dist = torch.cat(dist)
+        #loss = dist.mean()
         return loss
 
 
@@ -64,4 +66,4 @@ if __name__ == '__main__':
         targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda()
 
     loss = center_loss(features, targets)
-    print(loss)
\ No newline at end of file
+    print(loss)