mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
## Motivation Support SOTA real-time semantic segmentation method in [Paper with code](https://paperswithcode.com/task/real-time-semantic-segmentation) Paper: https://arxiv.org/pdf/2206.02066.pdf Official repo: https://github.com/XuJiacong/PIDNet ## Current results **Cityscapes** |Model|Ref mIoU|mIoU (ours)| |---|---|---| |PIDNet-S|78.8|78.74| |PIDNet-M|79.9|80.22| |PIDNet-L|80.9|80.89| ## TODO - [x] Support inference with official weights - [x] Support training on Cityscapes - [x] Update docstring - [x] Add unit test
194 lines
7.1 KiB
Python
194 lines
7.1 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Dict, List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule
|
|
from mmengine.model import BaseModule, ModuleList, Sequential
|
|
from torch import Tensor
|
|
|
|
|
|
class DAPPM(BaseModule):
|
|
"""DAPPM module in `DDRNet <https://arxiv.org/abs/2101.06085>`_.
|
|
|
|
Args:
|
|
in_channels (int): Input channels.
|
|
branch_channels (int): Branch channels.
|
|
out_channels (int): Output channels.
|
|
num_scales (int): Number of scales.
|
|
kernel_sizes (list[int]): Kernel sizes of each scale.
|
|
strides (list[int]): Strides of each scale.
|
|
paddings (list[int]): Paddings of each scale.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config dict for activation layer in ConvModule.
|
|
Default: dict(type='ReLU', inplace=True).
|
|
conv_cfg (dict): Config dict for convolution layer in ConvModule.
|
|
Default: dict(order=('norm', 'act', 'conv'), bias=False).
|
|
upsample_mode (str): Upsample mode. Default: 'bilinear'.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
branch_channels: int,
|
|
out_channels: int,
|
|
num_scales: int,
|
|
kernel_sizes: List[int] = [5, 9, 17],
|
|
strides: List[int] = [2, 4, 8],
|
|
paddings: List[int] = [2, 4, 8],
|
|
norm_cfg: Dict = dict(type='BN', momentum=0.1),
|
|
act_cfg: Dict = dict(type='ReLU', inplace=True),
|
|
conv_cfg: Dict = dict(
|
|
order=('norm', 'act', 'conv'), bias=False),
|
|
upsample_mode: str = 'bilinear'):
|
|
super().__init__()
|
|
|
|
self.num_scales = num_scales
|
|
self.unsample_mode = upsample_mode
|
|
self.in_channels = in_channels
|
|
self.branch_channels = branch_channels
|
|
self.out_channels = out_channels
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
self.conv_cfg = conv_cfg
|
|
|
|
self.scales = ModuleList([
|
|
ConvModule(
|
|
in_channels,
|
|
branch_channels,
|
|
kernel_size=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
**conv_cfg)
|
|
])
|
|
for i in range(1, num_scales - 1):
|
|
self.scales.append(
|
|
Sequential(*[
|
|
nn.AvgPool2d(
|
|
kernel_size=kernel_sizes[i - 1],
|
|
stride=strides[i - 1],
|
|
padding=paddings[i - 1]),
|
|
ConvModule(
|
|
in_channels,
|
|
branch_channels,
|
|
kernel_size=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
**conv_cfg)
|
|
]))
|
|
self.scales.append(
|
|
Sequential(*[
|
|
nn.AdaptiveAvgPool2d((1, 1)),
|
|
ConvModule(
|
|
in_channels,
|
|
branch_channels,
|
|
kernel_size=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
**conv_cfg)
|
|
]))
|
|
self.processes = ModuleList()
|
|
for i in range(num_scales - 1):
|
|
self.processes.append(
|
|
ConvModule(
|
|
branch_channels,
|
|
branch_channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
**conv_cfg))
|
|
|
|
self.compression = ConvModule(
|
|
branch_channels * num_scales,
|
|
out_channels,
|
|
kernel_size=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
**conv_cfg)
|
|
|
|
self.shortcut = ConvModule(
|
|
in_channels,
|
|
out_channels,
|
|
kernel_size=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
**conv_cfg)
|
|
|
|
def forward(self, inputs: Tensor):
|
|
feats = []
|
|
feats.append(self.scales[0](inputs))
|
|
|
|
for i in range(1, self.num_scales):
|
|
feat_up = F.interpolate(
|
|
self.scales[i](inputs),
|
|
size=inputs.shape[2:],
|
|
mode=self.unsample_mode)
|
|
feats.append(self.processes[i - 1](feat_up + feats[i - 1]))
|
|
|
|
return self.compression(torch.cat(feats,
|
|
dim=1)) + self.shortcut(inputs)
|
|
|
|
|
|
class PAPPM(DAPPM):
|
|
"""PAPPM module in `PIDNet <https://arxiv.org/abs/2206.02066>`_.
|
|
|
|
Args:
|
|
in_channels (int): Input channels.
|
|
branch_channels (int): Branch channels.
|
|
out_channels (int): Output channels.
|
|
num_scales (int): Number of scales.
|
|
kernel_sizes (list[int]): Kernel sizes of each scale.
|
|
strides (list[int]): Strides of each scale.
|
|
paddings (list[int]): Paddings of each scale.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN', momentum=0.1).
|
|
act_cfg (dict): Config dict for activation layer in ConvModule.
|
|
Default: dict(type='ReLU', inplace=True).
|
|
conv_cfg (dict): Config dict for convolution layer in ConvModule.
|
|
Default: dict(order=('norm', 'act', 'conv'), bias=False).
|
|
upsample_mode (str): Upsample mode. Default: 'bilinear'.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
branch_channels: int,
|
|
out_channels: int,
|
|
num_scales: int,
|
|
kernel_sizes: List[int] = [5, 9, 17],
|
|
strides: List[int] = [2, 4, 8],
|
|
paddings: List[int] = [2, 4, 8],
|
|
norm_cfg: Dict = dict(type='BN', momentum=0.1),
|
|
act_cfg: Dict = dict(type='ReLU', inplace=True),
|
|
conv_cfg: Dict = dict(
|
|
order=('norm', 'act', 'conv'), bias=False),
|
|
upsample_mode: str = 'bilinear'):
|
|
super().__init__(in_channels, branch_channels, out_channels,
|
|
num_scales, kernel_sizes, strides, paddings, norm_cfg,
|
|
act_cfg, conv_cfg, upsample_mode)
|
|
|
|
self.processes = ConvModule(
|
|
self.branch_channels * (self.num_scales - 1),
|
|
self.branch_channels * (self.num_scales - 1),
|
|
kernel_size=3,
|
|
padding=1,
|
|
groups=self.num_scales - 1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg,
|
|
**self.conv_cfg)
|
|
|
|
def forward(self, inputs: Tensor):
|
|
x_ = self.scales[0](inputs)
|
|
feats = []
|
|
for i in range(1, self.num_scales):
|
|
feat_up = F.interpolate(
|
|
self.scales[i](inputs),
|
|
size=inputs.shape[2:],
|
|
mode=self.unsample_mode,
|
|
align_corners=False)
|
|
feats.append(feat_up + x_)
|
|
scale_out = self.processes(torch.cat(feats, dim=1))
|
|
return self.compression(torch.cat([x_, scale_out],
|
|
dim=1)) + self.shortcut(inputs)
|