diff --git a/layers/center_loss.py b/layers/center_loss.py index 67875b4..0f5fd21 100644 --- a/layers/center_loss.py +++ b/layers/center_loss.py @@ -44,13 +44,15 @@ class CenterLoss(nn.Module): labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) mask = labels.eq(classes.expand(batch_size, self.num_classes)) - dist = [] - for i in range(batch_size): - value = distmat[i][mask[i]] - value = value.clamp(min=1e-12, max=1e+12) # for numerical stability - dist.append(value) - dist = torch.cat(dist) - loss = dist.mean() + dist = distmat * mask.float() + loss = dist.clamp(min=1e-12, max=1e+12).sum() / batch_size + #dist = [] + #for i in range(batch_size): + # value = distmat[i][mask[i]] + # value = value.clamp(min=1e-12, max=1e+12) # for numerical stability + # dist.append(value) + #dist = torch.cat(dist) + #loss = dist.mean() return loss @@ -64,4 +66,4 @@ if __name__ == '__main__': targets = torch.Tensor([0, 1, 2, 3, 2, 3, 1, 4, 5, 3, 2, 1, 0, 0, 5, 4]).cuda() loss = center_loss(features, targets) - print(loss) \ No newline at end of file + print(loss)