mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
117 lines
3.9 KiB
Python
117 lines
3.9 KiB
Python
|
# Copyright (c) OpenMMLab. All rights reserved.
|
||
|
from typing import Tuple, Union
|
||
|
|
||
|
import torch.nn as nn
|
||
|
from mmcv.cnn import ConvModule, build_activation_layer, build_norm_layer
|
||
|
from torch import Tensor
|
||
|
|
||
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
||
|
from mmseg.models.losses import accuracy
|
||
|
from mmseg.models.utils import resize
|
||
|
from mmseg.registry import MODELS
|
||
|
from mmseg.utils import OptConfigType, SampleList
|
||
|
|
||
|
|
||
|
@MODELS.register_module()
|
||
|
class DDRHead(BaseDecodeHead):
|
||
|
"""Decode head for DDRNet.
|
||
|
|
||
|
Args:
|
||
|
in_channels (int): Number of input channels.
|
||
|
channels (int): Number of output channels.
|
||
|
num_classes (int): Number of classes.
|
||
|
norm_cfg (dict, optional): Config dict for normalization layer.
|
||
|
Default: dict(type='BN').
|
||
|
act_cfg (dict, optional): Config dict for activation layer.
|
||
|
Default: dict(type='ReLU', inplace=True).
|
||
|
"""
|
||
|
|
||
|
def __init__(self,
|
||
|
in_channels: int,
|
||
|
channels: int,
|
||
|
num_classes: int,
|
||
|
norm_cfg: OptConfigType = dict(type='BN'),
|
||
|
act_cfg: OptConfigType = dict(type='ReLU', inplace=True),
|
||
|
**kwargs):
|
||
|
super().__init__(
|
||
|
in_channels,
|
||
|
channels,
|
||
|
num_classes=num_classes,
|
||
|
norm_cfg=norm_cfg,
|
||
|
act_cfg=act_cfg,
|
||
|
**kwargs)
|
||
|
|
||
|
self.head = self._make_base_head(self.in_channels, self.channels)
|
||
|
self.aux_head = self._make_base_head(self.in_channels // 2,
|
||
|
self.channels)
|
||
|
self.aux_cls_seg = nn.Conv2d(
|
||
|
self.channels, self.out_channels, kernel_size=1)
|
||
|
|
||
|
def init_weights(self):
|
||
|
for m in self.modules():
|
||
|
if isinstance(m, nn.Conv2d):
|
||
|
nn.init.kaiming_normal_(
|
||
|
m.weight, mode='fan_out', nonlinearity='relu')
|
||
|
elif isinstance(m, nn.BatchNorm2d):
|
||
|
nn.init.constant_(m.weight, 1)
|
||
|
nn.init.constant_(m.bias, 0)
|
||
|
|
||
|
def forward(
|
||
|
self,
|
||
|
inputs: Union[Tensor,
|
||
|
Tuple[Tensor]]) -> Union[Tensor, Tuple[Tensor]]:
|
||
|
if self.training:
|
||
|
c3_feat, c5_feat = inputs
|
||
|
x_c = self.head(c5_feat)
|
||
|
x_c = self.cls_seg(x_c)
|
||
|
x_s = self.aux_head(c3_feat)
|
||
|
x_s = self.aux_cls_seg(x_s)
|
||
|
|
||
|
return x_c, x_s
|
||
|
else:
|
||
|
x_c = self.head(inputs)
|
||
|
x_c = self.cls_seg(x_c)
|
||
|
return x_c
|
||
|
|
||
|
def _make_base_head(self, in_channels: int,
|
||
|
channels: int) -> nn.Sequential:
|
||
|
layers = [
|
||
|
ConvModule(
|
||
|
in_channels,
|
||
|
channels,
|
||
|
kernel_size=3,
|
||
|
padding=1,
|
||
|
norm_cfg=self.norm_cfg,
|
||
|
act_cfg=self.act_cfg,
|
||
|
order=('norm', 'act', 'conv')),
|
||
|
build_norm_layer(self.norm_cfg, channels)[1],
|
||
|
build_activation_layer(self.act_cfg),
|
||
|
]
|
||
|
|
||
|
return nn.Sequential(*layers)
|
||
|
|
||
|
def loss_by_feat(self, seg_logits: Tuple[Tensor],
|
||
|
batch_data_samples: SampleList) -> dict:
|
||
|
loss = dict()
|
||
|
context_logit, spatial_logit = seg_logits
|
||
|
seg_label = self._stack_batch_gt(batch_data_samples)
|
||
|
|
||
|
context_logit = resize(
|
||
|
context_logit,
|
||
|
size=seg_label.shape[2:],
|
||
|
mode='bilinear',
|
||
|
align_corners=self.align_corners)
|
||
|
spatial_logit = resize(
|
||
|
spatial_logit,
|
||
|
size=seg_label.shape[2:],
|
||
|
mode='bilinear',
|
||
|
align_corners=self.align_corners)
|
||
|
seg_label = seg_label.squeeze(1)
|
||
|
|
||
|
loss['loss_context'] = self.loss_decode[0](context_logit, seg_label)
|
||
|
loss['loss_spatial'] = self.loss_decode[1](spatial_logit, seg_label)
|
||
|
loss['acc_seg'] = accuracy(
|
||
|
context_logit, seg_label, ignore_index=self.ignore_index)
|
||
|
|
||
|
return loss
|