fix pair-wise circle loss

fix #252
pull/299/head
liaoxingyu 2020-09-09 15:28:52 +08:00
parent 919a515eb7
commit aa5c422606
1 changed files with 2 additions and 2 deletions

View File

@ -29,12 +29,12 @@ def circle_loss(
dist_mat = torch.matmul(all_embedding, all_embedding.t())
N = dist_mat.size(0)
is_pos = targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float()
is_pos = all_targets.view(N, 1).expand(N, N).eq(all_targets.view(N, 1).expand(N, N).t()).float()
# Compute the mask which ignores the relevance score of the query to itself
is_pos = is_pos - torch.eye(N, N, device=is_pos.device)
is_neg = targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
is_neg = all_targets.view(N, 1).expand(N, N).ne(all_targets.view(N, 1).expand(N, N).t())
s_p = dist_mat * is_pos
s_n = dist_mat * is_neg