188 lines
6.6 KiB
Python
188 lines
6.6 KiB
Python
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule, build_norm_layer
|
|
|
|
from mmseg.ops import Encoding, resize
|
|
from ..builder import HEADS, build_loss
|
|
from .decode_head import BaseDecodeHead
|
|
|
|
|
|
class EncModule(nn.Module):
|
|
"""Encoding Module used in EncNet.
|
|
|
|
Args:
|
|
in_channels (int): Input channels.
|
|
num_codes (int): Number of code words.
|
|
conv_cfg (dict|None): Config of conv layers.
|
|
norm_cfg (dict|None): Config of norm layers.
|
|
act_cfg (dict): Config of activation layers.
|
|
"""
|
|
|
|
def __init__(self, in_channels, num_codes, conv_cfg, norm_cfg, act_cfg):
|
|
super(EncModule, self).__init__()
|
|
self.encoding_project = ConvModule(
|
|
in_channels,
|
|
in_channels,
|
|
1,
|
|
conv_cfg=conv_cfg,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
# TODO: resolve this hack
|
|
# change to 1d
|
|
if norm_cfg is not None:
|
|
encoding_norm_cfg = norm_cfg.copy()
|
|
if encoding_norm_cfg['type'] in ['BN', 'IN']:
|
|
encoding_norm_cfg['type'] += '1d'
|
|
else:
|
|
encoding_norm_cfg['type'] = encoding_norm_cfg['type'].replace(
|
|
'2d', '1d')
|
|
else:
|
|
# fallback to BN1d
|
|
encoding_norm_cfg = dict(type='BN1d')
|
|
self.encoding = nn.Sequential(
|
|
Encoding(channels=in_channels, num_codes=num_codes),
|
|
build_norm_layer(encoding_norm_cfg, num_codes)[1],
|
|
nn.ReLU(inplace=True))
|
|
self.fc = nn.Sequential(
|
|
nn.Linear(in_channels, in_channels), nn.Sigmoid())
|
|
|
|
def forward(self, x):
|
|
"""Forward function."""
|
|
encoding_projection = self.encoding_project(x)
|
|
encoding_feat = self.encoding(encoding_projection).mean(dim=1)
|
|
batch_size, channels, _, _ = x.size()
|
|
gamma = self.fc(encoding_feat)
|
|
y = gamma.view(batch_size, channels, 1, 1)
|
|
output = F.relu_(x + x * y)
|
|
return encoding_feat, output
|
|
|
|
|
|
@HEADS.register_module()
|
|
class EncHead(BaseDecodeHead):
|
|
"""Context Encoding for Semantic Segmentation.
|
|
|
|
This head is the implementation of `EncNet
|
|
<https://arxiv.org/abs/1803.08904>`_.
|
|
|
|
Args:
|
|
num_codes (int): Number of code words. Default: 32.
|
|
use_se_loss (bool): Whether use Semantic Encoding Loss (SE-loss) to
|
|
regularize the training. Default: True.
|
|
add_lateral (bool): Whether use lateral connection to fuse features.
|
|
Default: False.
|
|
loss_se_decode (dict): Config of decode loss.
|
|
Default: dict(type='CrossEntropyLoss', use_sigmoid=True).
|
|
"""
|
|
|
|
def __init__(self,
|
|
num_codes=32,
|
|
use_se_loss=True,
|
|
add_lateral=False,
|
|
loss_se_decode=dict(
|
|
type='CrossEntropyLoss',
|
|
use_sigmoid=True,
|
|
loss_weight=0.2),
|
|
**kwargs):
|
|
super(EncHead, self).__init__(
|
|
input_transform='multiple_select', **kwargs)
|
|
self.use_se_loss = use_se_loss
|
|
self.add_lateral = add_lateral
|
|
self.num_codes = num_codes
|
|
self.bottleneck = ConvModule(
|
|
self.in_channels[-1],
|
|
self.channels,
|
|
3,
|
|
padding=1,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg)
|
|
if add_lateral:
|
|
self.lateral_convs = nn.ModuleList()
|
|
for in_channels in self.in_channels[:-1]: # skip the last one
|
|
self.lateral_convs.append(
|
|
ConvModule(
|
|
in_channels,
|
|
self.channels,
|
|
1,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg))
|
|
self.fusion = ConvModule(
|
|
len(self.in_channels) * self.channels,
|
|
self.channels,
|
|
3,
|
|
padding=1,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg)
|
|
self.enc_module = EncModule(
|
|
self.channels,
|
|
num_codes=num_codes,
|
|
conv_cfg=self.conv_cfg,
|
|
norm_cfg=self.norm_cfg,
|
|
act_cfg=self.act_cfg)
|
|
if self.use_se_loss:
|
|
self.loss_se_decode = build_loss(loss_se_decode)
|
|
self.se_layer = nn.Linear(self.channels, self.num_classes)
|
|
|
|
def forward(self, inputs):
|
|
"""Forward function."""
|
|
inputs = self._transform_inputs(inputs)
|
|
feat = self.bottleneck(inputs[-1])
|
|
if self.add_lateral:
|
|
laterals = [
|
|
resize(
|
|
lateral_conv(inputs[i]),
|
|
size=feat.shape[2:],
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
for i, lateral_conv in enumerate(self.lateral_convs)
|
|
]
|
|
feat = self.fusion(torch.cat([feat, *laterals], 1))
|
|
encode_feat, output = self.enc_module(feat)
|
|
output = self.cls_seg(output)
|
|
if self.use_se_loss:
|
|
se_output = self.se_layer(encode_feat)
|
|
return output, se_output
|
|
else:
|
|
return output
|
|
|
|
def forward_test(self, inputs, img_metas, test_cfg):
|
|
"""Forward function for testing, ignore se_loss."""
|
|
if self.use_se_loss:
|
|
return self.forward(inputs)[0]
|
|
else:
|
|
return self.forward(inputs)
|
|
|
|
@staticmethod
|
|
def _convert_to_onehot_labels(seg_label, num_classes):
|
|
"""Convert segmentation label to onehot.
|
|
|
|
Args:
|
|
seg_label (Tensor): Segmentation label of shape (N, H, W).
|
|
num_classes (int): Number of classes.
|
|
|
|
Returns:
|
|
Tensor: Onehot labels of shape (N, num_classes).
|
|
"""
|
|
|
|
batch_size = seg_label.size(0)
|
|
onehot_labels = seg_label.new_zeros((batch_size, num_classes))
|
|
for i in range(batch_size):
|
|
hist = seg_label[i].float().histc(
|
|
bins=num_classes, min=0, max=num_classes - 1)
|
|
onehot_labels[i] = hist > 0
|
|
return onehot_labels
|
|
|
|
def losses(self, seg_logit, seg_label):
|
|
"""Compute segmentation and semantic encoding loss."""
|
|
seg_logit, se_seg_logit = seg_logit
|
|
loss = dict()
|
|
loss.update(super(EncHead, self).losses(seg_logit, seg_label))
|
|
se_loss = self.loss_se_decode(
|
|
se_seg_logit,
|
|
self._convert_to_onehot_labels(seg_label, self.num_classes))
|
|
loss['loss_se'] = se_loss
|
|
return loss
|