Fix #2132, remove use of _C.set_grad_enable. Line endings were messed up too

This commit is contained in:
Ross Wightman 2024-04-09 09:00:23 -07:00
parent 59b3d86c1d
commit 5c5ae8d401

View File

@ -1,97 +1,97 @@
import torch import torch
import torch.nn as nn import torch.nn as nn
class AsymmetricLossMultiLabel(nn.Module): class AsymmetricLossMultiLabel(nn.Module):
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False): def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
super(AsymmetricLossMultiLabel, self).__init__() super(AsymmetricLossMultiLabel, self).__init__()
self.gamma_neg = gamma_neg self.gamma_neg = gamma_neg
self.gamma_pos = gamma_pos self.gamma_pos = gamma_pos
self.clip = clip self.clip = clip
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
self.eps = eps self.eps = eps
def forward(self, x, y): def forward(self, x, y):
"""" """"
Parameters Parameters
---------- ----------
x: input logits x: input logits
y: targets (multi-label binarized vector) y: targets (multi-label binarized vector)
""" """
# Calculating Probabilities # Calculating Probabilities
x_sigmoid = torch.sigmoid(x) x_sigmoid = torch.sigmoid(x)
xs_pos = x_sigmoid xs_pos = x_sigmoid
xs_neg = 1 - x_sigmoid xs_neg = 1 - x_sigmoid
# Asymmetric Clipping # Asymmetric Clipping
if self.clip is not None and self.clip > 0: if self.clip is not None and self.clip > 0:
xs_neg = (xs_neg + self.clip).clamp(max=1) xs_neg = (xs_neg + self.clip).clamp(max=1)
# Basic CE calculation # Basic CE calculation
los_pos = y * torch.log(xs_pos.clamp(min=self.eps)) los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps)) los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
loss = los_pos + los_neg loss = los_pos + los_neg
# Asymmetric Focusing # Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0: if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss: if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(False) torch.set_grad_enabled(False)
pt0 = xs_pos * y pt0 = xs_pos * y
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
pt = pt0 + pt1 pt = pt0 + pt1
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y) one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
one_sided_w = torch.pow(1 - pt, one_sided_gamma) one_sided_w = torch.pow(1 - pt, one_sided_gamma)
if self.disable_torch_grad_focal_loss: if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(True) torch.set_grad_enabled(True)
loss *= one_sided_w loss *= one_sided_w
return -loss.sum() return -loss.sum()
class AsymmetricLossSingleLabel(nn.Module): class AsymmetricLossSingleLabel(nn.Module):
def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'): def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
super(AsymmetricLossSingleLabel, self).__init__() super(AsymmetricLossSingleLabel, self).__init__()
self.eps = eps self.eps = eps
self.logsoftmax = nn.LogSoftmax(dim=-1) self.logsoftmax = nn.LogSoftmax(dim=-1)
self.targets_classes = [] # prevent gpu repeated memory allocation self.targets_classes = [] # prevent gpu repeated memory allocation
self.gamma_pos = gamma_pos self.gamma_pos = gamma_pos
self.gamma_neg = gamma_neg self.gamma_neg = gamma_neg
self.reduction = reduction self.reduction = reduction
def forward(self, inputs, target, reduction=None): def forward(self, inputs, target, reduction=None):
"""" """"
Parameters Parameters
---------- ----------
x: input logits x: input logits
y: targets (1-hot vector) y: targets (1-hot vector)
""" """
num_classes = inputs.size()[-1] num_classes = inputs.size()[-1]
log_preds = self.logsoftmax(inputs) log_preds = self.logsoftmax(inputs)
self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1) self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
# ASL weights # ASL weights
targets = self.targets_classes targets = self.targets_classes
anti_targets = 1 - targets anti_targets = 1 - targets
xs_pos = torch.exp(log_preds) xs_pos = torch.exp(log_preds)
xs_neg = 1 - xs_pos xs_neg = 1 - xs_pos
xs_pos = xs_pos * targets xs_pos = xs_pos * targets
xs_neg = xs_neg * anti_targets xs_neg = xs_neg * anti_targets
asymmetric_w = torch.pow(1 - xs_pos - xs_neg, asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
self.gamma_pos * targets + self.gamma_neg * anti_targets) self.gamma_pos * targets + self.gamma_neg * anti_targets)
log_preds = log_preds * asymmetric_w log_preds = log_preds * asymmetric_w
if self.eps > 0: # label smoothing if self.eps > 0: # label smoothing
self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes) self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
# loss calculation # loss calculation
loss = - self.targets_classes.mul(log_preds) loss = - self.targets_classes.mul(log_preds)
loss = loss.sum(dim=-1) loss = loss.sum(dim=-1)
if self.reduction == 'mean': if self.reduction == 'mean':
loss = loss.mean() loss = loss.mean()
return loss return loss