mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #2143 from huggingface/fix_asymm_set_grad_enable
Fix #2132, remove use of _C.set_grad_enable. Line endings were messed up too
This commit is contained in:
commit
f5ea076a46
@ -37,14 +37,14 @@ class AsymmetricLossMultiLabel(nn.Module):
|
||||
# Asymmetric Focusing
|
||||
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
||||
if self.disable_torch_grad_focal_loss:
|
||||
torch._C.set_grad_enabled(False)
|
||||
torch.set_grad_enabled(False)
|
||||
pt0 = xs_pos * y
|
||||
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
||||
pt = pt0 + pt1
|
||||
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
||||
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
||||
if self.disable_torch_grad_focal_loss:
|
||||
torch._C.set_grad_enabled(True)
|
||||
torch.set_grad_enabled(True)
|
||||
loss *= one_sided_w
|
||||
|
||||
return -loss.sum()
|
||||
|
Loading…
x
Reference in New Issue
Block a user