mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
Thanks for your contribution and we appreciate it a lot. The following instructions would make your pull request more healthy and more easily get feedback. If you do not understand some items, don't worry, just make the pull request and seek help from maintainers. ## Motivation Support DDRNet Paper: [Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes](https://arxiv.org/pdf/2101.06085) official Code: https://github.com/ydhongHIT/DDRNet There is already a PR https://github.com/open-mmlab/mmsegmentation/pull/1722 , but it has been inactive for a long time. ## Current Result ### Cityscapes #### inference with converted official weights | Method | Backbone | mIoU(official) | mIoU(converted weight) | | ------ | ------------- | -------------- | ---------------------- | | DDRNet | DDRNet23-slim | 77.8 | 77.84 | | DDRNet | DDRNet23 | 79.5 | 79.53 | #### training with converted pretrained backbone | Method | Backbone | Crop Size | Lr schd | Inf time(fps) | Device | mIoU | mIoU(ms+flip) | config | download | | ------ | ------------- | --------- | ------- | ------- | -------- | ----- | ------------- | ------------ | ------------ | | DDRNet | DDRNet23-slim | 1024x1024 | 120000 | 85.85 | RTX 8000 | 77.85 | 79.80 | [config](https://github.com/whu-pzhang/mmsegmentation/blob/ddrnet/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py) | model \| log | | DDRNet | DDRNet23 | 1024x1024 | 120000 | 33.41 | RTX 8000 | 79.53 | 80.98 | [config](https://github.com/whu-pzhang/mmsegmentation/blob/ddrnet/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py) | model \| log | The converted pretrained backbone weights download link: 1. [ddrnet23s_in1k_mmseg.pth](https://drive.google.com/file/d/1Ni4F1PMGGjuld-1S9fzDTmneLfpMuPTG/view?usp=sharing) 2. [ddrnet23_in1k_mmseg.pth](https://drive.google.com/file/d/11rsijC1xOWB6B0LgNQkAG-W6e1OdbCyJ/view?usp=sharing) ## To do - [x] support inference with converted official weights - [x] support training on cityscapes dataset --------- Co-authored-by: xiexinch <xiexinch@outlook.com>
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import Tuple, Union
|
|
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
|
from torch import Tensor
|
|
|
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
|
from mmseg.models.losses import accuracy
|
|
from mmseg.models.utils import resize
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import OptConfigType, SampleList
|
|
|
|
|
|
@MODELS.register_module()
|
|
class DDRHead(BaseDecodeHead):
|
|
"""Decode head for DDRNet.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
channels (int): Number of output channels.
|
|
num_classes (int): Number of classes.
|
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict, optional): Config dict for activation layer.
|
|
Default: dict(type='ReLU', inplace=True).
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: int,
|
|
channels: int,
|
|
num_classes: int,
|
|
norm_cfg: OptConfigType = dict(type='BN'),
|
|
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
|
**kwargs):
|
|
super().__init__(
|
|
in_channels,
|
|
channels,
|
|
num_classes=num_classes,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
**kwargs)
|
|
|
|
self.head = self._make_base_head(self.in_channels, self.channels)
|
|
self.aux_head = self._make_base_head(self.in_channels // 2,
|
|
self.channels)
|
|
self.aux_cls_seg = nn.Conv2d(
|
|
self.channels, self.out_channels, kernel_size=1)
|
|
|
|
def init_weights(self):
|
|
for m in self.modules():
|
|
if isinstance(m, nn.Conv2d):
|
|
nn.init.kaiming_normal_(
|
|
m.weight, mode='fan_out', nonlinearity='relu')
|
|
elif isinstance(m, nn.BatchNorm2d):
|
|
nn.init.constant_(m.weight, 1)
|
|
nn.init.constant_(m.bias, 0)
|
|
|
|
def forward(
|
|
self,
|
|
inputs: Union[Tensor,
|
|
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
|
if self.training:
|
|
c3_feat, c5_feat = inputs
|
|
x_c = self.head(c5_feat)
|
|
x_c = self.cls_seg(x_c)
|
|
x_s = self.aux_head(c3_feat)
|
|
x_s = self.aux_cls_seg(x_s)
|
|
|
|
return x_c, x_s
|
|
else:
|
|
x_c = self.head(inputs)
|
|
x_c = self.cls_seg(x_c)
|
|
return x_c
|
|
|
|
def _make_base_head(self, in_channels: int,
|
|
channels: int) -> nn.Sequential:
|
|
layers = [
|
|
ConvModule(
|
|
in_channels,
|
|
channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg,
|
|
order=('norm', 'act', 'conv')),
|
|
build_norm_layer(self.norm_cfg, channels)[1],
|
|
build_activation_layer(self.act_cfg),
|
|
]
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
|
batch_data_samples: SampleList) -> dict:
|
|
loss = dict()
|
|
context_logit, spatial_logit = seg_logits
|
|
seg_label = self._stack_batch_gt(batch_data_samples)
|
|
|
|
context_logit = resize(
|
|
context_logit,
|
|
size=seg_label.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
spatial_logit = resize(
|
|
spatial_logit,
|
|
size=seg_label.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
seg_label = seg_label.squeeze(1)
|
|
|
|
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
|
|
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
|
|
loss['acc_seg'] = accuracy(
|
|
context_logit, seg_label, ignore_index=self.ignore_index)
|
|
|
|
return loss
|