mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
## Motivation Support SOTA real-time semantic segmentation method in [Paper with code](https://paperswithcode.com/task/real-time-semantic-segmentation) Paper: https://arxiv.org/pdf/2206.02066.pdf Official repo: https://github.com/XuJiacong/PIDNet ## Current results **Cityscapes** |Model|Ref mIoU|mIoU (ours)| |---|---|---| |PIDNet-S|78.8|78.74| |PIDNet-M|79.9|80.22| |PIDNet-L|80.9|80.89| ## TODO - [x] Support inference with official weights - [x] Support training on Cityscapes - [x] Update docstring - [x] Add unit test
63 lines
1.8 KiB
Python
63 lines
1.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from torch import Tensor
|
|
|
|
from mmseg.registry import MODELS
|
|
|
|
|
|
@MODELS.register_module()
|
|
class BoundaryLoss(nn.Module):
|
|
"""Boundary loss.
|
|
|
|
This function is modified from
|
|
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L122>`_. # noqa
|
|
Licensed under the MIT License.
|
|
|
|
|
|
Args:
|
|
loss_weight (float): Weight of the loss. Defaults to 1.0.
|
|
loss_name (str): Name of the loss item. If you want this loss
|
|
item to be included into the backward graph, `loss_` must be the
|
|
prefix of the name. Defaults to 'loss_boundary'.
|
|
"""
|
|
|
|
def __init__(self,
|
|
loss_weight: float = 1.0,
|
|
loss_name: str = 'loss_boundary'):
|
|
super().__init__()
|
|
self.loss_weight = loss_weight
|
|
self.loss_name_ = loss_name
|
|
|
|
def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor:
|
|
"""Forward function.
|
|
Args:
|
|
bd_pre (Tensor): Predictions of the boundary head.
|
|
bd_gt (Tensor): Ground truth of the boundary.
|
|
|
|
Returns:
|
|
Tensor: Loss tensor.
|
|
"""
|
|
log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1)
|
|
target_t = bd_gt.view(1, -1).float()
|
|
|
|
pos_index = (target_t == 1)
|
|
neg_index = (target_t == 0)
|
|
|
|
weight = torch.zeros_like(log_p)
|
|
pos_num = pos_index.sum()
|
|
neg_num = neg_index.sum()
|
|
sum_num = pos_num + neg_num
|
|
weight[pos_index] = neg_num * 1.0 / sum_num
|
|
weight[neg_index] = pos_num * 1.0 / sum_num
|
|
|
|
loss = F.binary_cross_entropy_with_logits(
|
|
log_p, target_t, weight, reduction='mean')
|
|
|
|
return self.loss_weight * loss
|
|
|
|
@property
|
|
def loss_name(self):
|
|
return self.loss_name_
|