mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
## Motivation Support SOTA real-time semantic segmentation method in [Paper with code](https://paperswithcode.com/task/real-time-semantic-segmentation) Paper: https://arxiv.org/pdf/2206.02066.pdf Official repo: https://github.com/XuJiacong/PIDNet ## Current results **Cityscapes** |Model|Ref mIoU|mIoU (ours)| |---|---|---| |PIDNet-S|78.8|78.74| |PIDNet-M|79.9|80.22| |PIDNet-L|80.9|80.89| ## TODO - [x] Support inference with official weights - [x] Support training on Cityscapes - [x] Update docstring - [x] Add unit test
184 lines
6.7 KiB
Python
184 lines
6.7 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Optional, Tuple, Union
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
|
from mmengine.model import BaseModule
|
|
from torch import Tensor
|
|
|
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
|
from mmseg.models.losses import accuracy
|
|
from mmseg.models.utils import resize
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import OptConfigType, SampleList
|
|
|
|
|
|
class BasePIDHead(BaseModule):
|
|
"""Base class for PID head.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
channels (int): Number of output channels.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='ReLU', inplace=True).
|
|
init_cfg (dict or list[dict], optional): Init config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
channels: int,
|
|
norm_cfg: OptConfigType = dict(type='BN'),
|
|
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
|
init_cfg: OptConfigType = None):
|
|
super().__init__(init_cfg)
|
|
self.conv = ConvModule(
|
|
in_channels,
|
|
channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
order=('norm', 'act', 'conv'))
|
|
_, self.norm = build_norm_layer(norm_cfg, num_features=channels)
|
|
self.act = build_activation_layer(act_cfg)
|
|
|
|
def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor:
|
|
"""Forward function.
|
|
Args:
|
|
x (Tensor): Input tensor.
|
|
cls_seg (nn.Module, optional): The classification head.
|
|
|
|
Returns:
|
|
Tensor: Output tensor.
|
|
"""
|
|
x = self.conv(x)
|
|
x = self.norm(x)
|
|
x = self.act(x)
|
|
if cls_seg is not None:
|
|
x = cls_seg(x)
|
|
return x
|
|
|
|
|
|
@MODELS.register_module()
|
|
class PIDHead(BaseDecodeHead):
|
|
"""Decode head for PIDNet.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
channels (int): Number of output channels.
|
|
num_classes (int): Number of classes.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='ReLU', inplace=True).
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
channels: int,
|
|
num_classes: int,
|
|
norm_cfg: OptConfigType = dict(type='BN'),
|
|
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
|
**kwargs):
|
|
super().__init__(
|
|
in_channels,
|
|
channels,
|
|
num_classes=num_classes,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
**kwargs)
|
|
self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg)
|
|
self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg,
|
|
act_cfg)
|
|
self.d_head = BasePIDHead(
|
|
in_channels // 2,
|
|
in_channels // 4,
|
|
norm_cfg,
|
|
)
|
|
self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
|
|
self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1)
|
|
|
|
def init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(
|
|
m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(
|
|
self,
|
|
inputs: Union[Tensor,
|
|
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
|
"""Forward function.
|
|
Args:
|
|
inputs (Tensor | tuple[Tensor]): Input tensor or tuple of
|
|
Tensor. When training, the input is a tuple of three tensors,
|
|
(p_feat, i_feat, d_feat), and the output is a tuple of three
|
|
tensors, (p_seg_logit, i_seg_logit, d_seg_logit).
|
|
When inference, only the head of integral branch is used, and
|
|
input is a tensor of integral feature map, and the output is
|
|
the segmentation logit.
|
|
|
|
Returns:
|
|
Tensor | tuple[Tensor]: Output tensor or tuple of tensors.
|
|
"""
|
|
if self.training:
|
|
x_p, x_i, x_d = inputs
|
|
x_p = self.p_head(x_p, self.p_cls_seg)
|
|
x_i = self.i_head(x_i, self.cls_seg)
|
|
x_d = self.d_head(x_d, self.d_cls_seg)
|
|
return x_p, x_i, x_d
|
|
else:
|
|
return self.i_head(inputs, self.cls_seg)
|
|
|
|
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]:
|
|
gt_semantic_segs = [
|
|
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
|
|
]
|
|
gt_edge_segs = [
|
|
data_sample.gt_edge_map.data for data_sample in batch_data_samples
|
|
]
|
|
gt_sem_segs = torch.stack(gt_semantic_segs, dim=0)
|
|
gt_edge_segs = torch.stack(gt_edge_segs, dim=0)
|
|
return gt_sem_segs, gt_edge_segs
|
|
|
|
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
|
batch_data_samples: SampleList) -> dict:
|
|
loss = dict()
|
|
p_logit, i_logit, d_logit = seg_logits
|
|
sem_label, bd_label = self._stack_batch_gt(batch_data_samples)
|
|
p_logit = resize(
|
|
input=p_logit,
|
|
size=sem_label.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
i_logit = resize(
|
|
input=i_logit,
|
|
size=sem_label.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
d_logit = resize(
|
|
input=d_logit,
|
|
size=bd_label.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
sem_label = sem_label.squeeze(1)
|
|
bd_label = bd_label.squeeze(1)
|
|
loss['loss_sem_p'] = self.loss_decode[0](
|
|
p_logit, sem_label, ignore_index=self.ignore_index)
|
|
loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label)
|
|
loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label)
|
|
filler = torch.ones_like(sem_label) * self.ignore_index
|
|
sem_bd_label = torch.where(
|
|
torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler)
|
|
loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label)
|
|
loss['acc_seg'] = accuracy(
|
|
i_logit, sem_label, ignore_index=self.ignore_index)
|
|
return loss
|