223 lines
7.3 KiB
Python
223 lines
7.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule, build_norm_layer
|
|
from mmengine.model import BaseModule
|
|
|
|
from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import OptConfigType
|
|
|
|
|
|
@MODELS.register_module()
|
|
class DDRNet(BaseModule):
|
|
"""DDRNet backbone.
|
|
|
|
This backbone is the implementation of `Deep Dual-resolution Networks for
|
|
Real-time and Accurate Semantic Segmentation of Road Scenes
|
|
<http://arxiv.org/abs/2101.06085>`_.
|
|
Modified from https://github.com/ydhongHIT/DDRNet.
|
|
|
|
Args:
|
|
in_channels (int): Number of input image channels. Default: 3.
|
|
channels: (int): The base channels of DDRNet. Default: 32.
|
|
ppm_channels (int): The channels of PPM module. Default: 128.
|
|
align_corners (bool): align_corners argument of F.interpolate.
|
|
Default: False.
|
|
norm_cfg (dict): Config dict to build norm layer.
|
|
Default: dict(type='BN', requires_grad=True).
|
|
act_cfg (dict): Config dict for activation layer.
|
|
Default: dict(type='ReLU', inplace=True).
|
|
init_cfg (dict, optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels: int = 3,
|
|
channels: int = 32,
|
|
ppm_channels: int = 128,
|
|
align_corners: bool = False,
|
|
norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
|
|
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
|
init_cfg: OptConfigType = None):
|
|
super().__init__(init_cfg)
|
|
|
|
self.in_channels = in_channels
|
|
self.ppm_channels = ppm_channels
|
|
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
self.align_corners = align_corners
|
|
|
|
# stage 0-2
|
|
self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
|
|
self.relu = nn.ReLU()
|
|
|
|
# low resolution(context) branch
|
|
self.context_branch_layers = nn.ModuleList()
|
|
for i in range(3):
|
|
self.context_branch_layers.append(
|
|
self._make_layer(
|
|
block=BasicBlock if i < 2 else Bottleneck,
|
|
inplanes=channels * 2**(i + 1),
|
|
planes=channels * 8 if i > 0 else channels * 4,
|
|
num_blocks=2 if i < 2 else 1,
|
|
stride=2))
|
|
|
|
# bilateral fusion
|
|
self.compression_1 = ConvModule(
|
|
channels * 4,
|
|
channels * 2,
|
|
kernel_size=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=None)
|
|
self.down_1 = ConvModule(
|
|
channels * 2,
|
|
channels * 4,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=None)
|
|
|
|
self.compression_2 = ConvModule(
|
|
channels * 8,
|
|
channels * 2,
|
|
kernel_size=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=None)
|
|
self.down_2 = nn.Sequential(
|
|
ConvModule(
|
|
channels * 2,
|
|
channels * 4,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg),
|
|
ConvModule(
|
|
channels * 4,
|
|
channels * 8,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=None))
|
|
|
|
# high resolution(spatial) branch
|
|
self.spatial_branch_layers = nn.ModuleList()
|
|
for i in range(3):
|
|
self.spatial_branch_layers.append(
|
|
self._make_layer(
|
|
block=BasicBlock if i < 2 else Bottleneck,
|
|
inplanes=channels * 2,
|
|
planes=channels * 2,
|
|
num_blocks=2 if i < 2 else 1,
|
|
))
|
|
|
|
self.spp = DAPPM(
|
|
channels * 16, ppm_channels, channels * 4, num_scales=5)
|
|
|
|
def _make_stem_layer(self, in_channels, channels, num_blocks):
|
|
layers = [
|
|
ConvModule(
|
|
in_channels,
|
|
channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg),
|
|
ConvModule(
|
|
channels,
|
|
channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg)
|
|
]
|
|
|
|
layers.extend([
|
|
self._make_layer(BasicBlock, channels, channels, num_blocks),
|
|
nn.ReLU(),
|
|
self._make_layer(
|
|
BasicBlock, channels, channels * 2, num_blocks, stride=2),
|
|
nn.ReLU(),
|
|
])
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
|
|
downsample = None
|
|
if stride != 1 or inplanes != planes * block.expansion:
|
|
downsample = nn.Sequential(
|
|
nn.Conv2d(
|
|
inplanes,
|
|
planes * block.expansion,
|
|
kernel_size=1,
|
|
stride=stride,
|
|
bias=False),
|
|
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
|
|
|
|
layers = [
|
|
block(
|
|
in_channels=inplanes,
|
|
channels=planes,
|
|
stride=stride,
|
|
downsample=downsample)
|
|
]
|
|
inplanes = planes * block.expansion
|
|
for i in range(1, num_blocks):
|
|
layers.append(
|
|
block(
|
|
in_channels=inplanes,
|
|
channels=planes,
|
|
stride=1,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
|
|
|
|
return nn.Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
|
|
|
|
# stage 0-2
|
|
x = self.stem(x)
|
|
|
|
# stage3
|
|
x_c = self.context_branch_layers[0](x)
|
|
x_s = self.spatial_branch_layers[0](x)
|
|
comp_c = self.compression_1(self.relu(x_c))
|
|
x_c += self.down_1(self.relu(x_s))
|
|
x_s += resize(
|
|
comp_c,
|
|
size=out_size,
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
if self.training:
|
|
temp_context = x_s.clone()
|
|
|
|
# stage4
|
|
x_c = self.context_branch_layers[1](self.relu(x_c))
|
|
x_s = self.spatial_branch_layers[1](self.relu(x_s))
|
|
comp_c = self.compression_2(self.relu(x_c))
|
|
x_c += self.down_2(self.relu(x_s))
|
|
x_s += resize(
|
|
comp_c,
|
|
size=out_size,
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
|
|
# stage5
|
|
x_s = self.spatial_branch_layers[2](self.relu(x_s))
|
|
x_c = self.context_branch_layers[2](self.relu(x_c))
|
|
x_c = self.spp(x_c)
|
|
x_c = resize(
|
|
x_c,
|
|
size=out_size,
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
|
|
return (temp_context, x_s + x_c) if self.training else x_s + x_c
|