This commit is contained in:
gaotongxiao 2022-05-26 17:40:16 +08:00
parent 1e1da7b395
commit 43c50eee82
6 changed files with 35 additions and 12 deletions

View File

@ -23,7 +23,7 @@ class MaskedBalancedBCELoss(nn.Module):
def __init__(self, def __init__(self,
reduction: str = 'none', reduction: str = 'none',
negative_ratio: Union[float, int] = 3, negative_ratio: Union[float, int] = 3,
eps: float = 1e-6): eps: float = 1e-6) -> None:
super().__init__() super().__init__()
assert reduction in ['none', 'mean', 'sum'] assert reduction in ['none', 'mean', 'sum']
assert isinstance(negative_ratio, (float, int)) assert isinstance(negative_ratio, (float, int))

View File

@ -16,7 +16,7 @@ class MaskedDiceLoss(nn.Module):
1e-6. 1e-6.
""" """
def __init__(self, eps=1e-6): def __init__(self, eps: float = 1e-6) -> None:
super().__init__() super().__init__()
assert isinstance(eps, float) assert isinstance(eps, float)
self.eps = eps self.eps = eps

View File

@ -14,16 +14,19 @@ class MaskedSmoothL1Loss(nn.Module):
Args: Args:
beta (float, optional): The threshold in the piecewise function. beta (float, optional): The threshold in the piecewise function.
Defaults to 1. Defaults to 1.
eps (float, optional): Eps to avoid zero-division error. Defaults to
1e-6.
""" """
def __init__(self, beta: Union[float, int] = 1) -> None: def __init__(self, beta: Union[float, int] = 1, eps: float = 1e-6) -> None:
super().__init__() super().__init__()
self.smooth_l1_loss = nn.SmoothL1Loss(beta=beta) self.smooth_l1_loss = nn.SmoothL1Loss(beta=beta, reduction='none')
self.eps = eps
def forward(self, def forward(self,
pred: torch.Tensor, pred: torch.Tensor,
gt: torch.Tensor, gt: torch.Tensor,
mask: Optional[torch.BoolTensor] = None) -> torch.Tensor: mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward function. """Forward function.
Args: Args:
@ -43,7 +46,5 @@ class MaskedSmoothL1Loss(nn.Module):
if mask is None: if mask is None:
mask = torch.ones_like(gt).bool() mask = torch.ones_like(gt).bool()
assert mask.size() == gt.size() assert mask.size() == gt.size()
if not isinstance(mask, torch.BoolTensor): loss = self.smooth_l1_loss(pred * mask, gt * mask)
assert torch.all(torch.logical_or(mask == 0, mask == 1)) return loss.sum() / (mask.sum() + self.eps)
mask = mask.bool()
return self.smooth_l1_loss(pred[mask], gt[mask])

View File

@ -50,3 +50,8 @@ class TestMaskedBalancedBCELoss(TestCase):
self.bce_loss(self.pred, self.gt, self.mask).item(), self.bce_loss(self.pred, self.gt, self.mask).item(),
1.4067, 1.4067,
delta=0.1) delta=0.1)
# Test zero mask
zero_mask = torch.FloatTensor([0, 0, 0, 0])
self.assertAlmostEqual(
self.bce_loss(self.pred, self.gt, zero_mask).item(), 0)

View File

@ -36,3 +36,8 @@ class TestMaskedDiceLoss(TestCase):
self.loss(self.pred, self.gt, self.mask).item(), self.loss(self.pred, self.gt, self.mask).item(),
1 / 5, 1 / 5,
delta=0.001) delta=0.001)
# Test zero mask
zero_mask = torch.FloatTensor([0, 0, 0, 0])
self.assertAlmostEqual(
self.loss(self.pred, self.gt, zero_mask).item(), 1)

View File

@ -27,9 +27,21 @@ class TestMaskedSmoothL1Loss(TestCase):
invalid_mask = torch.BoolTensor([True, False, False]) invalid_mask = torch.BoolTensor([True, False, False])
self.l1_loss(self.pred, self.gt, invalid_mask) self.l1_loss(self.pred, self.gt, invalid_mask)
self.assertAlmostEqual(self.l1_loss(self.pred, self.gt).item(), 0.5) # Test L1 loss results
self.assertAlmostEqual( self.assertAlmostEqual(
self.l1_loss(self.pred, self.gt, self.mask).item(), 0.75) self.l1_loss(self.pred, self.gt).item(), 0.5, delta=0.01)
self.assertAlmostEqual(
self.l1_loss(self.pred, self.gt, self.mask).item(),
0.75,
delta=0.01)
# Test Smooth L1 loss results
self.assertAlmostEqual( self.assertAlmostEqual(
self.smooth_l1_loss(self.pred, self.gt, self.mask).item(), 0.3125) self.smooth_l1_loss(self.pred, self.gt, self.mask).item(),
0.3125,
delta=0.01)
# Test zero mask
zero_mask = torch.FloatTensor([0, 0, 0, 0])
self.assertAlmostEqual(
self.smooth_l1_loss(self.pred, self.gt, zero_mask).item(), 0)