mirror of https://github.com/open-mmlab/mmocr.git
fix loss
parent
1e1da7b395
commit
43c50eee82
|
@ -23,7 +23,7 @@ class MaskedBalancedBCELoss(nn.Module):
|
|||
def __init__(self,
|
||||
reduction: str = 'none',
|
||||
negative_ratio: Union[float, int] = 3,
|
||||
eps: float = 1e-6):
|
||||
eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
assert reduction in ['none', 'mean', 'sum']
|
||||
assert isinstance(negative_ratio, (float, int))
|
||||
|
|
|
@ -16,7 +16,7 @@ class MaskedDiceLoss(nn.Module):
|
|||
1e-6.
|
||||
"""
|
||||
|
||||
def __init__(self, eps=1e-6):
|
||||
def __init__(self, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
assert isinstance(eps, float)
|
||||
self.eps = eps
|
||||
|
|
|
@ -14,16 +14,19 @@ class MaskedSmoothL1Loss(nn.Module):
|
|||
Args:
|
||||
beta (float, optional): The threshold in the piecewise function.
|
||||
Defaults to 1.
|
||||
eps (float, optional): Eps to avoid zero-division error. Defaults to
|
||||
1e-6.
|
||||
"""
|
||||
|
||||
def __init__(self, beta: Union[float, int] = 1) -> None:
|
||||
def __init__(self, beta: Union[float, int] = 1, eps: float = 1e-6) -> None:
|
||||
super().__init__()
|
||||
self.smooth_l1_loss = nn.SmoothL1Loss(beta=beta)
|
||||
self.smooth_l1_loss = nn.SmoothL1Loss(beta=beta, reduction='none')
|
||||
self.eps = eps
|
||||
|
||||
def forward(self,
|
||||
pred: torch.Tensor,
|
||||
gt: torch.Tensor,
|
||||
mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
|
||||
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
|
@ -43,7 +46,5 @@ class MaskedSmoothL1Loss(nn.Module):
|
|||
if mask is None:
|
||||
mask = torch.ones_like(gt).bool()
|
||||
assert mask.size() == gt.size()
|
||||
if not isinstance(mask, torch.BoolTensor):
|
||||
assert torch.all(torch.logical_or(mask == 0, mask == 1))
|
||||
mask = mask.bool()
|
||||
return self.smooth_l1_loss(pred[mask], gt[mask])
|
||||
loss = self.smooth_l1_loss(pred * mask, gt * mask)
|
||||
return loss.sum() / (mask.sum() + self.eps)
|
||||
|
|
|
@ -50,3 +50,8 @@ class TestMaskedBalancedBCELoss(TestCase):
|
|||
self.bce_loss(self.pred, self.gt, self.mask).item(),
|
||||
1.4067,
|
||||
delta=0.1)
|
||||
|
||||
# Test zero mask
|
||||
zero_mask = torch.FloatTensor([0, 0, 0, 0])
|
||||
self.assertAlmostEqual(
|
||||
self.bce_loss(self.pred, self.gt, zero_mask).item(), 0)
|
||||
|
|
|
@ -36,3 +36,8 @@ class TestMaskedDiceLoss(TestCase):
|
|||
self.loss(self.pred, self.gt, self.mask).item(),
|
||||
1 / 5,
|
||||
delta=0.001)
|
||||
|
||||
# Test zero mask
|
||||
zero_mask = torch.FloatTensor([0, 0, 0, 0])
|
||||
self.assertAlmostEqual(
|
||||
self.loss(self.pred, self.gt, zero_mask).item(), 1)
|
||||
|
|
|
@ -27,9 +27,21 @@ class TestMaskedSmoothL1Loss(TestCase):
|
|||
invalid_mask = torch.BoolTensor([True, False, False])
|
||||
self.l1_loss(self.pred, self.gt, invalid_mask)
|
||||
|
||||
self.assertAlmostEqual(self.l1_loss(self.pred, self.gt).item(), 0.5)
|
||||
# Test L1 loss results
|
||||
self.assertAlmostEqual(
|
||||
self.l1_loss(self.pred, self.gt, self.mask).item(), 0.75)
|
||||
self.l1_loss(self.pred, self.gt).item(), 0.5, delta=0.01)
|
||||
self.assertAlmostEqual(
|
||||
self.l1_loss(self.pred, self.gt, self.mask).item(),
|
||||
0.75,
|
||||
delta=0.01)
|
||||
|
||||
# Test Smooth L1 loss results
|
||||
self.assertAlmostEqual(
|
||||
self.smooth_l1_loss(self.pred, self.gt, self.mask).item(), 0.3125)
|
||||
self.smooth_l1_loss(self.pred, self.gt, self.mask).item(),
|
||||
0.3125,
|
||||
delta=0.01)
|
||||
|
||||
# Test zero mask
|
||||
zero_mask = torch.FloatTensor([0, 0, 0, 0])
|
||||
self.assertAlmostEqual(
|
||||
self.smooth_l1_loss(self.pred, self.gt, zero_mask).item(), 0)
|
||||
|
|
Loading…
Reference in New Issue