338 lines
12 KiB
Python
338 lines
12 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from typing import List
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule
|
|
from torch import Tensor
|
|
|
|
from mmseg.models.decode_heads.decode_head import BaseDecodeHead
|
|
from mmseg.models.losses import accuracy
|
|
from mmseg.models.utils import SelfAttentionBlock, resize
|
|
from mmseg.registry import MODELS
|
|
from mmseg.utils import SampleList
|
|
|
|
|
|
class ImageLevelContext(nn.Module):
|
|
""" Image-Level Context Module
|
|
Args:
|
|
feats_channels (int): Input channels of query/key feature.
|
|
transform_channels (int): Output channels of key/query transform.
|
|
concat_input (bool): whether to concat input feature.
|
|
align_corners (bool): align_corners argument of F.interpolate.
|
|
conv_cfg (dict|None): Config of conv layers.
|
|
norm_cfg (dict|None): Config of norm layers.
|
|
act_cfg (dict): Config of activation layers.
|
|
"""
|
|
|
|
def __init__(self,
|
|
feats_channels,
|
|
transform_channels,
|
|
concat_input=False,
|
|
align_corners=False,
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
act_cfg=None):
|
|
super().__init__()
|
|
self.align_corners = align_corners
|
|
self.global_avgpool = nn.AdaptiveAvgPool2d((1, 1))
|
|
self.correlate_net = SelfAttentionBlock(
|
|
key_in_channels=feats_channels * 2,
|
|
query_in_channels=feats_channels,
|
|
channels=transform_channels,
|
|
out_channels=feats_channels,
|
|
share_key_query=False,
|
|
query_downsample=None,
|
|
key_downsample=None,
|
|
key_query_num_convs=2,
|
|
value_out_num_convs=1,
|
|
key_query_norm=True,
|
|
value_out_norm=True,
|
|
matmul_norm=True,
|
|
with_out=True,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
)
|
|
if concat_input:
|
|
self.bottleneck = ConvModule(
|
|
feats_channels * 2,
|
|
feats_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
)
|
|
|
|
'''forward'''
|
|
|
|
def forward(self, x):
|
|
x_global = self.global_avgpool(x)
|
|
x_global = resize(
|
|
x_global,
|
|
size=x.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
feats_il = self.correlate_net(x, torch.cat([x_global, x], dim=1))
|
|
if hasattr(self, 'bottleneck'):
|
|
feats_il = self.bottleneck(torch.cat([x, feats_il], dim=1))
|
|
return feats_il
|
|
|
|
|
|
class SemanticLevelContext(nn.Module):
|
|
""" Semantic-Level Context Module
|
|
Args:
|
|
feats_channels (int): Input channels of query/key feature.
|
|
transform_channels (int): Output channels of key/query transform.
|
|
concat_input (bool): whether to concat input feature.
|
|
conv_cfg (dict|None): Config of conv layers.
|
|
norm_cfg (dict|None): Config of norm layers.
|
|
act_cfg (dict): Config of activation layers.
|
|
"""
|
|
|
|
def __init__(self,
|
|
feats_channels,
|
|
transform_channels,
|
|
concat_input=False,
|
|
conv_cfg=None,
|
|
norm_cfg=None,
|
|
act_cfg=None):
|
|
super().__init__()
|
|
self.correlate_net = SelfAttentionBlock(
|
|
key_in_channels=feats_channels,
|
|
query_in_channels=feats_channels,
|
|
channels=transform_channels,
|
|
out_channels=feats_channels,
|
|
share_key_query=False,
|
|
query_downsample=None,
|
|
key_downsample=None,
|
|
key_query_num_convs=2,
|
|
value_out_num_convs=1,
|
|
key_query_norm=True,
|
|
value_out_norm=True,
|
|
matmul_norm=True,
|
|
with_out=True,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
)
|
|
if concat_input:
|
|
self.bottleneck = ConvModule(
|
|
feats_channels * 2,
|
|
feats_channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
)
|
|
|
|
'''forward'''
|
|
|
|
def forward(self, x, preds, feats_il):
|
|
inputs = x
|
|
batch_size, num_channels, h, w = x.size()
|
|
num_classes = preds.size(1)
|
|
feats_sl = torch.zeros(batch_size, h * w, num_channels).type_as(x)
|
|
for batch_idx in range(batch_size):
|
|
# (C, H, W), (num_classes, H, W) --> (H*W, C), (H*W, num_classes)
|
|
feats_iter, preds_iter = x[batch_idx], preds[batch_idx]
|
|
feats_iter, preds_iter = feats_iter.reshape(
|
|
num_channels, -1), preds_iter.reshape(num_classes, -1)
|
|
feats_iter, preds_iter = feats_iter.permute(1,
|
|
0), preds_iter.permute(
|
|
1, 0)
|
|
# (H*W, )
|
|
argmax = preds_iter.argmax(1)
|
|
for clsid in range(num_classes):
|
|
mask = (argmax == clsid)
|
|
if mask.sum() == 0:
|
|
continue
|
|
feats_iter_cls = feats_iter[mask]
|
|
preds_iter_cls = preds_iter[:, clsid][mask]
|
|
weight = torch.softmax(preds_iter_cls, dim=0)
|
|
feats_iter_cls = feats_iter_cls * weight.unsqueeze(-1)
|
|
feats_iter_cls = feats_iter_cls.sum(0)
|
|
feats_sl[batch_idx][mask] = feats_iter_cls
|
|
feats_sl = feats_sl.reshape(batch_size, h, w, num_channels)
|
|
feats_sl = feats_sl.permute(0, 3, 1, 2).contiguous()
|
|
feats_sl = self.correlate_net(inputs, feats_sl)
|
|
if hasattr(self, 'bottleneck'):
|
|
feats_sl = self.bottleneck(torch.cat([feats_il, feats_sl], dim=1))
|
|
return feats_sl
|
|
|
|
|
|
@MODELS.register_module()
|
|
class ISNetHead(BaseDecodeHead):
|
|
"""ISNet: Integrate Image-Level and Semantic-Level
|
|
Context for Semantic Segmentation
|
|
|
|
This head is the implementation of `ISNet`
|
|
<https://arxiv.org/pdf/2108.12382.pdf>`_.
|
|
|
|
Args:
|
|
transform_channels (int): Output channels of key/query transform.
|
|
concat_input (bool): whether to concat input feature.
|
|
with_shortcut (bool): whether to use shortcut connection.
|
|
shortcut_in_channels (int): Input channels of shortcut.
|
|
shortcut_feat_channels (int): Output channels of shortcut.
|
|
dropout_ratio (float): Ratio of dropout.
|
|
"""
|
|
|
|
def __init__(self, transform_channels, concat_input, with_shortcut,
|
|
shortcut_in_channels, shortcut_feat_channels, dropout_ratio,
|
|
**kwargs):
|
|
super().__init__(**kwargs)
|
|
|
|
self.in_channels = self.in_channels[-1]
|
|
|
|
self.bottleneck = ConvModule(
|
|
self.in_channels,
|
|
self.channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg)
|
|
|
|
self.ilc_net = ImageLevelContext(
|
|
feats_channels=self.channels,
|
|
transform_channels=transform_channels,
|
|
concat_input=concat_input,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg,
|
|
align_corners=self.align_corners)
|
|
self.slc_net = SemanticLevelContext(
|
|
feats_channels=self.channels,
|
|
transform_channels=transform_channels,
|
|
concat_input=concat_input,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg)
|
|
|
|
self.decoder_stage1 = nn.Sequential(
|
|
ConvModule(
|
|
self.channels,
|
|
self.channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg),
|
|
nn.Dropout2d(dropout_ratio),
|
|
nn.Conv2d(
|
|
self.channels,
|
|
self.num_classes,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=True),
|
|
)
|
|
|
|
if with_shortcut:
|
|
self.shortcut = ConvModule(
|
|
shortcut_in_channels,
|
|
shortcut_feat_channels,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg)
|
|
self.decoder_stage2 = nn.Sequential(
|
|
ConvModule(
|
|
self.channels + shortcut_feat_channels,
|
|
self.channels,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg),
|
|
nn.Dropout2d(dropout_ratio),
|
|
nn.Conv2d(
|
|
self.channels,
|
|
self.num_classes,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=True),
|
|
)
|
|
else:
|
|
self.decoder_stage2 = nn.Sequential(
|
|
nn.Dropout2d(dropout_ratio),
|
|
nn.Conv2d(
|
|
self.channels,
|
|
self.num_classes,
|
|
kernel_size=1,
|
|
stride=1,
|
|
padding=0,
|
|
bias=True),
|
|
)
|
|
|
|
self.conv_seg = None
|
|
self.dropout = None
|
|
|
|
def forward(self, inputs):
|
|
x = self._transform_inputs(inputs)
|
|
feats = self.bottleneck(x[-1])
|
|
|
|
feats_il = self.ilc_net(feats)
|
|
|
|
preds_stage1 = self.decoder_stage1(feats)
|
|
preds_stage1 = resize(
|
|
preds_stage1,
|
|
size=feats.size()[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
|
|
feats_sl = self.slc_net(feats, preds_stage1, feats_il)
|
|
|
|
if hasattr(self, 'shortcut'):
|
|
shortcut_out = self.shortcut(x[0])
|
|
feats_sl = resize(
|
|
feats_sl,
|
|
size=shortcut_out.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
feats_sl = torch.cat([feats_sl, shortcut_out], dim=1)
|
|
preds_stage2 = self.decoder_stage2(feats_sl)
|
|
|
|
return preds_stage1, preds_stage2
|
|
|
|
def loss_by_feat(self, seg_logits: Tensor,
|
|
batch_data_samples: SampleList) -> dict:
|
|
seg_label = self._stack_batch_gt(batch_data_samples)
|
|
loss = dict()
|
|
|
|
if self.sampler is not None:
|
|
seg_weight = self.sampler.sample(seg_logits[-1], seg_label)
|
|
else:
|
|
seg_weight = None
|
|
seg_label = seg_label.squeeze(1)
|
|
|
|
for seg_logit, loss_decode in zip(seg_logits, self.loss_decode):
|
|
seg_logit = resize(
|
|
input=seg_logit,
|
|
size=seg_label.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
loss[loss_decode.name] = loss_decode(
|
|
seg_logit,
|
|
seg_label,
|
|
seg_weight,
|
|
ignore_index=self.ignore_index)
|
|
|
|
loss['acc_seg'] = accuracy(
|
|
seg_logits[-1], seg_label, ignore_index=self.ignore_index)
|
|
return loss
|
|
|
|
def predict_by_feat(self, seg_logits: Tensor,
|
|
batch_img_metas: List[dict]) -> Tensor:
|
|
_, seg_logits_stage2 = seg_logits
|
|
return super().predict_by_feat(seg_logits_stage2, batch_img_metas)
|