Changed addmm_ to new pytorch API to remove deprecated warnings
parent
7cedf08a5b
commit
d78d9a86aa
|
@ -31,7 +31,7 @@ class TripletLoss(nn.Module):
|
|||
# Compute pairwise distance, replace by the official when merged
|
||||
dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n)
|
||||
dist = dist + dist.t()
|
||||
dist.addmm_(1, -2, inputs, inputs.t())
|
||||
dist.addmm_(inputs, inputs.t(), beta=1, alpha=-2)
|
||||
dist = dist.clamp(min=1e-12).sqrt() # for numerical stability
|
||||
|
||||
# For each anchor, find the hardest positive and negative
|
||||
|
|
|
@ -60,7 +60,7 @@ def euclidean_squared_distance(input1, input2):
|
|||
mat1 = torch.pow(input1, 2).sum(dim=1, keepdim=True).expand(m, n)
|
||||
mat2 = torch.pow(input2, 2).sum(dim=1, keepdim=True).expand(n, m).t()
|
||||
distmat = mat1 + mat2
|
||||
distmat.addmm_(1, -2, input1, input2.t())
|
||||
distmat.addmm_(input1, input2.t(), beta=1, alpha=-2)
|
||||
return distmat
|
||||
|
||||
|
||||
|
|
Loading…
Reference in New Issue