diff --git a/fastreid/modeling/losses/circle_loss.py b/fastreid/modeling/losses/circle_loss.py index c379852..c95aac2 100644 --- a/fastreid/modeling/losses/circle_loss.py +++ b/fastreid/modeling/losses/circle_loss.py @@ -29,12 +29,12 @@ def circle_loss( dist_mat = torch.matmul(all_embedding, all_embedding.t()) N = dist_mat.size(0) - is_pos = targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float() + is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float() # Compute the mask which ignores the relevance score of the query to itself is_pos = is_pos - torch.eye(N, N, device=is_pos.device) - is_neg = targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) + is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t()) s_p = dist_mat * is_pos s_n = dist_mat * is_neg