[Feature] Support DDRNet (#2855)

Thanks for your contribution and we appreciate it a lot. The following
instructions would make your pull request more healthy and more easily
get feedback. If you do not understand some items, don't worry, just
make the pull request and seek help from maintainers.

## Motivation

Support DDRNet
Paper: [Deep Dual-resolution Networks for Real-time and Accurate
Semantic Segmentation of Road Scenes](https://arxiv.org/pdf/2101.06085)
official Code: https://github.com/ydhongHIT/DDRNet


There is already a PR
https://github.com/open-mmlab/mmsegmentation/pull/1722 , but it has been
inactive for a long time.

## Current Result

### Cityscapes

#### inference with converted official weights

| Method | Backbone      | mIoU(official) | mIoU(converted weight) |
| ------ | ------------- | -------------- | ---------------------- |
| DDRNet | DDRNet23-slim | 77.8           | 77.84                  |
| DDRNet | DDRNet23 | 79.5 | 79.53 |

#### training with converted pretrained backbone

| Method | Backbone | Crop Size | Lr schd | Inf time(fps) | Device |
mIoU | mIoU(ms+flip) | config | download |
| ------ | ------------- | --------- | ------- | ------- | -------- |
----- | ------------- | ------------ | ------------ |
| DDRNet | DDRNet23-slim | 1024x1024 | 120000 | 85.85 | RTX 8000 | 77.85
| 79.80 |
[config](https://github.com/whu-pzhang/mmsegmentation/blob/ddrnet/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py)
| model \| log |
| DDRNet | DDRNet23 | 1024x1024 | 120000 | 33.41 | RTX 8000 | 79.53 |
80.98 |
[config](https://github.com/whu-pzhang/mmsegmentation/blob/ddrnet/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py)
| model \| log |


The converted pretrained backbone weights download link:

1.
[ddrnet23s_in1k_mmseg.pth](https://drive.google.com/file/d/1Ni4F1PMGGjuld-1S9fzDTmneLfpMuPTG/view?usp=sharing)
2.
[ddrnet23_in1k_mmseg.pth](https://drive.google.com/file/d/11rsijC1xOWB6B0LgNQkAG-W6e1OdbCyJ/view?usp=sharing)

## To do

- [x] support inference with converted official weights
- [x] support training on cityscapes dataset

---------

Co-authored-by: xiexinch <xiexinch@outlook.com>
This commit is contained in:
Pan Zhang 2023-04-27 09:44:30 +08:00 committed by GitHub
parent 3b072eb213
commit 990063e59b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 595 additions and 2 deletions

47
configs/ddrnet/README.md Normal file
View File

@ -0,0 +1,47 @@
# DDRNet
> [Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes](http://arxiv.org/abs/2101.06085)
## Introduction
<!-- [ALGORITHM] -->
<a href="https://github.com/ydhongHIT/DDRNet">Official Repo</a>
## Abstract
<!-- [ABSTRACT] -->
Semantic segmentation is a key technology for autonomous vehicles to understand the surrounding scenes. The appealing performances of contemporary models usually come at the expense of heavy computations and lengthy inference time, which is intolerable for self-driving. Using light-weight architectures (encoder-decoder or two-pathway) or reasoning on low-resolution images, recent methods realize very fast scene parsing, even running at more than 100 FPS on a single 1080Ti GPU. However, there is still a significant gap in performance between these real-time methods and the models based on dilation backbones. To tackle this problem, we proposed a family of efficient backbones specially designed for real-time semantic segmentation. The proposed deep dual-resolution networks (DDRNets) are composed of two deep branches between which multiple bilateral fusions are performed. Additionally, we design a new contextual information extractor named Deep Aggregation Pyramid Pooling Module (DAPPM) to enlarge effective receptive fields and fuse multi-scale context based on low-resolution feature maps. Our method achieves a new state-of-the-art trade-off between accuracy and speed on both Cityscapes and CamVid dataset. In particular, on a single 2080Ti GPU, DDRNet-23-slim yields 77.4% mIoU at 102 FPS on Cityscapes test set and 74.7% mIoU at 230 FPS on CamVid test set. With widely used test augmentation, our method is superior to most state-of-the-art models and requires much less computation. Codes and trained models are available online.
<!-- [IMAGE] -->
<div align=center>
<img src="https://raw.githubusercontent.com/ydhongHIT/DDRNet/main/figs/DDRNet_seg.png" width="80%"/>
</div>
## Results and models
### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem(GB) | Inf time(fps) | Device | mIoU | mIoU(ms+flip) | config | download |
| ------ | ------------- | --------- | ------- | ------- | ------------- | ------ | ----- | ------------- | ----------------------------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| DDRNet | DDRNet23-slim | 1024x1024 | 120000 | 1.70 | 85.85 | A100 | 77.84 | 80.15 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230426_145312-6a5e5174.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23-slim_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230426_145312.json) |
| DDRNet | DDRNet23 | 1024x1024 | 120000 | 7.26 | 33.41 | A100 | 79.99 | 81.71 | [config](https://github.com/open-mmlab/mmsegmentation/blob/main/configs/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230425_162633-81601db0.pth) \| [log](https://download.openmmlab.com/mmsegmentation/v0.5/ddrnet/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024/ddrnet_23_in1k-pre_2xb6-120k_cityscapes-1024x1024_20230425_162633.json) |
## Notes
The pretrained weights in config files are converted from [the official repo](https://github.com/ydhongHIT/DDRNet#pretrained-models).
## Citation
```bibtex
@misc{hong2021ddrnet,
title={Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation of Road Scenes},
author={Hong, Yuanduo and Pan, Huihui and Sun, Weichao and Jia, Yisong},
year={2021},
eprint={2101.06085},
archivePrefix={arXiv},
primaryClass={cs.CV},
}
```

View File

@ -0,0 +1,95 @@
_base_ = [
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py',
]
# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
# Licensed under the MIT License
class_weight = [
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
1.0507
]
crop_size = (1024, 1024)
data_preprocessor = dict(
type='SegDataPreProcessor',
size=crop_size,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='DDRNet',
in_channels=3,
channels=32,
ppm_channels=128,
norm_cfg=norm_cfg,
align_corners=False,
init_cfg=dict(
type='Pretrained',
checkpoint='pretrained/ddrnet23s_in1k_mmseg.pth')),
decode_head=dict(
type='DDRHead',
in_channels=32 * 4,
channels=64,
dropout_ratio=0.,
num_classes=19,
align_corners=False,
norm_cfg=norm_cfg,
loss_decode=[
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=1.0),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=0.4),
]),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
train_dataloader = dict(batch_size=6, num_workers=4)
iters = 120000
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=0,
power=0.9,
begin=0,
end=iters,
by_epoch=False)
]
# training schedule for 120k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=iters // 10),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
randomness = dict(seed=304)

View File

@ -0,0 +1,95 @@
_base_ = [
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py',
]
# The class_weight is borrowed from https://github.com/openseg-group/OCNet.pytorch/issues/14 # noqa
# Licensed under the MIT License
class_weight = [
0.8373, 0.918, 0.866, 1.0345, 1.0166, 0.9969, 0.9754, 1.0489, 0.8786,
1.0023, 0.9539, 0.9843, 1.1116, 0.9037, 1.0865, 1.0955, 1.0865, 1.1529,
1.0507
]
crop_size = (1024, 1024)
data_preprocessor = dict(
type='SegDataPreProcessor',
size=crop_size,
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
backbone=dict(
type='DDRNet',
in_channels=3,
channels=64,
ppm_channels=128,
norm_cfg=norm_cfg,
align_corners=False,
init_cfg=dict(
type='Pretrained',
checkpoint='pretrained/ddrnet23_in1k_mmseg.pth')),
decode_head=dict(
type='DDRHead',
in_channels=64 * 4,
channels=128,
dropout_ratio=0.,
num_classes=19,
align_corners=False,
norm_cfg=norm_cfg,
loss_decode=[
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=1.0),
dict(
type='OhemCrossEntropy',
thres=0.9,
min_kept=131072,
class_weight=class_weight,
loss_weight=0.4),
]),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
train_dataloader = dict(batch_size=6, num_workers=4)
iters = 120000
# optimizer
optimizer = dict(type='SGD', lr=0.01, momentum=0.9, weight_decay=0.0005)
optim_wrapper = dict(type='OptimWrapper', optimizer=optimizer, clip_grad=None)
# learning policy
param_scheduler = [
dict(
type='PolyLR',
eta_min=0,
power=0.9,
begin=0,
end=iters,
by_epoch=False)
]
# training schedule for 120k
train_cfg = dict(
type='IterBasedTrainLoop', max_iters=iters, val_interval=iters // 10)
val_cfg = dict(type='ValLoop')
test_cfg = dict(type='TestLoop')
default_hooks = dict(
timer=dict(type='IterTimerHook'),
logger=dict(type='LoggerHook', interval=50, log_metric_by_epoch=False),
param_scheduler=dict(type='ParamSchedulerHook'),
checkpoint=dict(
type='CheckpointHook', by_epoch=False, interval=iters // 10),
sampler_seed=dict(type='DistSamplerSeedHook'),
visualization=dict(type='SegVisualizationHook'))
randomness = dict(seed=304)

View File

@ -0,0 +1,14 @@
Collections:
- Name: ''
License: Apache License 2.0
Metadata:
Training Data:
- Cityscapes
Paper:
Title: Deep Dual-resolution Networks for Real-time and Accurate Semantic Segmentation
of Road Scenes
URL: http://arxiv.org/abs/2101.06085
README: configs/ddrnet/README.md
Frameworks:
- PyTorch
Models: []

View File

@ -3,6 +3,7 @@ from .beit import BEiT
from .bisenetv1 import BiSeNetV1
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
from .ddrnet import DDRNet
from .erfnet import ERFNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
@ -28,5 +29,6 @@ __all__ = [
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN'
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
'DDRNet'
]

View File

@ -0,0 +1,222 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
from mmcv.cnn import ConvModule, build_norm_layer
from mmengine.model import BaseModule
from mmseg.models.utils import DAPPM, BasicBlock, Bottleneck, resize
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType
@MODELS.register_module()
class DDRNet(BaseModule):
"""DDRNet backbone.
This backbone is the implementation of `Deep Dual-resolution Networks for
Real-time and Accurate Semantic Segmentation of Road Scenes
<http://arxiv.org/abs/2101.06085>`_.
Modified from https://github.com/ydhongHIT/DDRNet.
Args:
in_channels (int): Number of input image channels. Default: 3.
channels: (int): The base channels of DDRNet. Default: 32.
ppm_channels (int): The channels of PPM module. Default: 128.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
norm_cfg (dict): Config dict to build norm layer.
Default: dict(type='BN', requires_grad=True).
act_cfg (dict): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
init_cfg (dict, optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels: int = 3,
channels: int = 32,
ppm_channels: int = 128,
align_corners: bool = False,
norm_cfg: OptConfigType = dict(type='BN', requires_grad=True),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
init_cfg: OptConfigType = None):
super().__init__(init_cfg)
self.in_channels = in_channels
self.ppm_channels = ppm_channels
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.align_corners = align_corners
# stage 0-2
self.stem = self._make_stem_layer(in_channels, channels, num_blocks=2)
self.relu = nn.ReLU()
# low resolution(context) branch
self.context_branch_layers = nn.ModuleList()
for i in range(3):
self.context_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
inplanes=channels * 2**(i + 1),
planes=channels * 8 if i > 0 else channels * 4,
num_blocks=2 if i < 2 else 1,
stride=2))
# bilateral fusion
self.compression_1 = ConvModule(
channels * 4,
channels * 2,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.down_1 = ConvModule(
channels * 2,
channels * 4,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.compression_2 = ConvModule(
channels * 8,
channels * 2,
kernel_size=1,
norm_cfg=self.norm_cfg,
act_cfg=None)
self.down_2 = nn.Sequential(
ConvModule(
channels * 2,
channels * 4,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
channels * 4,
channels * 8,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=None))
# high resolution(spatial) branch
self.spatial_branch_layers = nn.ModuleList()
for i in range(3):
self.spatial_branch_layers.append(
self._make_layer(
block=BasicBlock if i < 2 else Bottleneck,
inplanes=channels * 2,
planes=channels * 2,
num_blocks=2 if i < 2 else 1,
))
self.spp = DAPPM(
channels * 16, ppm_channels, channels * 4, num_scales=5)
def _make_stem_layer(self, in_channels, channels, num_blocks):
layers = [
ConvModule(
in_channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg),
ConvModule(
channels,
channels,
kernel_size=3,
stride=2,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg)
]
layers.extend([
self._make_layer(BasicBlock, channels, channels, num_blocks),
nn.ReLU(),
self._make_layer(
BasicBlock, channels, channels * 2, num_blocks, stride=2),
nn.ReLU(),
])
return nn.Sequential(*layers)
def _make_layer(self, block, inplanes, planes, num_blocks, stride=1):
downsample = None
if stride != 1 or inplanes != planes * block.expansion:
downsample = nn.Sequential(
nn.Conv2d(
inplanes,
planes * block.expansion,
kernel_size=1,
stride=stride,
bias=False),
build_norm_layer(self.norm_cfg, planes * block.expansion)[1])
layers = [
block(
in_channels=inplanes,
channels=planes,
stride=stride,
downsample=downsample)
]
inplanes = planes * block.expansion
for i in range(1, num_blocks):
layers.append(
block(
in_channels=inplanes,
channels=planes,
stride=1,
norm_cfg=self.norm_cfg,
act_cfg_out=None if i == num_blocks - 1 else self.act_cfg))
return nn.Sequential(*layers)
def forward(self, x):
"""Forward function."""
out_size = (x.shape[-2] // 8, x.shape[-1] // 8)
# stage 0-2
x = self.stem(x)
# stage3
x_c = self.context_branch_layers[0](x)
x_s = self.spatial_branch_layers[0](x)
comp_c = self.compression_1(self.relu(x_c))
x_c += self.down_1(self.relu(x_s))
x_s += resize(
comp_c,
size=out_size,
mode='bilinear',
align_corners=self.align_corners)
if self.training:
temp_context = x_s.clone()
# stage4
x_c = self.context_branch_layers[1](self.relu(x_c))
x_s = self.spatial_branch_layers[1](self.relu(x_s))
comp_c = self.compression_2(self.relu(x_c))
x_c += self.down_2(self.relu(x_s))
x_s += resize(
comp_c,
size=out_size,
mode='bilinear',
align_corners=self.align_corners)
# stage5
x_s = self.spatial_branch_layers[2](self.relu(x_s))
x_c = self.context_branch_layers[2](self.relu(x_c))
x_c = self.spp(x_c)
x_c = resize(
x_c,
size=out_size,
mode='bilinear',
align_corners=self.align_corners)
return (temp_context, x_s + x_c) if self.training else x_s + x_c

View File

@ -4,6 +4,7 @@ from .apc_head import APCHead
from .aspp_head import ASPPHead
from .cc_head import CCHead
from .da_head import DAHead
from .ddr_head import DDRHead
from .dm_head import DMHead
from .dnl_head import DNLHead
from .dpt_head import DPTHead
@ -41,5 +42,5 @@ __all__ = [
'SETRMLAHead', 'DPTHead', 'SETRMLAHead', 'SegmenterMaskTransformerHead',
'SegformerHead', 'ISAHead', 'STDCHead', 'IterativeDecodeHead',
'KernelUpdateHead', 'KernelUpdator', 'MaskFormerHead', 'Mask2FormerHead',
'LightHamHead', 'PIDHead'
'LightHamHead', 'PIDHead', 'DDRHead'
]

View File

@ -0,0 +1,116 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Tuple, Union
import torch.nn as nn
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
from torch import Tensor
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
from mmseg.models.losses import accuracy
from mmseg.models.utils import resize
from mmseg.registry import MODELS
from mmseg.utils import OptConfigType, SampleList
@MODELS.register_module()
class DDRHead(BaseDecodeHead):
"""Decode head for DDRNet.
Args:
in_channels (int): Number of input channels.
channels (int): Number of output channels.
num_classes (int): Number of classes.
norm_cfg (dict, optional): Config dict for normalization layer.
Default: dict(type='BN').
act_cfg (dict, optional): Config dict for activation layer.
Default: dict(type='ReLU', inplace=True).
"""
def __init__(self,
in_channels: int,
channels: int,
num_classes: int,
norm_cfg: OptConfigType = dict(type='BN'),
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
**kwargs):
super().__init__(
in_channels,
channels,
num_classes=num_classes,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
**kwargs)
self.head = self._make_base_head(self.in_channels, self.channels)
self.aux_head = self._make_base_head(self.in_channels // 2,
self.channels)
self.aux_cls_seg = nn.Conv2d(
self.channels, self.out_channels, kernel_size=1)
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(
m.weight, mode='fan_out', nonlinearity='relu')
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(
self,
inputs: Union[Tensor,
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
if self.training:
c3_feat, c5_feat = inputs
x_c = self.head(c5_feat)
x_c = self.cls_seg(x_c)
x_s = self.aux_head(c3_feat)
x_s = self.aux_cls_seg(x_s)
return x_c, x_s
else:
x_c = self.head(inputs)
x_c = self.cls_seg(x_c)
return x_c
def _make_base_head(self, in_channels: int,
channels: int) -> nn.Sequential:
layers = [
ConvModule(
in_channels,
channels,
kernel_size=3,
padding=1,
norm_cfg=self.norm_cfg,
act_cfg=self.act_cfg,
order=('norm', 'act', 'conv')),
build_norm_layer(self.norm_cfg, channels)[1],
build_activation_layer(self.act_cfg),
]
return nn.Sequential(*layers)
def loss_by_feat(self, seg_logits: Tuple[Tensor],
batch_data_samples: SampleList) -> dict:
loss = dict()
context_logit, spatial_logit = seg_logits
seg_label = self._stack_batch_gt(batch_data_samples)
context_logit = resize(
context_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
spatial_logit = resize(
spatial_logit,
size=seg_label.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
seg_label = seg_label.squeeze(1)
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
loss['acc_seg'] = accuracy(
context_logit, seg_label, ignore_index=self.ignore_index)
return loss

View File

@ -8,6 +8,7 @@ Import:
- configs/cgnet/metafile.yaml
- configs/convnext/metafile.yaml
- configs/danet/metafile.yaml
- configs/ddrnet/metafile.yaml
- configs/deeplabv3/metafile.yaml
- configs/deeplabv3plus/metafile.yaml
- configs/dmnet/metafile.yaml