[CodeCamp2023-526] Kullback-Leibler divergence Loss implementation (#3242)

Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

It's OpenMMLab  Codecamp task.

## Modification

Implementd Kullback-Leibler divergence loss and also added tests for it.

## Checklist

1. Pre-commit or other linting tools are used to fix the potential lint
issues.
2. The modification is covered by complete unit tests. If not, please
add more unit test to ensure the correctness.
3. If the modification has potential influence on downstream projects,
this PR should be tested with downstream projects, like MMDet or
MMDet3D.
4. The documentation has been modified accordingly, like docstring or
example tutorials.

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
Andrey Dolgovyazov 2023-08-28 11:48:26 +03:00 committed by GitHub
parent b2f10954e6
commit 8233e64c7e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 139 additions and 0 deletions

View File

@ -0,0 +1,99 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmseg.registry import MODELS
@MODELS.register_module()
class KLDivLoss(nn.Module):
def __init__(self,
temperature: float = 1.0,
reduction: str = 'mean',
loss_name: str = 'loss_kld'):
"""Kullback-Leibler divergence Loss.
<https://en.wikipedia.org/wiki/Kullback%E2%80%93Leibler_divergence>
Args:
temperature (float, optional): Temperature param
reduction (str, optional): The method to reduce the loss into a
scalar. Default is "mean". Options are "none", "sum",
and "mean"
"""
assert isinstance(temperature, (float, int)), \
'Expected temperature to be' \
f'float or int, but got {temperature.__class__.__name__} instead'
assert temperature != 0., 'Temperature must not be zero'
assert reduction in ['mean', 'none', 'sum'], \
'Reduction must be one of the options ("mean", ' \
f'"sum", "none"), but got {reduction}'
super().__init__()
self.temperature = temperature
self.reduction = reduction
self._loss_name = loss_name
def forward(self, input: torch.Tensor, target: torch.Tensor):
"""Forward function. Calculate KL divergence Loss.
Args:
input (Tensor): Logit tensor,
the data type is float32 or float64.
The shape is (N, C) where N is batchsize and C is number of
channels.
If there more than 2 dimensions, shape is (N, C, D1, D2, ...
Dk), k>= 1
target (Tensor): Logit tensor,
the data type is float32 or float64.
input and target must be with the same shape.
Returns:
(Tensor): Reduced loss.
"""
assert isinstance(input, torch.Tensor), 'Expected input to' \
f'be Tensor, but got {input.__class__.__name__} instead'
assert isinstance(target, torch.Tensor), 'Expected target to' \
f'be Tensor, but got {target.__class__.__name__} instead'
assert input.shape == target.shape, 'Input and target ' \
'must have same shape,' \
f'but got shapes {input.shape} and {target.shape}'
input = F.softmax(input / self.temperature, dim=1)
target = F.softmax(target / self.temperature, dim=1)
loss = F.kl_div(input, target, reduction='none', log_target=False)
loss = loss * self.temperature**2
batch_size = input.shape[0]
if self.reduction == 'sum':
# Change view to calculate instance-wise sum
loss = loss.view(batch_size, -1)
return torch.sum(loss, dim=1)
elif self.reduction == 'mean':
# Change view to calculate instance-wise mean
loss = loss.view(batch_size, -1)
return torch.mean(loss, dim=1)
return loss
@property
def loss_name(self):
"""Loss Name.
This function must be implemented and will return the name of this
loss function. This name will be used to combine different loss items
by simple sum operation. In addition, if you want this loss item to be
included into the backward graph, `loss_` must be the prefix of the
name.
Returns:
str: The name of this loss item.
"""
return self._loss_name

View File

@ -0,0 +1,40 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmseg.models.losses.kldiv_loss import KLDivLoss
def test_kldiv_loss_with_none_reduction():
loss_class = KLDivLoss
pred = torch.rand((8, 5, 5))
target = torch.rand((8, 5, 5))
reduction = 'none'
# Test loss forward
loss = loss_class(reduction=reduction)(pred, target)
assert isinstance(loss, torch.Tensor)
assert loss.shape == (8, 5, 5), f'{loss.shape}'
def test_kldiv_loss_with_mean_reduction():
loss_class = KLDivLoss
pred = torch.rand((8, 5, 5))
target = torch.rand((8, 5, 5))
reduction = 'mean'
# Test loss forward
loss = loss_class(reduction=reduction)(pred, target)
assert isinstance(loss, torch.Tensor)
assert loss.shape == (8, ), f'{loss.shape}'
def test_kldiv_loss_with_sum_reduction():
loss_class = KLDivLoss
pred = torch.rand((8, 5, 5))
target = torch.rand((8, 5, 5))
reduction = 'sum'
# Test loss forward
loss = loss_class(reduction=reduction)(pred, target)
assert isinstance(loss, torch.Tensor)
assert loss.shape == (8, ), f'{loss.shape}'