diff --git a/torchreid/losses/hard_mine_triplet_loss.py b/torchreid/losses/hard_mine_triplet_loss.py index 102881f..ef9019b 100644 --- a/torchreid/losses/hard_mine_triplet_loss.py +++ b/torchreid/losses/hard_mine_triplet_loss.py @@ -31,7 +31,7 @@ class TripletLoss(nn.Module): # Compute pairwise distance, replace by the official when merged dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) dist = dist + dist.t() - dist.addmm_(1, -2, inputs, inputs.t()) + dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability # For each anchor, find the hardest positive and negative diff --git a/torchreid/metrics/distance.py b/torchreid/metrics/distance.py index 6278c34..f4fb383 100644 --- a/torchreid/metrics/distance.py +++ b/torchreid/metrics/distance.py @@ -60,7 +60,7 @@ def euclidean_squared_distance(input1, input2): mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n) mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t() distmat = mat1 + mat2 - distmat.addmm_(1, -2, input1, input2.t()) + distmat.addmm_(input1, input2.t(), beta=1, alpha=-2) return distmat