diff --git a/mmseg/models/losses/kldiv_loss.py b/mmseg/models/losses/kldiv_loss.py new file mode 100644 index 000000000..496ef9713 --- /dev/null +++ b/mmseg/models/losses/kldiv_loss.py @@ -0,0 +1,99 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch +import torch.nn as nn +import torch.nn.functional as F + +from mmseg.registry import MODELS + + +@MODELS.register_module() +class KLDivLoss(nn.Module): + + def __init__(self, + temperature: float = 1.0, + reduction: str = 'mean', + loss_name: str = 'loss_kld'): + """Kullback-Leibler divergence Loss. + + + + Args: + temperature (float, optional): Temperature param + reduction (str, optional): The method to reduce the loss into a + scalar. Default is "mean". Options are "none", "sum", + and "mean" + """ + + assert isinstance(temperature, (float, int)), \ + 'Expected temperature to be' \ + f'float or int, but got {temperature.__class__.__name__} instead' + assert temperature != 0., 'Temperature must not be zero' + + assert reduction in ['mean', 'none', 'sum'], \ + 'Reduction must be one of the options ("mean", ' \ + f'"sum", "none"), but got {reduction}' + + super().__init__() + self.temperature = temperature + self.reduction = reduction + self._loss_name = loss_name + + def forward(self, input: torch.Tensor, target: torch.Tensor): + """Forward function. Calculate KL divergence Loss. + + Args: + input (Tensor): Logit tensor, + the data type is float32 or float64. + The shape is (N, C) where N is batchsize and C is number of + channels. + If there more than 2 dimensions, shape is (N, C, D1, D2, ... + Dk), k>= 1 + target (Tensor): Logit tensor, + the data type is float32 or float64. + input and target must be with the same shape. + + Returns: + (Tensor): Reduced loss. + """ + assert isinstance(input, torch.Tensor), 'Expected input to' \ + f'be Tensor, but got {input.__class__.__name__} instead' + assert isinstance(target, torch.Tensor), 'Expected target to' \ + f'be Tensor, but got {target.__class__.__name__} instead' + + assert input.shape == target.shape, 'Input and target ' \ + 'must have same shape,' \ + f'but got shapes {input.shape} and {target.shape}' + + input = F.softmax(input / self.temperature, dim=1) + target = F.softmax(target / self.temperature, dim=1) + + loss = F.kl_div(input, target, reduction='none', log_target=False) + loss = loss * self.temperature**2 + + batch_size = input.shape[0] + + if self.reduction == 'sum': + # Change view to calculate instance-wise sum + loss = loss.view(batch_size, -1) + return torch.sum(loss, dim=1) + + elif self.reduction == 'mean': + # Change view to calculate instance-wise mean + loss = loss.view(batch_size, -1) + return torch.mean(loss, dim=1) + + return loss + + @property + def loss_name(self): + """Loss Name. + + This function must be implemented and will return the name of this + loss function. This name will be used to combine different loss items + by simple sum operation. In addition, if you want this loss item to be + included into the backward graph, `loss_` must be the prefix of the + name. + Returns: + str: The name of this loss item. + """ + return self._loss_name diff --git a/tests/test_models/test_losses/test_kldiv_loss.py b/tests/test_models/test_losses/test_kldiv_loss.py new file mode 100644 index 000000000..48bcc4bfd --- /dev/null +++ b/tests/test_models/test_losses/test_kldiv_loss.py @@ -0,0 +1,40 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import torch + +from mmseg.models.losses.kldiv_loss import KLDivLoss + + +def test_kldiv_loss_with_none_reduction(): + loss_class = KLDivLoss + pred = torch.rand((8, 5, 5)) + target = torch.rand((8, 5, 5)) + reduction = 'none' + + # Test loss forward + loss = loss_class(reduction=reduction)(pred, target) + assert isinstance(loss, torch.Tensor) + assert loss.shape == (8, 5, 5), f'{loss.shape}' + + +def test_kldiv_loss_with_mean_reduction(): + loss_class = KLDivLoss + pred = torch.rand((8, 5, 5)) + target = torch.rand((8, 5, 5)) + reduction = 'mean' + + # Test loss forward + loss = loss_class(reduction=reduction)(pred, target) + assert isinstance(loss, torch.Tensor) + assert loss.shape == (8, ), f'{loss.shape}' + + +def test_kldiv_loss_with_sum_reduction(): + loss_class = KLDivLoss + pred = torch.rand((8, 5, 5)) + target = torch.rand((8, 5, 5)) + reduction = 'sum' + + # Test loss forward + loss = loss_class(reduction=reduction)(pred, target) + assert isinstance(loss, torch.Tensor) + assert loss.shape == (8, ), f'{loss.shape}'