Merge pull request #2143 from huggingface/fix_asymm_set_grad_enable

Fix #2132, remove use of _C.set_grad_enable. Line endings were messed up too
This commit is contained in:
Ross Wightman 2024-04-09 10:14:13 -07:00 committed by GitHub
commit f5ea076a46
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -37,14 +37,14 @@ class AsymmetricLossMultiLabel(nn.Module):
# Asymmetric Focusing
if self.gamma_neg > 0 or self.gamma_pos > 0:
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(False)
torch.set_grad_enabled(False)
pt0 = xs_pos * y
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
pt = pt0 + pt1
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
if self.disable_torch_grad_focal_loss:
torch._C.set_grad_enabled(True)
torch.set_grad_enabled(True)
loss *= one_sided_w
return -loss.sum()