mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix #2132, remove use of _C.set_grad_enable. Line endings were messed up too
This commit is contained in:
parent
59b3d86c1d
commit
5c5ae8d401
@ -37,14 +37,14 @@ class AsymmetricLossMultiLabel(nn.Module):
|
|||||||
# Asymmetric Focusing
|
# Asymmetric Focusing
|
||||||
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
if self.gamma_neg > 0 or self.gamma_pos > 0:
|
||||||
if self.disable_torch_grad_focal_loss:
|
if self.disable_torch_grad_focal_loss:
|
||||||
torch._C.set_grad_enabled(False)
|
torch.set_grad_enabled(False)
|
||||||
pt0 = xs_pos * y
|
pt0 = xs_pos * y
|
||||||
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
pt1 = xs_neg * (1 - y) # pt = p if t > 0 else 1-p
|
||||||
pt = pt0 + pt1
|
pt = pt0 + pt1
|
||||||
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
one_sided_gamma = self.gamma_pos * y + self.gamma_neg * (1 - y)
|
||||||
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
one_sided_w = torch.pow(1 - pt, one_sided_gamma)
|
||||||
if self.disable_torch_grad_focal_loss:
|
if self.disable_torch_grad_focal_loss:
|
||||||
torch._C.set_grad_enabled(True)
|
torch.set_grad_enabled(True)
|
||||||
loss *= one_sided_w
|
loss *= one_sided_w
|
||||||
|
|
||||||
return -loss.sum()
|
return -loss.sum()
|
||||||
|
Loading…
x
Reference in New Issue
Block a user