[Feature] Support PIDNet (#2609)

## Motivation

Support SOTA real-time semantic segmentation method in [Paper with
code](https://paperswithcode.com/task/real-time-semantic-segmentation)

Paper: https://arxiv.org/pdf/2206.02066.pdf
Official repo: https://github.com/XuJiacong/PIDNet

## Current results

**Cityscapes**

|Model|Ref mIoU|mIoU (ours)|
|---|---|---|
|PIDNet-S|78.8|78.74|
|PIDNet-M|79.9|80.22|
|PIDNet-L|80.9|80.89|

## TODO

- [x] Support inference with official weights
- [x] Support training on Cityscapes
- [x] Update docstring
- [x] Add unit test
pull/2754/head
谢昕辰 2023-03-15 14:55:30 +08:00 committed by GitHub
parent 8c89ff3dd1
commit dd47cef801
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
20 changed files with 1646 additions and 4 deletions

View File

@ -159,6 +159,7 @@ Results and models are available in the [model zoo](docs/en/model_zoo.md).
- [x] [K-Net (NeurIPS'2021)](configs/knet)
- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer)
- [x] [Mask2Former (CVPR'2022)](configs/mask2former)
- [x] [PIDNet (ArXiv'2022)](configs/pidnet)
</details>

View File

@ -140,6 +140,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [K-Net (NeurIPS'2021)](configs/knet)
- [x] [MaskFormer (NeurIPS'2021)](configs/maskformer)
- [x] [Mask2Former (CVPR'2022)](configs/mask2former)
- [x] [PIDNet (ArXiv'2022)](configs/pidnet)
</details>

View File

@ -0,0 +1,50 @@
# PIDNet
> [PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller](https://arxiv.org/pdf/2206.02066.pdf)
## Introduction
<!-- [ALGORITHM] -->
<a href="https://github.com/XuJiacong/PIDNet">Official Repo</a>
<a href="https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/pidnet.py">Code Snippet</a>
## Abstract
<!-- [ABSTRACT] -->
Two-branch network architecture has shown its efficiency and effectiveness for real-time semantic segmentation tasks. However, direct fusion of low-level details and high-level semantics will lead to a phenomenon that the detailed features are easily overwhelmed by surrounding contextual information, namely overshoot in this paper, which limits the improvement of the accuracy of existed two-branch models. In this paper, we bridge a connection between Convolutional Neural Network (CNN) and Proportional-IntegralDerivative (PID) controller and reveal that the two-branch network is nothing but a Proportional-Integral (PI) controller, which inherently suffers from the similar overshoot issue. To alleviate this issue, we propose a novel threebranch network architecture: PIDNet, which possesses three branches to parse the detailed, context and boundary information (derivative of semantics), respectively, and employs boundary attention to guide the fusion of detailed and context branches in final stage. The family of PIDNets achieve the best trade-off between inference speed and accuracy and their test accuracy surpasses all the existed models with similar inference speed on Cityscapes, CamVid and COCO-Stuff datasets. Especially, PIDNet-S achieves 78.6% mIOU with inference speed of 93.2 FPS on Cityscapes test set and 80.1% mIOU with speed of 153.7 FPS on CamVid test set.
<!-- [IMAGE] -->
<div align=center>
<img src="https://raw.githubusercontent.com/XuJiacong/PIDNet/main/figs/pidnet.jpg" width="800"/>
</div>
## Results and models
### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------- | -------- | -------------- | ----- | ------------- | ----------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| PIDNet | PIDNet-S | 1024x1024 | 120000 | 3.38 | 80.82 | 78.74 | 80.87 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700-bb8e3bcc.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700.json) |
| PIDNet | PIDNet-M | 1024x1024 | 120000 | 5.14 | 71.98 | 80.22 | 82.05 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452-f9bcdbf3.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452.json) |
| PIDNet | PIDNet-L | 1024x1024 | 120000 | 5.83 | 60.06 | 80.89 | 82.37 | [config](https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/configs/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514-0783ca6b.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514.json) |
## Notes
The pretrained weights in config files are converted from [the official repo](https://github.com/XuJiacong/PIDNet#models).
## Citation
```bibtex
@misc{xu2022pidnet,
title={PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller},
author={Jiacong Xu and Zixiang Xiong and Shankar P. Bhattacharyya},
year={2022},
eprint={2206.02066},
archivePrefix={arXiv},
primaryClass={cs.CV}
}
```

View File

@ -0,0 +1,10 @@
_base_ = './pidnet-s_2xb6-120k_1024x1024-cityscapes.py'
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-l_imagenet1k_20230306-67889109.pth' # noqa
model = dict(
backbone=dict(
channels=64,
ppm_channels=112,
num_stem_blocks=3,
num_branch_blocks=4,
init_cfg=dict(checkpoint=checkpoint_file)),
decode_head=dict(in_channels=256, channels=256))

View File

@ -0,0 +1,5 @@
_base_ = './pidnet-s_2xb6-120k_1024x1024-cityscapes.py'
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-m_imagenet1k_20230306-39893c52.pth' # noqa
model = dict(
backbone=dict(channels=64, init_cfg=dict(checkpoint=checkpoint_file)),
decode_head=dict(in_channels=256))

View File

@ -0,0 +1,113 @@
_base_ = [
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py'
]
# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
# Licensed under the MIT License
class_weight = [
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
1.0507
]
checkpoint_file = 'https://download.openmmlab.com/mmsegmentation/v0.5/pretrain/pidnet/pidnet-s_imagenet1k_20230306-715e6273.pth' # noqa
crop_size = (1024, 1024)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255,
size=crop_size)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='PIDNet',
in_channels=3,
channels=32,
ppm_channels=96,
num_stem_blocks=2,
num_branch_blocks=3,
align_corners=False,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True),
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file)),
decode_head=dict(
type='PIDHead',
in_channels=128,
channels=128,
num_classes=19,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True),
align_corners=True,
loss_decode=[
dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=class_weight,
loss_weight=0.4),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=1.0),
dict(type='BoundaryLoss', loss_weight=20.0),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=1.0)
]),
train_cfg=dict(),
test_cfg=dict(mode='whole'))
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(
type='RandomResize',
scale=(2048, 1024),
ratio_range=(0.5, 2.0),
keep_ratio=True),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='GenerateEdge', edge_width=4),
dict(type='PackSegInputs')
]
train_dataloader = dict(batch_size=6, dataset=dict(pipeline=train_pipeline))
iters = 120000
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=0,
power=0.9,
begin=0,
end=iters,
by_epoch=False)
]
# training schedule for 120k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=iters // 10),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
randomness = dict(seed=304)

View File

@ -0,0 +1,81 @@
Collections:
- Name: PIDNet
Metadata:
Training Data:
- Cityscapes
Paper:
URL: https://arxiv.org/pdf/2206.02066.pdf
Title: 'PIDNet: A Real-time Semantic Segmentation Network Inspired from PID Controller'
README: configs/pidnet/README.md
Code:
URL: https://github.com/open-mmlab/mmsegmentation/blob/dev-1.x/mmseg/models/backbones/pidnet.py
Version: dev-1.x
Converted From:
Code: https://github.com/XuJiacong/PIDNet
Models:
- Name: pidnet-s_2xb6-120k_1024x1024-cityscapes
In Collection: PIDNet
Metadata:
backbone: PIDNet-S
crop size: (1024,1024)
lr schd: 120000
inference time (ms/im):
- value: 12.37
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1024,1024)
Training Memory (GB): 3.38
Results:
- Task: Semantic Segmentation
Dataset: Cityscapes
Metrics:
mIoU: 78.74
mIoU(ms+flip): 80.87
Config: configs/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-s_2xb6-120k_1024x1024-cityscapes/pidnet-s_2xb6-120k_1024x1024-cityscapes_20230302_191700-bb8e3bcc.pth
- Name: pidnet-m_2xb6-120k_1024x1024-cityscapes
In Collection: PIDNet
Metadata:
backbone: PIDNet-M
crop size: (1024,1024)
lr schd: 120000
inference time (ms/im):
- value: 13.89
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1024,1024)
Training Memory (GB): 5.14
Results:
- Task: Semantic Segmentation
Dataset: Cityscapes
Metrics:
mIoU: 80.22
mIoU(ms+flip): 82.05
Config: configs/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-m_2xb6-120k_1024x1024-cityscapes/pidnet-m_2xb6-120k_1024x1024-cityscapes_20230301_143452-f9bcdbf3.pth
- Name: pidnet-l_2xb6-120k_1024x1024-cityscapes
In Collection: PIDNet
Metadata:
backbone: PIDNet-L
crop size: (1024,1024)
lr schd: 120000
inference time (ms/im):
- value: 16.65
hardware: V100
backend: PyTorch
batch size: 1
mode: FP32
resolution: (1024,1024)
Training Memory (GB): 5.83
Results:
- Task: Semantic Segmentation
Dataset: Cityscapes
Metrics:
mIoU: 80.89
mIoU(ms+flip): 82.37
Config: configs/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes.py
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/pidnet/pidnet-l_2xb6-120k_1024x1024-cityscapes/pidnet-l_2xb6-120k_1024x1024-cityscapes_20230303_114514-0783ca6b.pth

View File

@ -11,6 +11,7 @@ from .mae import MAE
from .mit import MixVisionTransformer
from .mobilenet_v2 import MobileNetV2
from .mobilenet_v3 import MobileNetV3
from .pidnet import PIDNet
from .resnest import ResNeSt
from .resnet import ResNet, ResNetV1c, ResNetV1d
from .resnext import ResNeXt
@ -26,5 +27,5 @@ __all__ = [
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE'
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet'
]

View File

@ -0,0 +1,522 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from mmengine.runner import CheckpointLoader
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType
from ..utils import DAPPM, PAPPM, BasicBlock, Bottleneck
class PagFM(BaseModule):
"""Pixel-attention-guided fusion module.
Args:
in_channels (int): The number of input channels.
channels (int): The number of channels.
after_relu (bool): Whether to use ReLU before attention.
Default: False.
with_channel (bool): Whether to use channel attention.
Default: False.
upsample_mode (str): The mode of upsample. Default: 'bilinear'.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(typ='ReLU', inplace=True).
init_cfg (dict): Config dict for initialization. Default: None.
"""
def __init__(self,
in_channels: int,
channels: int,
after_relu: bool = False,
with_channel: bool = False,
upsample_mode: str = 'bilinear',
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(typ='ReLU', inplace=True),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.after_relu = after_relu
self.with_channel = with_channel
self.upsample_mode = upsample_mode
self.f_i = ConvModule(
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
self.f_p = ConvModule(
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=None)
if with_channel:
self.up = ConvModule(
channels, in_channels, 1, norm_cfg=norm_cfg, act_cfg=None)
if after_relu:
self.relu = MODELS.build(act_cfg)
def forward(self, x_p: Tensor, x_i: Tensor) -> Tensor:
"""Forward function.
Args:
x_p (Tensor): The featrue map from P branch.
x_i (Tensor): The featrue map from I branch.
Returns:
Tensor: The feature map with pixel-attention-guided fusion.
"""
if self.after_relu:
x_p = self.relu(x_p)
x_i = self.relu(x_i)
f_i = self.f_i(x_i)
f_i = F.interpolate(
f_i,
size=x_p.shape[2:],
mode=self.upsample_mode,
align_corners=False)
f_p = self.f_p(x_p)
if self.with_channel:
sigma = torch.sigmoid(self.up(f_p * f_i))
else:
sigma = torch.sigmoid(torch.sum(f_p * f_i, dim=1).unsqueeze(1))
x_i = F.interpolate(
x_i,
size=x_p.shape[2:],
mode=self.upsample_mode,
align_corners=False)
out = sigma * x_i + (1 - sigma) * x_p
return out
class Bag(BaseModule):
"""Boundary-attention-guided fusion module.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
kernel_size (int): The kernel size of the convolution. Default: 3.
padding (int): The padding of the convolution. Default: 1.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
conv_cfg (dict): Config dict for convolution layer.
Default: dict(order=('norm', 'act', 'conv')).
init_cfg (dict): Config dict for initialization. Default: None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
kernel_size: int = 3,
padding: int = 1,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
conv_cfg: OptConfigType = dict(order=('norm', 'act', 'conv')),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.conv = ConvModule(
in_channels,
out_channels,
kernel_size,
padding=padding,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
"""Forward function.
Args:
x_p (Tensor): The featrue map from P branch.
x_i (Tensor): The featrue map from I branch.
x_d (Tensor): The featrue map from D branch.
Returns:
Tensor: The feature map with boundary-attention-guided fusion.
"""
sigma = torch.sigmoid(x_d)
return self.conv(sigma * x_p + (1 - sigma) * x_i)
class LightBag(BaseModule):
"""Light Boundary-attention-guided fusion module.
Args:
in_channels (int): The number of input channels.
out_channels (int): The number of output channels.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer. Default: None.
init_cfg (dict): Config dict for initialization. Default: None.
"""
def __init__(self,
in_channels: int,
out_channels: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = None,
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.f_p = ConvModule(
in_channels,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.f_i = ConvModule(
in_channels,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x_p: Tensor, x_i: Tensor, x_d: Tensor) -> Tensor:
"""Forward function.
Args:
x_p (Tensor): The featrue map from P branch.
x_i (Tensor): The featrue map from I branch.
x_d (Tensor): The featrue map from D branch.
Returns:
Tensor: The feature map with light boundary-attention-guided
fusion.
"""
sigma = torch.sigmoid(x_d)
f_p = self.f_p((1 - sigma) * x_i + x_p)
f_i = self.f_i(x_i + sigma * x_p)
return f_p + f_i
@MODELS.register_module()
class PIDNet(BaseModule):
"""PIDNet backbone.
This backbone is the implementation of `PIDNet: A Real-time Semantic
Segmentation Network Inspired from PID Controller
<https://arxiv.org/abs/2206.02066>`_.
Modified from https://github.com/XuJiacong/PIDNet.
Licensed under the MIT License.
Args:
in_channels (int): The number of input channels. Default: 3.
channels (int): The number of channels in the stem layer. Default: 64.
ppm_channels (int): The number of channels in the PPM layer.
Default: 96.
num_stem_blocks (int): The number of blocks in the stem layer.
Default: 2.
num_branch_blocks (int): The number of blocks in the branch layer.
Default: 3.
align_corners (bool): The align_corners argument of F.interpolate.
Default: False.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
init_cfg (dict): Config dict for initialization. Default: None.
"""
def __init__(self,
in_channels: int = 3,
channels: int = 64,
ppm_channels: int = 96,
num_stem_blocks: int = 2,
num_branch_blocks: int = 3,
align_corners: bool = False,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
init_cfg: OptConfigType = None,
**kwargs):
super().__init__(init_cfg)
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
# stem layer
self.stem = self._make_stem_layer(in_channels, channels,
num_stem_blocks)
self.relu = nn.ReLU()
# I Branch
self.i_branch_layers = nn.ModuleList()
for i in range(3):
self.i_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
in_channels=channels * 2**(i + 1),
channels=channels * 8 if i > 0 else channels * 4,
num_blocks=num_branch_blocks if i < 2 else 2,
stride=2))
# P Branch
self.p_branch_layers = nn.ModuleList()
for i in range(3):
self.p_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
in_channels=channels * 2,
channels=channels * 2,
num_blocks=num_stem_blocks if i < 2 else 1))
self.compression_1 = ConvModule(
channels * 4,
channels * 2,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None)
self.compression_2 = ConvModule(
channels * 8,
channels * 2,
kernel_size=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None)
self.pag_1 = PagFM(channels * 2, channels)
self.pag_2 = PagFM(channels * 2, channels)
# D Branch
if num_stem_blocks == 2:
self.d_branch_layers = nn.ModuleList([
self._make_single_layer(BasicBlock, channels * 2, channels),
self._make_layer(Bottleneck, channels, channels, 1)
])
channel_expand = 1
spp_module = PAPPM
dfm_module = LightBag
act_cfg_dfm = None
else:
self.d_branch_layers = nn.ModuleList([
self._make_single_layer(BasicBlock, channels * 2,
channels * 2),
self._make_single_layer(BasicBlock, channels * 2, channels * 2)
])
channel_expand = 2
spp_module = DAPPM
dfm_module = Bag
act_cfg_dfm = act_cfg
self.diff_1 = ConvModule(
channels * 4,
channels * channel_expand,
kernel_size=3,
padding=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None)
self.diff_2 = ConvModule(
channels * 8,
channels * 2,
kernel_size=3,
padding=1,
bias=False,
norm_cfg=norm_cfg,
act_cfg=None)
self.spp = spp_module(
channels * 16, ppm_channels, channels * 4, num_scales=5)
self.dfm = dfm_module(
channels * 4, channels * 4, norm_cfg=norm_cfg, act_cfg=act_cfg_dfm)
self.d_branch_layers.append(
self._make_layer(Bottleneck, channels * 2, channels * 2, 1))
def _make_stem_layer(self, in_channels: int, channels: int,
num_blocks: int) -> nn.Sequential:
"""Make stem layer.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_blocks (int): Number of blocks.
Returns:
nn.Sequential: The stem layer.
"""
layers = [
ConvModule(
in_channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
]
layers.append(
self._make_layer(BasicBlock, channels, channels, num_blocks))
layers.append(nn.ReLU())
layers.append(
self._make_layer(
BasicBlock, channels, channels * 2, num_blocks, stride=2))
layers.append(nn.ReLU())
return nn.Sequential(*layers)
def _make_layer(self,
block: BasicBlock,
in_channels: int,
channels: int,
num_blocks: int,
stride: int = 1) -> nn.Sequential:
"""Make layer for PIDNet backbone.
Args:
block (BasicBlock): Basic block.
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_blocks (int): Number of blocks.
stride (int): Stride of the first block. Default: 1.
Returns:
nn.Sequential: The Branch Layer.
"""
downsample = None
if stride != 1 or in_channels != channels * block.expansion:
downsample = ConvModule(
in_channels,
channels * block.expansion,
kernel_size=1,
stride=stride,
norm_cfg=self.norm_cfg,
act_cfg=None)
layers = [block(in_channels, channels, stride, downsample)]
in_channels = channels * block.expansion
for i in range(1, num_blocks):
layers.append(
block(
in_channels,
channels,
stride=1,
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
return nn.Sequential(*layers)
def _make_single_layer(self,
block: Union[BasicBlock, Bottleneck],
in_channels: int,
channels: int,
stride: int = 1) -> nn.Module:
"""Make single layer for PIDNet backbone.
Args:
block (BasicBlock or Bottleneck): Basic block or Bottleneck.
in_channels (int): Number of input channels.
channels (int): Number of output channels.
stride (int): Stride of the first block. Default: 1.
Returns:
nn.Module
"""
downsample = None
if stride != 1 or in_channels != channels * block.expansion:
downsample = ConvModule(
in_channels,
channels * block.expansion,
kernel_size=1,
stride=stride,
norm_cfg=self.norm_cfg,
act_cfg=None)
return block(
in_channels, channels, stride, downsample, act_cfg_out=None)
def init_weights(self):
"""Initialize the weights in backbone.
Since the D branch is not initialized by the pre-trained model, we
initialize it with the same method as the ResNet.
"""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
if self.init_cfg is not None:
assert 'checkpoint' in self.init_cfg, f'Only support ' \
f'specify `Pretrained` in ' \
f'`init_cfg` in ' \
f'{self.__class__.__name__} '
ckpt = CheckpointLoader.load_checkpoint(
self.init_cfg['checkpoint'], map_location='cpu')
self.load_state_dict(ckpt, strict=False)
def forward(self, x: Tensor) -> Union[Tensor, Tuple[Tensor]]:
"""Forward function.
Args:
x (Tensor): Input tensor with shape (B, C, H, W).
Returns:
Tensor or tuple[Tensor]: If self.training is True, return
tuple[Tensor], else return Tensor.
"""
w_out = x.shape[-1] // 8
h_out = x.shape[-2] // 8
# stage 0-2
x = self.stem(x)
# stage 3
x_i = self.relu(self.i_branch_layers[0](x))
x_p = self.p_branch_layers[0](x)
x_d = self.d_branch_layers[0](x)
comp_i = self.compression_1(x_i)
x_p = self.pag_1(x_p, comp_i)
diff_i = self.diff_1(x_i)
x_d += F.interpolate(
diff_i,
size=[h_out, w_out],
mode='bilinear',
align_corners=self.align_corners)
if self.training:
temp_p = x_p.clone()
# stage 4
x_i = self.relu(self.i_branch_layers[1](x_i))
x_p = self.p_branch_layers[1](self.relu(x_p))
x_d = self.d_branch_layers[1](self.relu(x_d))
comp_i = self.compression_2(x_i)
x_p = self.pag_2(x_p, comp_i)
diff_i = self.diff_2(x_i)
x_d += F.interpolate(
diff_i,
size=[h_out, w_out],
mode='bilinear',
align_corners=self.align_corners)
if self.training:
temp_d = x_d.clone()
# stage 5
x_i = self.i_branch_layers[2](x_i)
x_p = self.p_branch_layers[2](self.relu(x_p))
x_d = self.d_branch_layers[2](self.relu(x_d))
x_i = self.spp(x_i)
x_i = F.interpolate(
x_i,
size=[h_out, w_out],
mode='bilinear',
align_corners=self.align_corners)
out = self.dfm(x_p, x_i, x_d)
return (temp_p, out, temp_d) if self.training else out

View File

@ -19,6 +19,7 @@ from .mask2former_head import Mask2FormerHead
from .maskformer_head import MaskFormerHead
from .nl_head import NLHead
from .ocr_head import OCRHead
from .pid_head import PIDHead
from .point_head import PointHead
from .psa_head import PSAHead
from .psp_head import PSPHead
@ -38,5 +39,6 @@ __all__ = [
'PointHead', 'APCHead', 'DMHead', 'LRASPPHead', 'SETRUPHead',
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead'
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
'PIDHead'
]

View File

@ -0,0 +1,183 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple, Union
import torch
import torch.nn as nn
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from mmengine.model import BaseModule
from torch import Tensor
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.losses import accuracy
from mmseg.models.utils import resize
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType, SampleList
class BasePIDHead(BaseModule):
"""Base class for PID head.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
init_cfg (dict or list[dict], optional): Init config dict.
Default: None.
"""
def __init__(self,
in_channels: int,
channels: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.conv = ConvModule(
in_channels,
channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
order=('norm', 'act', 'conv'))
_, self.norm = build_norm_layer(norm_cfg, num_features=channels)
self.act = build_activation_layer(act_cfg)
def forward(self, x: Tensor, cls_seg: Optional[nn.Module]) -> Tensor:
"""Forward function.
Args:
x (Tensor): Input tensor.
cls_seg (nn.Module, optional): The classification head.
Returns:
Tensor: Output tensor.
"""
x = self.conv(x)
x = self.norm(x)
x = self.act(x)
if cls_seg is not None:
x = cls_seg(x)
return x
@MODELS.register_module()
class PIDHead(BaseDecodeHead):
"""Decode head for PIDNet.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_classes (int): Number of classes.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
"""
def __init__(self,
in_channels: int,
channels: int,
num_classes: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
**kwargs):
super().__init__(
in_channels,
channels,
num_classes=num_classes,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.i_head = BasePIDHead(in_channels, channels, norm_cfg, act_cfg)
self.p_head = BasePIDHead(in_channels // 2, channels, norm_cfg,
act_cfg)
self.d_head = BasePIDHead(
in_channels // 2,
in_channels // 4,
norm_cfg,
)
self.p_cls_seg = nn.Conv2d(channels, self.out_channels, kernel_size=1)
self.d_cls_seg = nn.Conv2d(in_channels // 4, 1, kernel_size=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(
self,
inputs: Union[Tensor,
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
"""Forward function.
Args:
inputs (Tensor | tuple[Tensor]): Input tensor or tuple of
Tensor. When training, the input is a tuple of three tensors,
(p_feat, i_feat, d_feat), and the output is a tuple of three
tensors, (p_seg_logit, i_seg_logit, d_seg_logit).
When inference, only the head of integral branch is used, and
input is a tensor of integral feature map, and the output is
the segmentation logit.
Returns:
Tensor | tuple[Tensor]: Output tensor or tuple of tensors.
"""
if self.training:
x_p, x_i, x_d = inputs
x_p = self.p_head(x_p, self.p_cls_seg)
x_i = self.i_head(x_i, self.cls_seg)
x_d = self.d_head(x_d, self.d_cls_seg)
return x_p, x_i, x_d
else:
return self.i_head(inputs, self.cls_seg)
def _stack_batch_gt(self, batch_data_samples: SampleList) -> Tuple[Tensor]:
gt_semantic_segs = [
data_sample.gt_sem_seg.data for data_sample in batch_data_samples
]
gt_edge_segs = [
data_sample.gt_edge_map.data for data_sample in batch_data_samples
]
gt_sem_segs = torch.stack(gt_semantic_segs, dim=0)
gt_edge_segs = torch.stack(gt_edge_segs, dim=0)
return gt_sem_segs, gt_edge_segs
def loss_by_feat(self, seg_logits: Tuple[Tensor],
batch_data_samples: SampleList) -> dict:
loss = dict()
p_logit, i_logit, d_logit = seg_logits
sem_label, bd_label = self._stack_batch_gt(batch_data_samples)
p_logit = resize(
input=p_logit,
size=sem_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
i_logit = resize(
input=i_logit,
size=sem_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
d_logit = resize(
input=d_logit,
size=bd_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
sem_label = sem_label.squeeze(1)
bd_label = bd_label.squeeze(1)
loss['loss_sem_p'] = self.loss_decode[0](
p_logit, sem_label, ignore_index=self.ignore_index)
loss['loss_sem_i'] = self.loss_decode[1](i_logit, sem_label)
loss['loss_bd'] = self.loss_decode[2](d_logit, bd_label)
filler = torch.ones_like(sem_label) * self.ignore_index
sem_bd_label = torch.where(
torch.sigmoid(d_logit[:, 0, :, :]) > 0.8, sem_label, filler)
loss['loss_sem_bd'] = self.loss_decode[3](i_logit, sem_bd_label)
loss['acc_seg'] = accuracy(
i_logit, sem_label, ignore_index=self.ignore_index)
return loss

View File

@ -1,10 +1,12 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .accuracy import Accuracy, accuracy
from .boundary_loss import BoundaryLoss
from .cross_entropy_loss import (CrossEntropyLoss, binary_cross_entropy,
cross_entropy, mask_cross_entropy)
from .dice_loss import DiceLoss
from .focal_loss import FocalLoss
from .lovasz_loss import LovaszLoss
from .ohem_cross_entropy_loss import OhemCrossEntropy
from .tversky_loss import TverskyLoss
from .utils import reduce_loss, weight_reduce_loss, weighted_loss
@ -12,5 +14,5 @@ __all__ = [
'accuracy', 'Accuracy', 'cross_entropy', 'binary_cross_entropy',
'mask_cross_entropy', 'CrossEntropyLoss', 'reduce_loss',
'weight_reduce_loss', 'weighted_loss', 'LovaszLoss', 'DiceLoss',
'FocalLoss', 'TverskyLoss'
'FocalLoss', 'TverskyLoss', 'OhemCrossEntropy', 'BoundaryLoss'
]

View File

@ -0,0 +1,62 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmseg.registry import MODELS
@MODELS.register_module()
class BoundaryLoss(nn.Module):
"""Boundary loss.
This function is modified from
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L122>`_. # noqa
Licensed under the MIT License.
Args:
loss_weight (float): Weight of the loss. Defaults to 1.0.
loss_name (str): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_boundary'.
"""
def __init__(self,
loss_weight: float = 1.0,
loss_name: str = 'loss_boundary'):
super().__init__()
self.loss_weight = loss_weight
self.loss_name_ = loss_name
def forward(self, bd_pre: Tensor, bd_gt: Tensor) -> Tensor:
"""Forward function.
Args:
bd_pre (Tensor): Predictions of the boundary head.
bd_gt (Tensor): Ground truth of the boundary.
Returns:
Tensor: Loss tensor.
"""
log_p = bd_pre.permute(0, 2, 3, 1).contiguous().view(1, -1)
target_t = bd_gt.view(1, -1).float()
pos_index = (target_t == 1)
neg_index = (target_t == 0)
weight = torch.zeros_like(log_p)
pos_num = pos_index.sum()
neg_num = neg_index.sum()
sum_num = pos_num + neg_num
weight[pos_index] = neg_num * 1.0 / sum_num
weight[neg_index] = pos_num * 1.0 / sum_num
loss = F.binary_cross_entropy_with_logits(
log_p, target_t, weight, reduction='mean')
return self.loss_weight * loss
@property
def loss_name(self):
return self.loss_name_

View File

@ -0,0 +1,94 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Optional, Union
import torch.nn as nn
import torch.nn.functional as F
from torch import Tensor
from mmseg.registry import MODELS
@MODELS.register_module()
class OhemCrossEntropy(nn.Module):
"""OhemCrossEntropy loss.
This func is modified from
`PIDNet <https://github.com/XuJiacong/PIDNet/blob/main/utils/criterion.py#L43>`_. # noqa
Licensed under the MIT License.
Args:
ignore_label (int): Labels to ignore when computing the loss.
Default: 255
thresh (float, optional): The threshold for hard example selection.
Below which, are prediction with low confidence. If not
specified, the hard examples will be pixels of top ``min_kept``
loss. Default: 0.7.
min_kept (int, optional): The minimum number of predictions to keep.
Default: 100000.
loss_weight (float): Weight of the loss. Defaults to 1.0.
class_weight (list[float] | str, optional): Weight of each class. If in
str format, read them from a file. Defaults to None.
loss_name (str): Name of the loss item. If you want this loss
item to be included into the backward graph, `loss_` must be the
prefix of the name. Defaults to 'loss_boundary'.
"""
def __init__(self,
ignore_label: int = 255,
thres: float = 0.7,
min_kept: int = 100000,
loss_weight: float = 1.0,
class_weight: Optional[Union[List[float], str]] = None,
loss_name: str = 'loss_ohem'):
super().__init__()
self.thresh = thres
self.min_kept = max(1, min_kept)
self.ignore_label = ignore_label
self.loss_weight = loss_weight
self.loss_name_ = loss_name
self.class_weight = class_weight
def forward(self, score: Tensor, target: Tensor) -> Tensor:
"""Forward function.
Args:
score (Tensor): Predictions of the segmentation head.
target (Tensor): Ground truth of the image.
Returns:
Tensor: Loss tensor.
"""
# score: (N, C, H, W)
pred = F.softmax(score, dim=1)
if self.class_weight is not None:
class_weight = score.new_tensor(self.class_weight)
else:
class_weight = None
pixel_losses = F.cross_entropy(
score,
target,
weight=class_weight,
ignore_index=self.ignore_label,
reduction='none').contiguous().view(-1) # (N*H*W)
mask = target.contiguous().view(-1) != self.ignore_label # (N*H*W)
tmp_target = target.clone() # (N, H, W)
tmp_target[tmp_target == self.ignore_label] = 0
# pred: (N, C, H, W) -> (N*H*W, C)
pred = pred.gather(1, tmp_target.unsqueeze(1))
# pred: (N*H*W, C) -> (N*H*W), ind: (N*H*W)
pred, ind = pred.contiguous().view(-1, )[mask].contiguous().sort()
if pred.numel() > 0:
min_value = pred[min(self.min_kept, pred.numel() - 1)]
else:
return score.new_tensor(0.0)
threshold = max(min_value, self.thresh)
pixel_losses = pixel_losses[mask][ind]
pixel_losses = pixel_losses[pred < threshold]
return self.loss_weight * pixel_losses.mean()
@property
def loss_name(self):
return self.loss_name_

View File

@ -1,8 +1,10 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .basic_block import BasicBlock, Bottleneck
from .embed import PatchEmbed
from .encoding import Encoding
from .inverted_residual import InvertedResidual, InvertedResidualV3
from .make_divisible import make_divisible
from .ppm import DAPPM, PAPPM
from .res_layer import ResLayer
from .se_layer import SELayer
from .self_attention_block import SelfAttentionBlock
@ -15,5 +17,5 @@ __all__ = [
'ResLayer', 'SelfAttentionBlock', 'make_divisible', 'InvertedResidual',
'UpConvBlock', 'InvertedResidualV3', 'SELayer', 'PatchEmbed',
'nchw_to_nlc', 'nlc_to_nchw', 'nchw2nlc2nchw', 'nlc2nchw2nlc', 'Encoding',
'Upsample', 'resize'
'Upsample', 'resize', 'DAPPM', 'PAPPM', 'BasicBlock', 'Bottleneck'
]

View File

@ -0,0 +1,143 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch.nn as nn
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule
from torch import Tensor
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType
class BasicBlock(BaseModule):
"""Basic block from `ResNet <https://arxiv.org/abs/1512.03385>`_.
Args:
in_channels (int): Input channels.
channels (int): Output channels.
stride (int): Stride of the first block. Default: 1.
downsample (nn.Module, optional): Downsample operation on identity.
Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer in
ConvModule. Default: dict(type='ReLU', inplace=True).
act_cfg_out (dict, optional): Config dict for activation layer at the
last of the block. Default: None.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
expansion = 1
def __init__(self,
in_channels: int,
channels: int,
stride: int = 1,
downsample: nn.Module = None,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
act_cfg_out: OptConfigType = dict(type='ReLU', inplace=True),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.conv1 = ConvModule(
in_channels,
channels,
kernel_size=3,
stride=stride,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv2 = ConvModule(
channels,
channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=None)
self.downsample = downsample
if act_cfg_out:
self.act = MODELS.build(act_cfg_out)
def forward(self, x: Tensor) -> Tensor:
residual = x
out = self.conv1(x)
out = self.conv2(out)
if self.downsample:
residual = self.downsample(x)
out += residual
if hasattr(self, 'act'):
out = self.act(out)
return out
class Bottleneck(BaseModule):
"""Bottleneck block from `ResNet <https://arxiv.org/abs/1512.03385>`_.
Args:
in_channels (int): Input channels.
channels (int): Output channels.
stride (int): Stride of the first block. Default: 1.
downsample (nn.Module, optional): Downsample operation on identity.
Default: None.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer in
ConvModule. Default: dict(type='ReLU', inplace=True).
act_cfg_out (dict, optional): Config dict for activation layer at
the last of the block. Default: None.
init_cfg (dict, optional): Initialization config dict. Default: None.
"""
expansion = 2
def __init__(self,
in_channels: int,
channels: int,
stride: int = 1,
downsample: Optional[nn.Module] = None,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
act_cfg_out: OptConfigType = None,
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.conv1 = ConvModule(
in_channels, channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
self.conv2 = ConvModule(
channels,
channels,
3,
stride,
1,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.conv3 = ConvModule(
channels,
channels * self.expansion,
1,
norm_cfg=norm_cfg,
act_cfg=None)
if act_cfg_out:
self.act = MODELS.build(act_cfg_out)
self.downsample = downsample
def forward(self, x: Tensor) -> Tensor:
residual = x
out = self.conv1(x)
out = self.conv2(out)
out = self.conv3(out)
if self.downsample:
residual = self.downsample(x)
out += residual
if hasattr(self, 'act'):
out = self.act(out)
return out

View File

@ -0,0 +1,193 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import ConvModule
from mmengine.model import BaseModule, ModuleList, Sequential
from torch import Tensor
class DAPPM(BaseModule):
"""DAPPM module in `DDRNet <https://arxiv.org/abs/2101.06085>`_.
Args:
in_channels (int): Input channels.
branch_channels (int): Branch channels.
out_channels (int): Output channels.
num_scales (int): Number of scales.
kernel_sizes (list[int]): Kernel sizes of each scale.
strides (list[int]): Strides of each scale.
paddings (list[int]): Paddings of each scale.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU', inplace=True).
conv_cfg (dict): Config dict for convolution layer in ConvModule.
Default: dict(order=('norm', 'act', 'conv'), bias=False).
upsample_mode (str): Upsample mode. Default: 'bilinear'.
"""
def __init__(self,
in_channels: int,
branch_channels: int,
out_channels: int,
num_scales: int,
kernel_sizes: List[int] = [5, 9, 17],
strides: List[int] = [2, 4, 8],
paddings: List[int] = [2, 4, 8],
norm_cfg: Dict = dict(type='BN', momentum=0.1),
act_cfg: Dict = dict(type='ReLU', inplace=True),
conv_cfg: Dict = dict(
order=('norm', 'act', 'conv'), bias=False),
upsample_mode: str = 'bilinear'):
super().__init__()
self.num_scales = num_scales
self.unsample_mode = upsample_mode
self.in_channels = in_channels
self.branch_channels = branch_channels
self.out_channels = out_channels
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.conv_cfg = conv_cfg
self.scales = ModuleList([
ConvModule(
in_channels,
branch_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
])
for i in range(1, num_scales - 1):
self.scales.append(
Sequential(*[
nn.AvgPool2d(
kernel_size=kernel_sizes[i - 1],
stride=strides[i - 1],
padding=paddings[i - 1]),
ConvModule(
in_channels,
branch_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
]))
self.scales.append(
Sequential(*[
nn.AdaptiveAvgPool2d((1, 1)),
ConvModule(
in_channels,
branch_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
]))
self.processes = ModuleList()
for i in range(num_scales - 1):
self.processes.append(
ConvModule(
branch_channels,
branch_channels,
kernel_size=3,
padding=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg))
self.compression = ConvModule(
branch_channels * num_scales,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
self.shortcut = ConvModule(
in_channels,
out_channels,
kernel_size=1,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**conv_cfg)
def forward(self, inputs: Tensor):
feats = []
feats.append(self.scales[0](inputs))
for i in range(1, self.num_scales):
feat_up = F.interpolate(
self.scales[i](inputs),
size=inputs.shape[2:],
mode=self.unsample_mode)
feats.append(self.processes[i - 1](feat_up + feats[i - 1]))
return self.compression(torch.cat(feats,
dim=1)) + self.shortcut(inputs)
class PAPPM(DAPPM):
"""PAPPM module in `PIDNet <https://arxiv.org/abs/2206.02066>`_.
Args:
in_channels (int): Input channels.
branch_channels (int): Branch channels.
out_channels (int): Output channels.
num_scales (int): Number of scales.
kernel_sizes (list[int]): Kernel sizes of each scale.
strides (list[int]): Strides of each scale.
paddings (list[int]): Paddings of each scale.
norm_cfg (dict): Config dict for normalization layer.
Default: dict(type='BN', momentum=0.1).
act_cfg (dict): Config dict for activation layer in ConvModule.
Default: dict(type='ReLU', inplace=True).
conv_cfg (dict): Config dict for convolution layer in ConvModule.
Default: dict(order=('norm', 'act', 'conv'), bias=False).
upsample_mode (str): Upsample mode. Default: 'bilinear'.
"""
def __init__(self,
in_channels: int,
branch_channels: int,
out_channels: int,
num_scales: int,
kernel_sizes: List[int] = [5, 9, 17],
strides: List[int] = [2, 4, 8],
paddings: List[int] = [2, 4, 8],
norm_cfg: Dict = dict(type='BN', momentum=0.1),
act_cfg: Dict = dict(type='ReLU', inplace=True),
conv_cfg: Dict = dict(
order=('norm', 'act', 'conv'), bias=False),
upsample_mode: str = 'bilinear'):
super().__init__(in_channels, branch_channels, out_channels,
num_scales, kernel_sizes, strides, paddings, norm_cfg,
act_cfg, conv_cfg, upsample_mode)
self.processes = ConvModule(
self.branch_channels * (self.num_scales - 1),
self.branch_channels * (self.num_scales - 1),
kernel_size=3,
padding=1,
groups=self.num_scales - 1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
**self.conv_cfg)
def forward(self, inputs: Tensor):
x_ = self.scales[0](inputs)
feats = []
for i in range(1, self.num_scales):
feat_up = F.interpolate(
self.scales[i](inputs),
size=inputs.shape[2:],
mode=self.unsample_mode,
align_corners=False)
feats.append(feat_up + x_)
scale_out = self.processes(torch.cat(feats, dim=1))
return self.compression(torch.cat([x_, scale_out],
dim=1)) + self.shortcut(inputs)

View File

@ -31,6 +31,7 @@ Import:
- configs/mobilenet_v3/mobilenet_v3.yml
- configs/nonlocal_net/nonlocal_net.yml
- configs/ocrnet/ocrnet.yml
- configs/pidnet/pidnet.yml
- configs/point_rend/point_rend.yml
- configs/poolformer/poolformer.yml
- configs/psanet/psanet.yml

View File

@ -0,0 +1,87 @@
# Copyright (c) OpenMMLab. All rights reserved.
import os
import tempfile
import torch
from mmengine.registry import init_default_scope
from mmseg.registry import MODELS
init_default_scope('mmseg')
def test_pidnet_backbone():
# Test PIDNet Standard Forward
norm_cfg = dict(type='BN', requires_grad=True)
backbone_cfg = dict(
type='PIDNet',
in_channels=3,
channels=32,
ppm_channels=96,
num_stem_blocks=2,
num_branch_blocks=3,
align_corners=False,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True))
model = MODELS.build(backbone_cfg)
model.init_weights()
# Test init weights
temp_file = tempfile.NamedTemporaryFile()
temp_file.close()
torch.save(model.state_dict(), temp_file.name)
backbone_cfg.update(
init_cfg=dict(type='Pretrained', checkpoint=temp_file.name))
model = MODELS.build(backbone_cfg)
model.init_weights()
os.remove(temp_file.name)
# Test eval mode
model.eval()
batch_size = 1
imgs = torch.randn(batch_size, 3, 64, 128)
feats = model(imgs)
assert type(feats) == torch.Tensor
assert feats.shape == torch.Size([batch_size, 128, 8, 16])
# Test train mode
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 64, 128)
feats = model(imgs)
assert len(feats) == 3
# test output for P branch
assert feats[0].shape == torch.Size([batch_size, 64, 8, 16])
# test output for I branch
assert feats[1].shape == torch.Size([batch_size, 128, 8, 16])
# test output for D branch
assert feats[2].shape == torch.Size([batch_size, 64, 8, 16])
# Test pidnet-m
backbone_cfg.update(channels=64)
model = MODELS.build(backbone_cfg)
feats = model(imgs)
assert len(feats) == 3
# test output for P branch
assert feats[0].shape == torch.Size([batch_size, 128, 8, 16])
# test output for I branch
assert feats[1].shape == torch.Size([batch_size, 256, 8, 16])
# test output for D branch
assert feats[2].shape == torch.Size([batch_size, 128, 8, 16])
# Test pidnet-l
backbone_cfg.update(
channels=64, ppm_channesl=112, num_stem_blocks=3, num_branch_blocks=4)
model = MODELS.build(backbone_cfg)
feats = model(imgs)
assert len(feats) == 3
# test output for P branch
assert feats[0].shape == torch.Size([batch_size, 128, 8, 16])
# test output for I branch
assert feats[1].shape == torch.Size([batch_size, 256, 8, 16])
# test output for D branch
assert feats[2].shape == torch.Size([batch_size, 128, 8, 16])

View File

@ -0,0 +1,89 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmengine.registry import init_default_scope
from mmseg.registry import MODELS
def test_pidnet_head():
init_default_scope('mmseg')
# Test PIDNet decode head Standard Forward
norm_cfg = dict(type='BN', requires_grad=True)
backbone_cfg = dict(
type='PIDNet',
in_channels=3,
channels=32,
ppm_channels=96,
num_stem_blocks=2,
num_branch_blocks=3,
align_corners=False,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True))
decode_head_cfg = dict(
type='PIDHead',
in_channels=128,
channels=128,
num_classes=19,
norm_cfg=norm_cfg,
act_cfg=dict(type='ReLU', inplace=True),
align_corners=True,
loss_decode=[
dict(
type='CrossEntropyLoss',
use_sigmoid=False,
class_weight=[
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507
],
loss_weight=0.4),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=[
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507
],
loss_weight=1.0),
dict(type='BoundaryLoss', loss_weight=20.0),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=[
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754,
1.0489, 0.8786, 1.0023, 0.9539, 0.9843, 1.1116, 0.9037,
1.0865, 1.0955, 1.0865, 1.1529, 1.0507
],
loss_weight=1.0)
])
backbone = MODELS.build(backbone_cfg)
head = MODELS.build(decode_head_cfg)
# Test train mode
backbone.train()
head.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 64, 128)
feats = backbone(imgs)
seg_logit = head(feats)
assert isinstance(seg_logit, tuple)
assert len(seg_logit) == 3
p_logits, i_logits, d_logits = seg_logit
assert p_logits.shape == (batch_size, 19, 8, 16)
assert i_logits.shape == (batch_size, 19, 8, 16)
assert d_logits.shape == (batch_size, 1, 8, 16)
# Test eval mode
backbone.eval()
head.eval()
feats = backbone(imgs)
seg_logit = head(feats)
assert isinstance(seg_logit, torch.Tensor)
assert seg_logit.shape == (batch_size, 19, 8, 16)