diff --git a/configs/ddrnet/README.md b/configs/ddrnet/README.md
new file mode 100644
index 000000000..882198866
--- /dev/null
+++ b/configs/ddrnet/README.md
@@ -0,0 +1,47 @@
+# DDRNet
+
+> [Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes](http://arxiv.org/abs/2101.06085)
+
+## Introduction
+
+
+
+Official Repo
+
+## Abstract
+
+
+
+Semantic segmentation is a key technology for autonomous vehicles to understand the surrounding scenes. The appealing performances of contemporary models usually come at the expense of heavy computations and lengthy inference time, which is intolerable for self-driving. Using light-weight architectures (encoder-decoder or two-pathway) or reasoning on low-resolution images, recent methods realize very fast scene parsing, even running at more than 100 FPS on a single 1080Ti GPU. However, there is still a significant gap in performance between these real-time methods and the models based on dilation backbones. To tackle this problem, we proposed a family of efficient backbones specially designed for real-time semantic segmentation. The proposed deep dual-resolution networks (DDRNets) are composed of two deep branches between which multiple bilateral fusions are performed. Additionally, we design a new contextual information extractor named Deep Aggregation Pyramid Pooling Module (DAPPM) to enlarge effective receptive fields and fuse multi-scale context based on low-resolution feature maps. Our method achieves a new state-of-the-art trade-off between accuracy and speed on both Cityscapes and CamVid dataset. In particular, on a single 2080Ti GPU, DDRNet-23-slim yields 77.4% mIoU at 102 FPS on Cityscapes test set and 74.7% mIoU at 230 FPS on CamVid test set. With widely used test augmentation, our method is superior to most state-of-the-art models and requires much less computation. Codes and trained models are available online.
+
+
+
+
+

+
+
+## Results and models
+
+### Cityscapes
+
+| Method | Backbone | Crop Size | Lr schd | Mem(GB) | Inf time(fps) | Device | mIoU | mIoU(ms+flip) | config | download |
+| ------ | ------------- | --------- | ------- | ------- | ------------- | ------ | ----- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
+| DDRNet | DDRNet23-slim | 1024x1024 | 120000 | 1.70 | 85.85 | A100 | 77.84 | 80.15 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230426_145312-6a5e5174.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230426_145312.json) |
+| DDRNet | DDRNet23 | 1024x1024 | 120000 | 7.26 | 33.41 | A100 | 79.99 | 81.71 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230425_162633-81601db0.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230425_162633.json) |
+
+## Notes
+
+The pretrained weights in config files are converted from [the official repo](https://github.com/ydhongHIT/DDRNet#pretrained-models).
+
+## Citation
+
+```bibtex
+@misc{hong2021ddrnet,
+ title={Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes},
+ author={Hong, Yuanduo and Pan, Huihui and Sun, Weichao and Jia, Yisong},
+ year={2021},
+ eprint={2101.06085},
+ archivePrefix={arXiv},
+ primaryClass={cs.CV},
+}
+```
diff --git a/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py b/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py
new file mode 100644
index 000000000..d911de4dc
--- /dev/null
+++ b/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py
@@ -0,0 +1,95 @@
+_base_ = [
+ '../_base_/datasets/cityscapes_1024x1024.py',
+ '../_base_/default_runtime.py',
+]
+
+# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
+# Licensed under the MIT License
+class_weight = [
+ 0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
+ 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
+ 1.0507
+]
+
+crop_size = (1024, 1024)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='DDRNet',
+ in_channels=3,
+ channels=32,
+ ppm_channels=128,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='pretrained/ddrnet23s_in1k_mmseg.pth')),
+ decode_head=dict(
+ type='DDRHead',
+ in_channels=32 * 4,
+ channels=64,
+ dropout_ratio=0.,
+ num_classes=19,
+ align_corners=False,
+ norm_cfg=norm_cfg,
+ loss_decode=[
+ dict(
+ type='OhemCrossEntropy',
+ thres=0.9,
+ min_kept=131072,
+ class_weight=class_weight,
+ loss_weight=1.0),
+ dict(
+ type='OhemCrossEntropy',
+ thres=0.9,
+ min_kept=131072,
+ class_weight=class_weight,
+ loss_weight=0.4),
+ ]),
+
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+
+train_dataloader = dict(batch_size=6, num_workers=4)
+
+iters = 120000
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
+# learning policy
+param_scheduler = [
+ dict(
+ type='PolyLR',
+ eta_min=0,
+ power=0.9,
+ begin=0,
+ end=iters,
+ by_epoch=False)
+]
+
+# training schedule for 120k
+train_cfg = dict(
+ type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook', by_epoch=False, interval=iters // 10),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+randomness = dict(seed=304)
diff --git a/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py b/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py
new file mode 100644
index 000000000..b59638b25
--- /dev/null
+++ b/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py
@@ -0,0 +1,95 @@
+_base_ = [
+ '../_base_/datasets/cityscapes_1024x1024.py',
+ '../_base_/default_runtime.py',
+]
+
+# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
+# Licensed under the MIT License
+class_weight = [
+ 0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
+ 1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
+ 1.0507
+]
+
+crop_size = (1024, 1024)
+data_preprocessor = dict(
+ type='SegDataPreProcessor',
+ size=crop_size,
+ mean=[123.675, 116.28, 103.53],
+ std=[58.395, 57.12, 57.375],
+ bgr_to_rgb=True,
+ pad_val=0,
+ seg_pad_val=255)
+norm_cfg = dict(type='SyncBN', requires_grad=True)
+model = dict(
+ type='EncoderDecoder',
+ data_preprocessor=data_preprocessor,
+ backbone=dict(
+ type='DDRNet',
+ in_channels=3,
+ channels=64,
+ ppm_channels=128,
+ norm_cfg=norm_cfg,
+ align_corners=False,
+ init_cfg=dict(
+ type='Pretrained',
+ checkpoint='pretrained/ddrnet23_in1k_mmseg.pth')),
+ decode_head=dict(
+ type='DDRHead',
+ in_channels=64 * 4,
+ channels=128,
+ dropout_ratio=0.,
+ num_classes=19,
+ align_corners=False,
+ norm_cfg=norm_cfg,
+ loss_decode=[
+ dict(
+ type='OhemCrossEntropy',
+ thres=0.9,
+ min_kept=131072,
+ class_weight=class_weight,
+ loss_weight=1.0),
+ dict(
+ type='OhemCrossEntropy',
+ thres=0.9,
+ min_kept=131072,
+ class_weight=class_weight,
+ loss_weight=0.4),
+ ]),
+
+ # model training and testing settings
+ train_cfg=dict(),
+ test_cfg=dict(mode='whole'))
+
+train_dataloader = dict(batch_size=6, num_workers=4)
+
+iters = 120000
+# optimizer
+optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
+optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
+# learning policy
+param_scheduler = [
+ dict(
+ type='PolyLR',
+ eta_min=0,
+ power=0.9,
+ begin=0,
+ end=iters,
+ by_epoch=False)
+]
+
+# training schedule for 120k
+train_cfg = dict(
+ type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
+val_cfg = dict(type='ValLoop')
+test_cfg = dict(type='TestLoop')
+default_hooks = dict(
+ timer=dict(type='IterTimerHook'),
+ logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
+ param_scheduler=dict(type='ParamSchedulerHook'),
+ checkpoint=dict(
+ type='CheckpointHook', by_epoch=False, interval=iters // 10),
+ sampler_seed=dict(type='DistSamplerSeedHook'),
+ visualization=dict(type='SegVisualizationHook'))
+
+randomness = dict(seed=304)
diff --git a/configs/ddrnet/metafile.yaml b/configs/ddrnet/metafile.yaml
new file mode 100644
index 000000000..01e701871
--- /dev/null
+++ b/configs/ddrnet/metafile.yaml
@@ -0,0 +1,14 @@
+Collections:
+- Name: ''
+ License: Apache License 2.0
+ Metadata:
+ Training Data:
+ - Cityscapes
+ Paper:
+ Title: Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation
+ of Road Scenes
+ URL: http://arxiv.org/abs/2101.06085
+ README: configs/ddrnet/README.md
+ Frameworks:
+ - PyTorch
+Models: []
diff --git a/mmseg/models/backbones/__init__.py b/mmseg/models/backbones/__init__.py
index e3107306e..d9228a500 100644
--- a/mmseg/models/backbones/__init__.py
+++ b/mmseg/models/backbones/__init__.py
@@ -3,6 +3,7 @@ from .beit import BEiT
from .bisenetv1 import BiSeNetV1
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
+from .ddrnet import DDRNet
from .erfnet import ERFNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
@@ -28,5 +29,6 @@ __all__ = [
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
- 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN'
+ 'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
+ 'DDRNet'
]
diff --git a/mmseg/models/backbones/ddrnet.py b/mmseg/models/backbones/ddrnet.py
new file mode 100644
index 000000000..4508aade8
--- /dev/null
+++ b/mmseg/models/backbones/ddrnet.py
@@ -0,0 +1,222 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+import torch.nn as nn
+from mmcv.cnn import ConvModule, build_norm_layer
+from mmengine.model import BaseModule
+
+from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
+from mmseg.registry import MODELS
+from mmseg.utils import OptConfigType
+
+
+@MODELS.register_module()
+class DDRNet(BaseModule):
+ """DDRNet backbone.
+
+ This backbone is the implementation of `Deep Dual-resolution Networks for
+ Real-time and Accurate Semantic Segmentation of Road Scenes
+ `_.
+ Modified from https://github.com/ydhongHIT/DDRNet.
+
+ Args:
+ in_channels (int): Number of input image channels. Default: 3.
+ channels: (int): The base channels of DDRNet. Default: 32.
+ ppm_channels (int): The channels of PPM module. Default: 128.
+ align_corners (bool): align_corners argument of F.interpolate.
+ Default: False.
+ norm_cfg (dict): Config dict to build norm layer.
+ Default: dict(type='BN', requires_grad=True).
+ act_cfg (dict): Config dict for activation layer.
+ Default: dict(type='ReLU', inplace=True).
+ init_cfg (dict, optional): Initialization config dict.
+ Default: None.
+ """
+
+ def __init__(self,
+ in_channels: int = 3,
+ channels: int = 32,
+ ppm_channels: int = 128,
+ align_corners: bool = False,
+ norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
+ act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
+ init_cfg: OptConfigType = None):
+ super().__init__(init_cfg)
+
+ self.in_channels = in_channels
+ self.ppm_channels = ppm_channels
+
+ self.norm_cfg = norm_cfg
+ self.act_cfg = act_cfg
+ self.align_corners = align_corners
+
+ # stage 0-2
+ self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
+ self.relu = nn.ReLU()
+
+ # low resolution(context) branch
+ self.context_branch_layers = nn.ModuleList()
+ for i in range(3):
+ self.context_branch_layers.append(
+ self._make_layer(
+ block=BasicBlock if i < 2 else Bottleneck,
+ inplanes=channels * 2**(i + 1),
+ planes=channels * 8 if i > 0 else channels * 4,
+ num_blocks=2 if i < 2 else 1,
+ stride=2))
+
+ # bilateral fusion
+ self.compression_1 = ConvModule(
+ channels * 4,
+ channels * 2,
+ kernel_size=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.down_1 = ConvModule(
+ channels * 2,
+ channels * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+
+ self.compression_2 = ConvModule(
+ channels * 8,
+ channels * 2,
+ kernel_size=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None)
+ self.down_2 = nn.Sequential(
+ ConvModule(
+ channels * 2,
+ channels * 4,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ ConvModule(
+ channels * 4,
+ channels * 8,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=None))
+
+ # high resolution(spatial) branch
+ self.spatial_branch_layers = nn.ModuleList()
+ for i in range(3):
+ self.spatial_branch_layers.append(
+ self._make_layer(
+ block=BasicBlock if i < 2 else Bottleneck,
+ inplanes=channels * 2,
+ planes=channels * 2,
+ num_blocks=2 if i < 2 else 1,
+ ))
+
+ self.spp = DAPPM(
+ channels * 16, ppm_channels, channels * 4, num_scales=5)
+
+ def _make_stem_layer(self, in_channels, channels, num_blocks):
+ layers = [
+ ConvModule(
+ in_channels,
+ channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg),
+ ConvModule(
+ channels,
+ channels,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg)
+ ]
+
+ layers.extend([
+ self._make_layer(BasicBlock, channels, channels, num_blocks),
+ nn.ReLU(),
+ self._make_layer(
+ BasicBlock, channels, channels * 2, num_blocks, stride=2),
+ nn.ReLU(),
+ ])
+
+ return nn.Sequential(*layers)
+
+ def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
+ downsample = None
+ if stride != 1 or inplanes != planes * block.expansion:
+ downsample = nn.Sequential(
+ nn.Conv2d(
+ inplanes,
+ planes * block.expansion,
+ kernel_size=1,
+ stride=stride,
+ bias=False),
+ build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
+
+ layers = [
+ block(
+ in_channels=inplanes,
+ channels=planes,
+ stride=stride,
+ downsample=downsample)
+ ]
+ inplanes = planes * block.expansion
+ for i in range(1, num_blocks):
+ layers.append(
+ block(
+ in_channels=inplanes,
+ channels=planes,
+ stride=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
+
+ return nn.Sequential(*layers)
+
+ def forward(self, x):
+ """Forward function."""
+ out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
+
+ # stage 0-2
+ x = self.stem(x)
+
+ # stage3
+ x_c = self.context_branch_layers[0](x)
+ x_s = self.spatial_branch_layers[0](x)
+ comp_c = self.compression_1(self.relu(x_c))
+ x_c += self.down_1(self.relu(x_s))
+ x_s += resize(
+ comp_c,
+ size=out_size,
+ mode='bilinear',
+ align_corners=self.align_corners)
+ if self.training:
+ temp_context = x_s.clone()
+
+ # stage4
+ x_c = self.context_branch_layers[1](self.relu(x_c))
+ x_s = self.spatial_branch_layers[1](self.relu(x_s))
+ comp_c = self.compression_2(self.relu(x_c))
+ x_c += self.down_2(self.relu(x_s))
+ x_s += resize(
+ comp_c,
+ size=out_size,
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ # stage5
+ x_s = self.spatial_branch_layers[2](self.relu(x_s))
+ x_c = self.context_branch_layers[2](self.relu(x_c))
+ x_c = self.spp(x_c)
+ x_c = resize(
+ x_c,
+ size=out_size,
+ mode='bilinear',
+ align_corners=self.align_corners)
+
+ return (temp_context, x_s + x_c) if self.training else x_s + x_c
diff --git a/mmseg/models/decode_heads/__init__.py b/mmseg/models/decode_heads/__init__.py
index 18235456b..36c37ec2d 100644
--- a/mmseg/models/decode_heads/__init__.py
+++ b/mmseg/models/decode_heads/__init__.py
@@ -4,6 +4,7 @@ from .apc_head import APCHead
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
+from .ddr_head import DDRHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .dpt_head import DPTHead
@@ -41,5 +42,5 @@ __all__ = [
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
- 'LightHamHead', 'PIDHead'
+ 'LightHamHead', 'PIDHead', 'DDRHead'
]
diff --git a/mmseg/models/decode_heads/ddr_head.py b/mmseg/models/decode_heads/ddr_head.py
new file mode 100644
index 000000000..ba26d6503
--- /dev/null
+++ b/mmseg/models/decode_heads/ddr_head.py
@@ -0,0 +1,116 @@
+# Copyright (c) OpenMMLab. All rights reserved.
+from typing import Tuple, Union
+
+import torch.nn as nn
+from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
+from torch import Tensor
+
+from mmseg.models.decode_heads.decode_head import BaseDecodeHead
+from mmseg.models.losses import accuracy
+from mmseg.models.utils import resize
+from mmseg.registry import MODELS
+from mmseg.utils import OptConfigType, SampleList
+
+
+@MODELS.register_module()
+class DDRHead(BaseDecodeHead):
+ """Decode head for DDRNet.
+
+ Args:
+ in_channels (int): Number of input channels.
+ channels (int): Number of output channels.
+ num_classes (int): Number of classes.
+ norm_cfg (dict, optional): Config dict for normalization layer.
+ Default: dict(type='BN').
+ act_cfg (dict, optional): Config dict for activation layer.
+ Default: dict(type='ReLU', inplace=True).
+ """
+
+ def __init__(self,
+ in_channels: int,
+ channels: int,
+ num_classes: int,
+ norm_cfg: OptConfigType = dict(type='BN'),
+ act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
+ **kwargs):
+ super().__init__(
+ in_channels,
+ channels,
+ num_classes=num_classes,
+ norm_cfg=norm_cfg,
+ act_cfg=act_cfg,
+ **kwargs)
+
+ self.head = self._make_base_head(self.in_channels, self.channels)
+ self.aux_head = self._make_base_head(self.in_channels // 2,
+ self.channels)
+ self.aux_cls_seg = nn.Conv2d(
+ self.channels, self.out_channels, kernel_size=1)
+
+ def init_weights(self):
+ for m in self.modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(
+ m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.constant_(m.weight, 1)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(
+ self,
+ inputs: Union[Tensor,
+ Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
+ if self.training:
+ c3_feat, c5_feat = inputs
+ x_c = self.head(c5_feat)
+ x_c = self.cls_seg(x_c)
+ x_s = self.aux_head(c3_feat)
+ x_s = self.aux_cls_seg(x_s)
+
+ return x_c, x_s
+ else:
+ x_c = self.head(inputs)
+ x_c = self.cls_seg(x_c)
+ return x_c
+
+ def _make_base_head(self, in_channels: int,
+ channels: int) -> nn.Sequential:
+ layers = [
+ ConvModule(
+ in_channels,
+ channels,
+ kernel_size=3,
+ padding=1,
+ norm_cfg=self.norm_cfg,
+ act_cfg=self.act_cfg,
+ order=('norm', 'act', 'conv')),
+ build_norm_layer(self.norm_cfg, channels)[1],
+ build_activation_layer(self.act_cfg),
+ ]
+
+ return nn.Sequential(*layers)
+
+ def loss_by_feat(self, seg_logits: Tuple[Tensor],
+ batch_data_samples: SampleList) -> dict:
+ loss = dict()
+ context_logit, spatial_logit = seg_logits
+ seg_label = self._stack_batch_gt(batch_data_samples)
+
+ context_logit = resize(
+ context_logit,
+ size=seg_label.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ spatial_logit = resize(
+ spatial_logit,
+ size=seg_label.shape[2:],
+ mode='bilinear',
+ align_corners=self.align_corners)
+ seg_label = seg_label.squeeze(1)
+
+ loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
+ loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
+ loss['acc_seg'] = accuracy(
+ context_logit, seg_label, ignore_index=self.ignore_index)
+
+ return loss
diff --git a/model-index.yml b/model-index.yml
index 5e87c386d..3ed1c1cdc 100644
--- a/model-index.yml
+++ b/model-index.yml
@@ -8,6 +8,7 @@ Import:
- configs/cgnet/metafile.yaml
- configs/convnext/metafile.yaml
- configs/danet/metafile.yaml
+- configs/ddrnet/metafile.yaml
- configs/deeplabv3/metafile.yaml
- configs/deeplabv3plus/metafile.yaml
- configs/dmnet/metafile.yaml