mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] huasdorff distance loss (#2820)
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Add Huasdorff distance loss --------- Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
parent
b2f4b4fe33
commit
bb93b482b8
@ -5,6 +5,7 @@ from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
|
||||
cross_entropy, mask_cross_entropy)
|
||||
from .dice_loss import DiceLoss
|
||||
from .focal_loss import FocalLoss
|
||||
from .huasdorff_distance_loss import HuasdorffDisstanceLoss
|
||||
from .lovasz_loss import LovaszLoss
|
||||
from .ohem_cross_entropy_loss import OhemCrossEntropy
|
||||
from .tversky_loss import TverskyLoss
|
||||
@ -14,5 +15,6 @@ __all__ = [
|
||||
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
|
||||
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
|
||||
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
|
||||
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss'
|
||||
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss',
|
||||
'HuasdorffDisstanceLoss'
|
||||
]
|
||||
|
160
mmseg/models/losses/huasdorff_distance_loss.py
Normal file
160
mmseg/models/losses/huasdorff_distance_loss.py
Normal file
@ -0,0 +1,160 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
"""Modified from https://github.com/JunMa11/SegWithDistMap/blob/
|
||||
master/code/train_LA_HD.py (Apache-2.0 License)"""
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from scipy.ndimage import distance_transform_edt as distance
|
||||
from torch import Tensor
|
||||
|
||||
from mmseg.registry import MODELS
|
||||
from .utils import get_class_weight, weighted_loss
|
||||
|
||||
|
||||
def compute_dtm(img_gt: Tensor, pred: Tensor) -> Tensor:
|
||||
"""
|
||||
compute the distance transform map of foreground in mask
|
||||
Args:
|
||||
img_gt: Ground truth of the image, (b, h, w)
|
||||
pred: Predictions of the segmentation head after softmax, (b, c, h, w)
|
||||
|
||||
Returns:
|
||||
output: the foreground Distance Map (SDM)
|
||||
dtm(x) = 0; x in segmentation boundary
|
||||
inf|x-y|; x in segmentation
|
||||
"""
|
||||
|
||||
fg_dtm = torch.zeros_like(pred)
|
||||
out_shape = pred.shape
|
||||
for b in range(out_shape[0]): # batch size
|
||||
for c in range(1, out_shape[1]): # default 0 channel is background
|
||||
posmask = img_gt[b].byte()
|
||||
if posmask.any():
|
||||
posdis = distance(posmask)
|
||||
fg_dtm[b][c] = torch.from_numpy(posdis)
|
||||
|
||||
return fg_dtm
|
||||
|
||||
|
||||
@weighted_loss
|
||||
def hd_loss(seg_soft: Tensor,
|
||||
gt: Tensor,
|
||||
seg_dtm: Tensor,
|
||||
gt_dtm: Tensor,
|
||||
class_weight=None,
|
||||
ignore_index=255) -> Tensor:
|
||||
"""
|
||||
compute huasdorff distance loss for segmentation
|
||||
Args:
|
||||
seg_soft: softmax results, shape=(b,c,x,y)
|
||||
gt: ground truth, shape=(b,x,y)
|
||||
seg_dtm: segmentation distance transform map, shape=(b,c,x,y)
|
||||
gt_dtm: ground truth distance transform map, shape=(b,c,x,y)
|
||||
|
||||
Returns:
|
||||
output: hd_loss
|
||||
"""
|
||||
assert seg_soft.shape[0] == gt.shape[0]
|
||||
total_loss = 0
|
||||
num_class = seg_soft.shape[1]
|
||||
if class_weight is not None:
|
||||
assert class_weight.ndim == num_class
|
||||
for i in range(1, num_class):
|
||||
if i != ignore_index:
|
||||
delta_s = (seg_soft[:, i, ...] - gt.float())**2
|
||||
s_dtm = seg_dtm[:, i, ...]**2
|
||||
g_dtm = gt_dtm[:, i, ...]**2
|
||||
dtm = s_dtm + g_dtm
|
||||
multiplied = torch.einsum('bxy, bxy->bxy', delta_s, dtm)
|
||||
hd_loss = multiplied.mean()
|
||||
if class_weight is not None:
|
||||
hd_loss *= class_weight[i]
|
||||
total_loss += hd_loss
|
||||
|
||||
return total_loss / num_class
|
||||
|
||||
|
||||
@MODELS.register_module()
|
||||
class HuasdorffDisstanceLoss(nn.Module):
|
||||
"""HuasdorffDisstanceLoss. This loss is proposed in `How Distance Transform
|
||||
Maps Boost Segmentation CNNs: An Empirical Study.
|
||||
|
||||
<http://proceedings.mlr.press/v121/ma20b.html>`_.
|
||||
Args:
|
||||
reduction (str, optional): The method used to reduce the loss into
|
||||
a scalar. Defaults to 'mean'.
|
||||
class_weight (list[float] | str, optional): Weight of each class. If in
|
||||
str format, read them from a file. Defaults to None.
|
||||
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
||||
ignore_index (int | None): The label index to be ignored. Default: 255.
|
||||
loss_name (str): Name of the loss item. If you want this loss
|
||||
item to be included into the backward graph, `loss_` must be the
|
||||
prefix of the name. Defaults to 'loss_boundary'.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
reduction='mean',
|
||||
class_weight=None,
|
||||
loss_weight=1.0,
|
||||
ignore_index=255,
|
||||
loss_name='loss_huasdorff_disstance',
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.reduction = reduction
|
||||
self.loss_weight = loss_weight
|
||||
self.class_weight = get_class_weight(class_weight)
|
||||
self._loss_name = loss_name
|
||||
self.ignore_index = ignore_index
|
||||
|
||||
def forward(self,
|
||||
pred: Tensor,
|
||||
target: Tensor,
|
||||
avg_factor=None,
|
||||
reduction_override=None,
|
||||
**kwargs) -> Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
pred (Tensor): Predictions of the segmentation head. (B, C, H, W)
|
||||
target (Tensor): Ground truth of the image. (B, H, W)
|
||||
avg_factor (int, optional): Average factor that is used to
|
||||
average the loss. Defaults to None.
|
||||
reduction_override (str, optional): The reduction method used
|
||||
to override the original reduction method of the loss.
|
||||
Options are "none", "mean" and "sum".
|
||||
Returns:
|
||||
Tensor: Loss tensor.
|
||||
"""
|
||||
assert reduction_override in (None, 'none', 'mean', 'sum')
|
||||
reduction = (
|
||||
reduction_override if reduction_override else self.reduction)
|
||||
if self.class_weight is not None:
|
||||
class_weight = pred.new_tensor(self.class_weight)
|
||||
else:
|
||||
class_weight = None
|
||||
|
||||
pred_soft = F.softmax(pred, dim=1)
|
||||
valid_mask = (target != self.ignore_index).long()
|
||||
target = target * valid_mask
|
||||
|
||||
with torch.no_grad():
|
||||
gt_dtm = compute_dtm(target.cpu(), pred_soft)
|
||||
gt_dtm = gt_dtm.float()
|
||||
seg_dtm2 = compute_dtm(
|
||||
pred_soft.argmax(dim=1, keepdim=False).cpu(), pred_soft)
|
||||
seg_dtm2 = seg_dtm2.float()
|
||||
|
||||
loss_hd = self.loss_weight * hd_loss(
|
||||
pred_soft,
|
||||
target,
|
||||
seg_dtm=seg_dtm2,
|
||||
gt_dtm=gt_dtm,
|
||||
reduction=reduction,
|
||||
avg_factor=avg_factor,
|
||||
class_weight=class_weight,
|
||||
ignore_index=self.ignore_index)
|
||||
return loss_hd
|
||||
|
||||
@property
|
||||
def loss_name(self):
|
||||
return self._loss_name
|
@ -0,0 +1,29 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.losses import HuasdorffDisstanceLoss
|
||||
|
||||
|
||||
def test_huasdorff_distance_loss():
|
||||
loss_class = HuasdorffDisstanceLoss
|
||||
pred = torch.rand((10, 8, 6, 6))
|
||||
target = torch.rand((10, 6, 6))
|
||||
class_weight = torch.rand(8)
|
||||
|
||||
# Test loss forward
|
||||
loss = loss_class()(pred, target)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with avg_factor
|
||||
loss = loss_class()(pred, target, avg_factor=10)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with avg_factor and reduction is None, 'sum' and 'mean'
|
||||
for reduction in [None, 'sum', 'mean']:
|
||||
loss = loss_class()(pred, target, avg_factor=10, reduction=reduction)
|
||||
assert isinstance(loss, torch.Tensor)
|
||||
|
||||
# Test loss forward with class_weight
|
||||
with pytest.raises(AssertionError):
|
||||
loss_class(class_weight=class_weight)(pred, target)
|
Loading…
x
Reference in New Issue
Block a user