mirror of https://github.com/JDAI-CV/DCL.git
49 lines
1.5 KiB
Python
49 lines
1.5 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
|
|
class FocalLoss(nn.Module): #1d and 2d
|
|
|
|
def __init__(self, gamma=2, size_average=True):
|
|
super(FocalLoss, self).__init__()
|
|
self.gamma = gamma
|
|
self.size_average = size_average
|
|
|
|
|
|
def forward(self, logit, target, class_weight=None, type='softmax'):
|
|
target = target.view(-1, 1).long()
|
|
if type=='sigmoid':
|
|
if class_weight is None:
|
|
class_weight = [1]*2 #[0.5, 0.5]
|
|
|
|
prob = torch.sigmoid(logit)
|
|
prob = prob.view(-1, 1)
|
|
prob = torch.cat((1-prob, prob), 1)
|
|
select = torch.FloatTensor(len(prob), 2).zero_().cuda()
|
|
select.scatter_(1, target, 1.)
|
|
|
|
elif type=='softmax':
|
|
B,C = logit.size()
|
|
if class_weight is None:
|
|
class_weight =[1]*C #[1/C]*C
|
|
|
|
#logit = logit.permute(0, 2, 3, 1).contiguous().view(-1, C)
|
|
prob = F.softmax(logit,1)
|
|
select = torch.FloatTensor(len(prob), C).zero_().cuda()
|
|
select.scatter_(1, target, 1.)
|
|
|
|
class_weight = torch.FloatTensor(class_weight).cuda().view(-1,1)
|
|
class_weight = torch.gather(class_weight, 0, target)
|
|
|
|
prob = (prob*select).sum(1).view(-1,1)
|
|
prob = torch.clamp(prob,1e-8,1-1e-8)
|
|
batch_loss = - class_weight *(torch.pow((1-prob), self.gamma))*prob.log()
|
|
|
|
if self.size_average:
|
|
loss = batch_loss.mean()
|
|
else:
|
|
loss = batch_loss
|
|
|
|
return loss
|
|
|