[New Feature]add lovasz loss (#351)
* add lovasz loss * Modify as comments * Modify paper url * add unittest and remove Var * impove unittestpull/1801/head
parent
11471ad9bb
commit
c8bbd3fa95
|
@ -1,10 +1,11 @@
|
|||
from .accuracy import Accuracy, accuracy
|
||||
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy, mask_cross_entropy)
|
||||
from .lovasz_loss import LovaszLoss
|
||||
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
|
||||
|
||||
__all__ = [
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
||||
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'weighted_loss'
|
||||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,303 @@
|
|||
"""Modified from https://github.com/bermanmaxim/LovaszSoftmax/blob/master/pytor
|
||||
ch/lovasz_losses.py Lovasz-Softmax and Jaccard hinge loss in PyTorch Maxim
|
||||
Berman 2018 ESAT-PSI KU Leuven (MIT License)"""
|
||||
|
||||
import mmcv
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..builder import LOSSES
|
||||
from .utils import weight_reduce_loss
|
||||
|
||||
|
||||
def lovasz_grad(gt_sorted):
|
||||
"""Computes gradient of the Lovasz extension w.r.t sorted errors.
|
||||
|
||||
See Alg. 1 in paper.
|
||||
"""
|
||||
p = len(gt_sorted)
|
||||
gts = gt_sorted.sum()
|
||||
intersection = gts - gt_sorted.float().cumsum(0)
|
||||
union = gts + (1 - gt_sorted).float().cumsum(0)
|
||||
jaccard = 1. - intersection / union
|
||||
if p > 1: # cover 1-pixel case
|
||||
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
||||
return jaccard
|
||||
|
||||
|
||||
def flatten_binary_logits(logits, labels, ignore_index=None):
|
||||
"""Flattens predictions in the batch (binary case) Remove labels equal to
|
||||
'ignore_index'."""
|
||||
logits = logits.view(-1)
|
||||
labels = labels.view(-1)
|
||||
if ignore_index is None:
|
||||
return logits, labels
|
||||
valid = (labels != ignore_index)
|
||||
vlogits = logits[valid]
|
||||
vlabels = labels[valid]
|
||||
return vlogits, vlabels
|
||||
|
||||
|
||||
def flatten_probs(probs, labels, ignore_index=None):
|
||||
"""Flattens predictions in the batch."""
|
||||
if probs.dim() == 3:
|
||||
# assumes output of a sigmoid layer
|
||||
B, H, W = probs.size()
|
||||
probs = probs.view(B, 1, H, W)
|
||||
B, C, H, W = probs.size()
|
||||
probs = probs.permute(0, 2, 3, 1).contiguous().view(-1, C) # B*H*W, C=P,C
|
||||
labels = labels.view(-1)
|
||||
if ignore_index is None:
|
||||
return probs, labels
|
||||
valid = (labels != ignore_index)
|
||||
vprobs = probs[valid.nonzero().squeeze()]
|
||||
vlabels = labels[valid]
|
||||
return vprobs, vlabels
|
||||
|
||||
|
||||
def lovasz_hinge_flat(logits, labels):
|
||||
"""Binary Lovasz hinge loss.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): [P], logits at each prediction
|
||||
(between -infty and +infty).
|
||||
labels (torch.Tensor): [P], binary ground truth labels (0 or 1).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
if len(labels) == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return logits.sum() * 0.
|
||||
signs = 2. * labels.float() - 1.
|
||||
errors = (1. - logits * signs)
|
||||
errors_sorted, perm = torch.sort(errors, dim=0, descending=True)
|
||||
perm = perm.data
|
||||
gt_sorted = labels[perm]
|
||||
grad = lovasz_grad(gt_sorted)
|
||||
loss = torch.dot(F.relu(errors_sorted), grad)
|
||||
return loss
|
||||
|
||||
|
||||
def lovasz_hinge(logits,
|
||||
labels,
|
||||
classes='present',
|
||||
per_image=False,
|
||||
class_weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
ignore_index=255):
|
||||
"""Binary Lovasz hinge loss.
|
||||
|
||||
Args:
|
||||
logits (torch.Tensor): [B, H, W], logits at each pixel
|
||||
(between -infty and +infty).
|
||||
labels (torch.Tensor): [B, H, W], binary ground truth masks (0 or 1).
|
||||
classes (str | list[int], optional): Placeholder, to be consistent with
|
||||
other loss. Default: None.
|
||||
per_image (bool, optional): If per_image is True, compute the loss per
|
||||
image instead of per batch. Default: False.
|
||||
class_weight (list[float], optional): Placeholder, to be consistent
|
||||
with other loss. Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. This parameter only works when per_image is True.
|
||||
Default: None.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
if per_image:
|
||||
loss = [
|
||||
lovasz_hinge_flat(*flatten_binary_logits(
|
||||
logit.unsqueeze(0), label.unsqueeze(0), ignore_index))
|
||||
for logit, label in zip(logits, labels)
|
||||
]
|
||||
loss = weight_reduce_loss(
|
||||
torch.stack(loss), None, reduction, avg_factor)
|
||||
else:
|
||||
loss = lovasz_hinge_flat(
|
||||
*flatten_binary_logits(logits, labels, ignore_index))
|
||||
return loss
|
||||
|
||||
|
||||
def lovasz_softmax_flat(probs, labels, classes='present', class_weight=None):
|
||||
"""Multi-class Lovasz-Softmax loss.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): [P, C], class probabilities at each prediction
|
||||
(between 0 and 1).
|
||||
labels (torch.Tensor): [P], ground truth labels (between 0 and C - 1).
|
||||
classes (str | list[int], optional): Classes choosed to calculate loss.
|
||||
'all' for all classes, 'present' for classes present in labels, or
|
||||
a list of classes to average. Default: 'present'.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
if probs.numel() == 0:
|
||||
# only void pixels, the gradients should be 0
|
||||
return probs * 0.
|
||||
C = probs.size(1)
|
||||
losses = []
|
||||
class_to_sum = list(range(C)) if classes in ['all', 'present'] else classes
|
||||
for c in class_to_sum:
|
||||
fg = (labels == c).float() # foreground for class c
|
||||
if (classes == 'present' and fg.sum() == 0):
|
||||
continue
|
||||
if C == 1:
|
||||
if len(classes) > 1:
|
||||
raise ValueError('Sigmoid output possible only with 1 class')
|
||||
class_pred = probs[:, 0]
|
||||
else:
|
||||
class_pred = probs[:, c]
|
||||
errors = (fg - class_pred).abs()
|
||||
errors_sorted, perm = torch.sort(errors, 0, descending=True)
|
||||
perm = perm.data
|
||||
fg_sorted = fg[perm]
|
||||
loss = torch.dot(errors_sorted, lovasz_grad(fg_sorted))
|
||||
if class_weight is not None:
|
||||
loss *= class_weight[c]
|
||||
losses.append(loss)
|
||||
return torch.stack(losses).mean()
|
||||
|
||||
|
||||
def lovasz_softmax(probs,
|
||||
labels,
|
||||
classes='present',
|
||||
per_image=False,
|
||||
class_weight=None,
|
||||
reduction='mean',
|
||||
avg_factor=None,
|
||||
ignore_index=255):
|
||||
"""Multi-class Lovasz-Softmax loss.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): [B, C, H, W], class probabilities at each
|
||||
prediction (between 0 and 1).
|
||||
labels (torch.Tensor): [B, H, W], ground truth labels (between 0 and
|
||||
C - 1).
|
||||
classes (str | list[int], optional): Classes choosed to calculate loss.
|
||||
'all' for all classes, 'present' for classes present in labels, or
|
||||
a list of classes to average. Default: 'present'.
|
||||
per_image (bool, optional): If per_image is True, compute the loss per
|
||||
image instead of per batch. Default: False.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
avg_factor (int, optional): Average factor that is used to average
|
||||
the loss. This parameter only works when per_image is True.
|
||||
Default: None.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: The calculated loss.
|
||||
"""
|
||||
|
||||
if per_image:
|
||||
loss = [
|
||||
lovasz_softmax_flat(
|
||||
*flatten_probs(
|
||||
prob.unsqueeze(0), label.unsqueeze(0), ignore_index),
|
||||
classes=classes,
|
||||
class_weight=class_weight)
|
||||
for prob, label in zip(probs, labels)
|
||||
]
|
||||
loss = weight_reduce_loss(
|
||||
torch.stack(loss), None, reduction, avg_factor)
|
||||
else:
|
||||
loss = lovasz_softmax_flat(
|
||||
*flatten_probs(probs, labels, ignore_index),
|
||||
classes=classes,
|
||||
class_weight=class_weight)
|
||||
return loss
|
||||
|
||||
|
||||
@LOSSES.register_module()
|
||||
class LovaszLoss(nn.Module):
|
||||
"""LovaszLoss.
|
||||
|
||||
This loss is proposed in `The Lovasz-Softmax loss: A tractable surrogate
|
||||
for the optimization of the intersection-over-union measure in neural
|
||||
networks <https://arxiv.org/abs/1705.08790>`_.
|
||||
|
||||
Args:
|
||||
loss_type (str, optional): Binary or multi-class loss.
|
||||
Default: 'multi_class'. Options are "binary" and "multi_class".
|
||||
classes (str | list[int], optional): Classes choosed to calculate loss.
|
||||
'all' for all classes, 'present' for classes present in labels, or
|
||||
a list of classes to average. Default: 'present'.
|
||||
per_image (bool, optional): If per_image is True, compute the loss per
|
||||
image instead of per batch. Default: False.
|
||||
reduction (str, optional): The method used to reduce the loss. Options
|
||||
are "none", "mean" and "sum". This parameter only works when
|
||||
per_image is True. Default: 'mean'.
|
||||
class_weight (list[float], optional): The weight for each class.
|
||||
Default: None.
|
||||
loss_weight (float, optional): Weight of the loss. Defaults to 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
loss_type='multi_class',
|
||||
classes='present',
|
||||
per_image=False,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0):
|
||||
super(LovaszLoss, self).__init__()
|
||||
assert loss_type in ('binary', 'multi_class'), "loss_type should be \
|
||||
'binary' or 'multi_class'."
|
||||
|
||||
if loss_type == 'binary':
|
||||
self.cls_criterion = lovasz_hinge
|
||||
else:
|
||||
self.cls_criterion = lovasz_softmax
|
||||
assert classes in ('all', 'present') or mmcv.is_list_of(classes, int)
|
||||
if not per_image:
|
||||
assert reduction == 'none', "reduction should be 'none' when \
|
||||
per_image is False."
|
||||
|
||||
self.classes = classes
|
||||
self.per_image = per_image
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = class_weight
|
||||
|
||||
def forward(self,
|
||||
cls_score,
|
||||
label,
|
||||
weight=None,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs):
|
||||
"""Forward function."""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = cls_score.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
# if multi-class loss, transform logits to probs
|
||||
if self.cls_criterion == lovasz_softmax:
|
||||
cls_score = F.softmax(cls_score, dim=1)
|
||||
|
||||
loss_cls = self.loss_weight * self.cls_criterion(
|
||||
cls_score,
|
||||
label,
|
||||
self.classes,
|
||||
self.per_image,
|
||||
class_weight=class_weight,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
**kwargs)
|
||||
return loss_cls
|
|
@ -142,3 +142,63 @@ def test_accuracy():
|
|||
with pytest.raises(AssertionError):
|
||||
accuracy = Accuracy()
|
||||
accuracy(pred[:, :, None], true_label)
|
||||
|
||||
|
||||
def test_lovasz_loss():
|
||||
from mmseg.models import build_loss
|
||||
|
||||
# loss_type should be 'binary' or 'multi_class'
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='Binary',
|
||||
reduction='none',
|
||||
loss_weight=1.0)
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# reduction should be 'none' when per_image is False.
|
||||
with pytest.raises(AssertionError):
|
||||
loss_cfg = dict(type='LovaszLoss', loss_type='multi_class')
|
||||
build_loss(loss_cfg)
|
||||
|
||||
# test lovasz loss with loss_type = 'multi_class' and per_image = False
|
||||
loss_cfg = dict(type='LovaszLoss', reduction='none', loss_weight=1.0)
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(1, 3, 4, 4)
|
||||
labels = (torch.rand(1, 4, 4) * 2).long()
|
||||
lovasz_loss(logits, labels)
|
||||
|
||||
# test lovasz loss with loss_type = 'multi_class' and per_image = True
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
class_weight=[1.0, 2.0, 3.0],
|
||||
loss_weight=1.0)
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(1, 3, 4, 4)
|
||||
labels = (torch.rand(1, 4, 4) * 2).long()
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
||||
# test lovasz loss with loss_type = 'binary' and per_image = False
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='binary',
|
||||
reduction='none',
|
||||
loss_weight=1.0)
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(2, 4, 4)
|
||||
labels = (torch.rand(2, 4, 4)).long()
|
||||
lovasz_loss(logits, labels)
|
||||
|
||||
# test lovasz loss with loss_type = 'binary' and per_image = True
|
||||
loss_cfg = dict(
|
||||
type='LovaszLoss',
|
||||
loss_type='binary',
|
||||
per_image=True,
|
||||
reduction='mean',
|
||||
loss_weight=1.0)
|
||||
lovasz_loss = build_loss(loss_cfg)
|
||||
logits = torch.rand(2, 4, 4)
|
||||
labels = (torch.rand(2, 4, 4)).long()
|
||||
lovasz_loss(logits, labels, ignore_index=None)
|
||||
|
|
Loading…
Reference in New Issue