DCL/utils/Asoftmax_loss.py

48 lines
1.3 KiB
Python

import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.nn.functional as F
from torch.nn import Parameter
import math
import pdb
class AngleLoss(nn.Module):
def __init__(self, gamma=0):
super(AngleLoss, self).__init__()
self.gamma = gamma
self.it = 0
self.LambdaMin = 50.0
self.LambdaMax = 1500.0
self.lamb = 1500.0
def forward(self, input, target, decay=None):
self.it += 1
cos_theta,phi_theta = input
target = target.view(-1,1) #size=(B,1)
index = cos_theta.data * 0.0 #size=(B,Classnum)
index.scatter_(1,target.data.view(-1,1),1)
index = index.byte()
index = Variable(index)
if decay is None:
self.lamb = max(self.LambdaMin,self.LambdaMax/(1+0.1*self.it ))
else:
self.LambdaMax *= decay
self.lamb = max(self.LambdaMin, self.LambdaMax)
output = cos_theta * 1.0 #size=(B,Classnum)
output[index] -= cos_theta[index]*(1.0+0)/(1+self.lamb)
output[index] += phi_theta[index]*(1.0+0)/(1+self.lamb)
logpt = F.log_softmax(output, 1)
logpt = logpt.gather(1,target)
logpt = logpt.view(-1)
pt = Variable(logpt.data.exp())
loss = -1 * (1-pt)**self.gamma * logpt
loss = loss.mean()
return loss