mirror of
https://github.com/open-mmlab/mmocr.git
synced 2025-06-03 21:54:47 +08:00
fix loss
This commit is contained in:
parent
1e1da7b395
commit
43c50eee82
@ -23,7 +23,7 @@ class MaskedBalancedBCELoss(nn.Module):
|
|||||||
def __init__(self,
|
def __init__(self,
|
||||||
reduction: str = 'none',
|
reduction: str = 'none',
|
||||||
negative_ratio: Union[float, int] = 3,
|
negative_ratio: Union[float, int] = 3,
|
||||||
eps: float = 1e-6):
|
eps: float = 1e-6) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert reduction in ['none', 'mean', 'sum']
|
assert reduction in ['none', 'mean', 'sum']
|
||||||
assert isinstance(negative_ratio, (float, int))
|
assert isinstance(negative_ratio, (float, int))
|
||||||
|
@ -16,7 +16,7 @@ class MaskedDiceLoss(nn.Module):
|
|||||||
1e-6.
|
1e-6.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, eps=1e-6):
|
def __init__(self, eps: float = 1e-6) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
assert isinstance(eps, float)
|
assert isinstance(eps, float)
|
||||||
self.eps = eps
|
self.eps = eps
|
||||||
|
@ -14,16 +14,19 @@ class MaskedSmoothL1Loss(nn.Module):
|
|||||||
Args:
|
Args:
|
||||||
beta (float, optional): The threshold in the piecewise function.
|
beta (float, optional): The threshold in the piecewise function.
|
||||||
Defaults to 1.
|
Defaults to 1.
|
||||||
|
eps (float, optional): Eps to avoid zero-division error. Defaults to
|
||||||
|
1e-6.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, beta: Union[float, int] = 1) -> None:
|
def __init__(self, beta: Union[float, int] = 1, eps: float = 1e-6) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.smooth_l1_loss = nn.SmoothL1Loss(beta=beta)
|
self.smooth_l1_loss = nn.SmoothL1Loss(beta=beta, reduction='none')
|
||||||
|
self.eps = eps
|
||||||
|
|
||||||
def forward(self,
|
def forward(self,
|
||||||
pred: torch.Tensor,
|
pred: torch.Tensor,
|
||||||
gt: torch.Tensor,
|
gt: torch.Tensor,
|
||||||
mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||||
"""Forward function.
|
"""Forward function.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -43,7 +46,5 @@ class MaskedSmoothL1Loss(nn.Module):
|
|||||||
if mask is None:
|
if mask is None:
|
||||||
mask = torch.ones_like(gt).bool()
|
mask = torch.ones_like(gt).bool()
|
||||||
assert mask.size() == gt.size()
|
assert mask.size() == gt.size()
|
||||||
if not isinstance(mask, torch.BoolTensor):
|
loss = self.smooth_l1_loss(pred * mask, gt * mask)
|
||||||
assert torch.all(torch.logical_or(mask == 0, mask == 1))
|
return loss.sum() / (mask.sum() + self.eps)
|
||||||
mask = mask.bool()
|
|
||||||
return self.smooth_l1_loss(pred[mask], gt[mask])
|
|
||||||
|
@ -50,3 +50,8 @@ class TestMaskedBalancedBCELoss(TestCase):
|
|||||||
self.bce_loss(self.pred, self.gt, self.mask).item(),
|
self.bce_loss(self.pred, self.gt, self.mask).item(),
|
||||||
1.4067,
|
1.4067,
|
||||||
delta=0.1)
|
delta=0.1)
|
||||||
|
|
||||||
|
# Test zero mask
|
||||||
|
zero_mask = torch.FloatTensor([0, 0, 0, 0])
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
self.bce_loss(self.pred, self.gt, zero_mask).item(), 0)
|
||||||
|
@ -36,3 +36,8 @@ class TestMaskedDiceLoss(TestCase):
|
|||||||
self.loss(self.pred, self.gt, self.mask).item(),
|
self.loss(self.pred, self.gt, self.mask).item(),
|
||||||
1 / 5,
|
1 / 5,
|
||||||
delta=0.001)
|
delta=0.001)
|
||||||
|
|
||||||
|
# Test zero mask
|
||||||
|
zero_mask = torch.FloatTensor([0, 0, 0, 0])
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
self.loss(self.pred, self.gt, zero_mask).item(), 1)
|
||||||
|
@ -27,9 +27,21 @@ class TestMaskedSmoothL1Loss(TestCase):
|
|||||||
invalid_mask = torch.BoolTensor([True, False, False])
|
invalid_mask = torch.BoolTensor([True, False, False])
|
||||||
self.l1_loss(self.pred, self.gt, invalid_mask)
|
self.l1_loss(self.pred, self.gt, invalid_mask)
|
||||||
|
|
||||||
self.assertAlmostEqual(self.l1_loss(self.pred, self.gt).item(), 0.5)
|
# Test L1 loss results
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
self.l1_loss(self.pred, self.gt, self.mask).item(), 0.75)
|
self.l1_loss(self.pred, self.gt).item(), 0.5, delta=0.01)
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
self.l1_loss(self.pred, self.gt, self.mask).item(),
|
||||||
|
0.75,
|
||||||
|
delta=0.01)
|
||||||
|
|
||||||
|
# Test Smooth L1 loss results
|
||||||
self.assertAlmostEqual(
|
self.assertAlmostEqual(
|
||||||
self.smooth_l1_loss(self.pred, self.gt, self.mask).item(), 0.3125)
|
self.smooth_l1_loss(self.pred, self.gt, self.mask).item(),
|
||||||
|
0.3125,
|
||||||
|
delta=0.01)
|
||||||
|
|
||||||
|
# Test zero mask
|
||||||
|
zero_mask = torch.FloatTensor([0, 0, 0, 0])
|
||||||
|
self.assertAlmostEqual(
|
||||||
|
self.smooth_l1_loss(self.pred, self.gt, zero_mask).item(), 0)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user