Fix #2132, remove use of _C.set_grad_enable. Line endings were messed up too

This commit is contained in:
Ross Wightman 2024-04-09 09:00:23 -07:00
parent 59b3d86c1d
commit 5c5ae8d401

View File

@ -37,14 +37,14 @@ class AsymmetricLossMultiLabel(nn.Module):
# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(False)
torch.set_grad_enabled(False)
pt0 = xs_pos * y
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
pt = pt0 + pt1
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(True)
torch.set_grad_enabled(True)
loss *= one_sided_w
return -loss.sum()