Changed addmm_ to new pytorch API to remove deprecated warnings

pull/412/head
Carmel Barak 2021-02-03 15:45:08 +02:00
parent 7cedf08a5b
commit d78d9a86aa
2 changed files with 2 additions and 2 deletions

View File

@ -31,7 +31,7 @@ class TripletLoss(nn.Module):
# Compute pairwise distance, replace by the official when merged
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
dist = dist + dist.t()
dist.addmm_(1, -2, inputs, inputs.t())
dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2)
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
# For each anchor, find the hardest positive and negative

View File

@ -60,7 +60,7 @@ def euclidean_squared_distance(input1, input2):
mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
distmat = mat1 + mat2
distmat.addmm_(1, -2, input1, input2.t())
distmat.addmm_(input1, input2.t(), beta=1, alpha=-2)
return distmat