pull/1178/head
gaotongxiao 2022-05-26 17:40:16 +08:00
parent 1e1da7b395
commit 43c50eee82
6 changed files with 35 additions and 12 deletions

View File

@ -23,7 +23,7 @@ class MaskedBalancedBCELoss(nn.Module):
def __init__(self,
reduction: str = 'none',
negative_ratio: Union[float, int] = 3,
eps: float = 1e-6):
eps: float = 1e-6) -> None:
super().__init__()
assert reduction in ['none', 'mean', 'sum']
assert isinstance(negative_ratio, (float, int))

View File

@ -16,7 +16,7 @@ class MaskedDiceLoss(nn.Module):
1e-6.
"""
def __init__(self, eps=1e-6):
def __init__(self, eps: float = 1e-6) -> None:
super().__init__()
assert isinstance(eps, float)
self.eps = eps

View File

@ -14,16 +14,19 @@ class MaskedSmoothL1Loss(nn.Module):
Args:
beta (float, optional): The threshold in the piecewise function.
Defaults to 1.
eps (float, optional): Eps to avoid zero-division error. Defaults to
1e-6.
"""
def __init__(self, beta: Union[float, int] = 1) -> None:
def __init__(self, beta: Union[float, int] = 1, eps: float = 1e-6) -> None:
super().__init__()
self.smooth_l1_loss = nn.SmoothL1Loss(beta=beta)
self.smooth_l1_loss = nn.SmoothL1Loss(beta=beta, reduction='none')
self.eps = eps
def forward(self,
pred: torch.Tensor,
gt: torch.Tensor,
mask: Optional[torch.BoolTensor] = None) -> torch.Tensor:
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function.
Args:
@ -43,7 +46,5 @@ class MaskedSmoothL1Loss(nn.Module):
if mask is None:
mask = torch.ones_like(gt).bool()
assert mask.size() == gt.size()
if not isinstance(mask, torch.BoolTensor):
assert torch.all(torch.logical_or(mask == 0, mask == 1))
mask = mask.bool()
return self.smooth_l1_loss(pred[mask], gt[mask])
loss = self.smooth_l1_loss(pred * mask, gt * mask)
return loss.sum() / (mask.sum() + self.eps)

View File

@ -50,3 +50,8 @@ class TestMaskedBalancedBCELoss(TestCase):
self.bce_loss(self.pred, self.gt, self.mask).item(),
1.4067,
delta=0.1)
# Test zero mask
zero_mask = torch.FloatTensor([0, 0, 0, 0])
self.assertAlmostEqual(
self.bce_loss(self.pred, self.gt, zero_mask).item(), 0)

View File

@ -36,3 +36,8 @@ class TestMaskedDiceLoss(TestCase):
self.loss(self.pred, self.gt, self.mask).item(),
1 / 5,
delta=0.001)
# Test zero mask
zero_mask = torch.FloatTensor([0, 0, 0, 0])
self.assertAlmostEqual(
self.loss(self.pred, self.gt, zero_mask).item(), 1)

View File

@ -27,9 +27,21 @@ class TestMaskedSmoothL1Loss(TestCase):
invalid_mask = torch.BoolTensor([True, False, False])
self.l1_loss(self.pred, self.gt, invalid_mask)
self.assertAlmostEqual(self.l1_loss(self.pred, self.gt).item(), 0.5)
# Test L1 loss results
self.assertAlmostEqual(
self.l1_loss(self.pred, self.gt, self.mask).item(), 0.75)
self.l1_loss(self.pred, self.gt).item(), 0.5, delta=0.01)
self.assertAlmostEqual(
self.l1_loss(self.pred, self.gt, self.mask).item(),
0.75,
delta=0.01)
# Test Smooth L1 loss results
self.assertAlmostEqual(
self.smooth_l1_loss(self.pred, self.gt, self.mask).item(), 0.3125)
self.smooth_l1_loss(self.pred, self.gt, self.mask).item(),
0.3125,
delta=0.01)
# Test zero mask
zero_mask = torch.FloatTensor([0, 0, 0, 0])
self.assertAlmostEqual(
self.smooth_l1_loss(self.pred, self.gt, zero_mask).item(), 0)