mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix #2132, remove use of _C.set_grad_enable. Line endings were messed up too
This commit is contained in:
parent
59b3d86c1d
commit
5c5ae8d401
@ -1,97 +1,97 @@
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
|
|
||||||
class AsymmetricLossMultiLabel(nn.Module):
|
class AsymmetricLossMultiLabel(nn.Module):
|
||||||
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
|
def __init__(self, gamma_neg=4, gamma_pos=1, clip=0.05, eps=1e-8, disable_torch_grad_focal_loss=False):
|
||||||
super(AsymmetricLossMultiLabel, self).__init__()
|
super(AsymmetricLossMultiLabel, self).__init__()
|
||||||
|
|
||||||
self.gamma_neg = gamma_neg
|
self.gamma_neg = gamma_neg
|
||||||
self.gamma_pos = gamma_pos
|
self.gamma_pos = gamma_pos
|
||||||
self.clip = clip
|
self.clip = clip
|
||||||
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
self.disable_torch_grad_focal_loss = disable_torch_grad_focal_loss
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self, x, y):
|
def forward(self, x, y):
|
||||||
""""
|
""""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x: input logits
|
x: input logits
|
||||||
y: targets (multi-label binarized vector)
|
y: targets (multi-label binarized vector)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Calculating Probabilities
|
# Calculating Probabilities
|
||||||
x_sigmoid = torch.sigmoid(x)
|
x_sigmoid = torch.sigmoid(x)
|
||||||
xs_pos = x_sigmoid
|
xs_pos = x_sigmoid
|
||||||
xs_neg = 1 - x_sigmoid
|
xs_neg = 1 - x_sigmoid
|
||||||
|
|
||||||
# Asymmetric Clipping
|
# Asymmetric Clipping
|
||||||
if self.clip is not None and self.clip > 0:
|
if self.clip is not None and self.clip > 0:
|
||||||
xs_neg = (xs_neg + self.clip).clamp(max=1)
|
xs_neg = (xs_neg + self.clip).clamp(max=1)
|
||||||
|
|
||||||
# Basic CE calculation
|
# Basic CE calculation
|
||||||
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
|
los_pos = y * torch.log(xs_pos.clamp(min=self.eps))
|
||||||
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
|
los_neg = (1 - y) * torch.log(xs_neg.clamp(min=self.eps))
|
||||||
loss = los_pos + los_neg
|
loss = los_pos + los_neg
|
||||||
|
|
||||||
# Asymmetric Focusing
|
# Asymmetric Focusing
|
||||||
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
||||||
if self.disable_torch_grad_focal_loss:
|
if self.disable_torch_grad_focal_loss:
|
||||||
torch._C.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
pt0 = xs_pos * y
|
pt0 = xs_pos * y
|
||||||
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
||||||
pt = pt0 + pt1
|
pt = pt0 + pt1
|
||||||
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
||||||
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
||||||
if self.disable_torch_grad_focal_loss:
|
if self.disable_torch_grad_focal_loss:
|
||||||
torch._C.set_grad_enabled(True)
|
torch.set_grad_enabled(True)
|
||||||
loss *= one_sided_w
|
loss *= one_sided_w
|
||||||
|
|
||||||
return -loss.sum()
|
return -loss.sum()
|
||||||
|
|
||||||
|
|
||||||
class AsymmetricLossSingleLabel(nn.Module):
|
class AsymmetricLossSingleLabel(nn.Module):
|
||||||
def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
|
def __init__(self, gamma_pos=1, gamma_neg=4, eps: float = 0.1, reduction='mean'):
|
||||||
super(AsymmetricLossSingleLabel, self).__init__()
|
super(AsymmetricLossSingleLabel, self).__init__()
|
||||||
|
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
self.logsoftmax = nn.LogSoftmax(dim=-1)
|
self.logsoftmax = nn.LogSoftmax(dim=-1)
|
||||||
self.targets_classes = [] # prevent gpu repeated memory allocation
|
self.targets_classes = [] # prevent gpu repeated memory allocation
|
||||||
self.gamma_pos = gamma_pos
|
self.gamma_pos = gamma_pos
|
||||||
self.gamma_neg = gamma_neg
|
self.gamma_neg = gamma_neg
|
||||||
self.reduction = reduction
|
self.reduction = reduction
|
||||||
|
|
||||||
def forward(self, inputs, target, reduction=None):
|
def forward(self, inputs, target, reduction=None):
|
||||||
""""
|
""""
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
x: input logits
|
x: input logits
|
||||||
y: targets (1-hot vector)
|
y: targets (1-hot vector)
|
||||||
"""
|
"""
|
||||||
|
|
||||||
num_classes = inputs.size()[-1]
|
num_classes = inputs.size()[-1]
|
||||||
log_preds = self.logsoftmax(inputs)
|
log_preds = self.logsoftmax(inputs)
|
||||||
self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
|
self.targets_classes = torch.zeros_like(inputs).scatter_(1, target.long().unsqueeze(1), 1)
|
||||||
|
|
||||||
# ASL weights
|
# ASL weights
|
||||||
targets = self.targets_classes
|
targets = self.targets_classes
|
||||||
anti_targets = 1 - targets
|
anti_targets = 1 - targets
|
||||||
xs_pos = torch.exp(log_preds)
|
xs_pos = torch.exp(log_preds)
|
||||||
xs_neg = 1 - xs_pos
|
xs_neg = 1 - xs_pos
|
||||||
xs_pos = xs_pos * targets
|
xs_pos = xs_pos * targets
|
||||||
xs_neg = xs_neg * anti_targets
|
xs_neg = xs_neg * anti_targets
|
||||||
asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
|
asymmetric_w = torch.pow(1 - xs_pos - xs_neg,
|
||||||
self.gamma_pos * targets + self.gamma_neg * anti_targets)
|
self.gamma_pos * targets + self.gamma_neg * anti_targets)
|
||||||
log_preds = log_preds * asymmetric_w
|
log_preds = log_preds * asymmetric_w
|
||||||
|
|
||||||
if self.eps > 0: # label smoothing
|
if self.eps > 0: # label smoothing
|
||||||
self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
|
self.targets_classes = self.targets_classes.mul(1 - self.eps).add(self.eps / num_classes)
|
||||||
|
|
||||||
# loss calculation
|
# loss calculation
|
||||||
loss = - self.targets_classes.mul(log_preds)
|
loss = - self.targets_classes.mul(log_preds)
|
||||||
|
|
||||||
loss = loss.sum(dim=-1)
|
loss = loss.sum(dim=-1)
|
||||||
if self.reduction == 'mean':
|
if self.reduction == 'mean':
|
||||||
loss = loss.mean()
|
loss = loss.mean()
|
||||||
|
|
||||||
return loss
|
return loss
|
||||||
|
Loading…
x
Reference in New Issue
Block a user