Pan Zhang 990063e59b
[Feature] Support DDRNet (#2855)
Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Support DDRNet
Paper: [Deep Dual-resolution Networks for Real-time and Accurate
Semantic Segmentation of Road Scenes](https://arxiv.org/pdf/2101.06085)
official Code: https://github.com/ydhongHIT/DDRNet


There is already a PR
https://github.com/open-mmlab/mmsegmentation/pull/1722 , but it has been
inactive for a long time.

## Current Result

### Cityscapes

#### inference with converted official weights

| Method | Backbone      | mIoU(official) | mIoU(converted weight) |
| ------ | ------------- | -------------- | ---------------------- |
| DDRNet | DDRNet23-slim | 77.8           | 77.84                  |
| DDRNet | DDRNet23 | 79.5 | 79.53 |

#### training with converted pretrained backbone

| Method | Backbone | Crop Size | Lr schd | Inf time(fps) | Device |
mIoU | mIoU(ms+flip) | config | download |
| ------ | ------------- | --------- | ------- | ------- | -------- |
----- | ------------- | ------------ | ------------ |
| DDRNet | DDRNet23-slim | 1024x1024 | 120000 | 85.85 | RTX 8000 | 77.85
| 79.80 |
[config](https://github.com/whu-pzhang/mmsegmentation/blob/ddrnet/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py)
| model \| log |
| DDRNet | DDRNet23 | 1024x1024 | 120000 | 33.41 | RTX 8000 | 79.53 |
80.98 |
[config](https://github.com/whu-pzhang/mmsegmentation/blob/ddrnet/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py)
| model \| log |


The converted pretrained backbone weights download link:

1.
[ddrnet23s_in1k_mmseg.pth](https://drive.google.com/file/d/1Ni4F1PMGGjuld-1S9fzDTmneLfpMuPTG/view?usp=sharing)
2.
[ddrnet23_in1k_mmseg.pth](https://drive.google.com/file/d/11rsijC1xOWB6B0LgNQkAG-W6e1OdbCyJ/view?usp=sharing)

## To do

- [x] support inference with converted official weights
- [x] support training on cityscapes dataset

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
2023-04-27 09:44:30 +08:00

117 lines
3.9 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch.nn as nn
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from torch import Tensor
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.losses import accuracy
from mmseg.models.utils import resize
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType, SampleList
@MODELS.register_module()
class DDRHead(BaseDecodeHead):
"""Decode head for DDRNet.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_classes (int): Number of classes.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
"""
def __init__(self,
in_channels: int,
channels: int,
num_classes: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
**kwargs):
super().__init__(
in_channels,
channels,
num_classes=num_classes,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.head = self._make_base_head(self.in_channels, self.channels)
self.aux_head = self._make_base_head(self.in_channels // 2,
self.channels)
self.aux_cls_seg = nn.Conv2d(
self.channels, self.out_channels, kernel_size=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(
self,
inputs: Union[Tensor,
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
if self.training:
c3_feat, c5_feat = inputs
x_c = self.head(c5_feat)
x_c = self.cls_seg(x_c)
x_s = self.aux_head(c3_feat)
x_s = self.aux_cls_seg(x_s)
return x_c, x_s
else:
x_c = self.head(inputs)
x_c = self.cls_seg(x_c)
return x_c
def _make_base_head(self, in_channels: int,
channels: int) -> nn.Sequential:
layers = [
ConvModule(
in_channels,
channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
order=('norm', 'act', 'conv')),
build_norm_layer(self.norm_cfg, channels)[1],
build_activation_layer(self.act_cfg),
]
return nn.Sequential(*layers)
def loss_by_feat(self, seg_logits: Tuple[Tensor],
batch_data_samples: SampleList) -> dict:
loss = dict()
context_logit, spatial_logit = seg_logits
seg_label = self._stack_batch_gt(batch_data_samples)
context_logit = resize(
context_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
spatial_logit = resize(
spatial_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
seg_label = seg_label.squeeze(1)
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
loss['acc_seg'] = accuracy(
context_logit, seg_label, ignore_index=self.ignore_index)
return loss