diff --git a/mmseg/models/losses/tversky_loss.py b/mmseg/models/losses/tversky_loss.py index 4ad14f783..96ef92c4d 100644 --- a/mmseg/models/losses/tversky_loss.py +++ b/mmseg/models/losses/tversky_loss.py @@ -16,6 +16,7 @@ def tversky_loss(pred, valid_mask, alpha=0.3, beta=0.7, + gamma=1.0, smooth=1, class_weight=None, ignore_index=255): @@ -31,6 +32,8 @@ def tversky_loss(pred, alpha=alpha, beta=beta, smooth=smooth) + if gamma > 1.0: + tversky_loss **= (1 / gamma) if class_weight is not None: tversky_loss *= class_weight[i] total_loss += tversky_loss @@ -62,7 +65,11 @@ class TverskyLoss(nn.Module): """TverskyLoss. This loss is proposed in `Tversky loss function for image segmentation using 3D fully convolutional deep networks. - `_. + ` + and `A novel focal Tversky loss function with improved attention U-Net for + lesion segmentation. + + `_. Args: smooth (float): A float number to smooth loss, and avoid NaN error. Default: 1. @@ -75,6 +82,9 @@ class TverskyLoss(nn.Module): beta (float, in [0, 1]): The coefficient of false negatives. Default: 0.7. Note: alpha + beta = 1. + gamma (float, in [1, inf]): The focal term. When `gamma` > 1, + the loss focuses more on less accurate predictions that + have been misclassified. Default: 1.0. loss_name (str, optional): Name of the loss item. If you want this loss item to be included into the backward graph, `loss_` must be the prefix of the name. Defaults to 'loss_tversky'. @@ -87,6 +97,7 @@ class TverskyLoss(nn.Module): ignore_index=255, alpha=0.3, beta=0.7, + gamma=1.0, loss_name='loss_tversky'): super(TverskyLoss, self).__init__() self.smooth = smooth @@ -94,8 +105,10 @@ class TverskyLoss(nn.Module): self.loss_weight = loss_weight self.ignore_index = ignore_index assert (alpha + beta == 1.0), 'Sum of alpha and beta but be 1.0!' + assert gamma >= 1.0, 'gamma should be at least 1.0!' self.alpha = alpha self.beta = beta + self.gamma = gamma self._loss_name = loss_name def forward(self, pred, target, **kwargs): @@ -117,6 +130,7 @@ class TverskyLoss(nn.Module): valid_mask=valid_mask, alpha=self.alpha, beta=self.beta, + gamma=self.gamma, smooth=self.smooth, class_weight=class_weight, ignore_index=self.ignore_index) diff --git a/tests/test_models/test_losses/test_tversky_loss.py b/tests/test_models/test_losses/test_tversky_loss.py index 24a4b57e9..70b91e6f1 100644 --- a/tests/test_models/test_losses/test_tversky_loss.py +++ b/tests/test_models/test_losses/test_tversky_loss.py @@ -20,6 +20,21 @@ def test_tversky_lose(): labels = (torch.rand(8, 4, 4) * 3).long() tversky_loss(logits, labels, ignore_index=1) + # test gamma < 1.0 + with pytest.raises(AssertionError): + loss_cfg = dict( + type='TverskyLoss', + class_weight=[1.0, 2.0, 3.0], + loss_weight=1.0, + alpha=0.4, + beta=0.7, + gamma=0.9999, + loss_name='loss_tversky') + tversky_loss = build_loss(loss_cfg) + logits = torch.rand(8, 3, 4, 4) + labels = (torch.rand(8, 4, 4) * 3).long() + tversky_loss(logits, labels, ignore_index=1) + # test tversky loss loss_cfg = dict( type='TverskyLoss',