294 lines
10 KiB
Python
294 lines
10 KiB
Python
import math
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import ConvModule, Linear, build_activation_layer
|
|
from mmcv.runner import BaseModule
|
|
|
|
from mmseg.ops import resize
|
|
from ..builder import HEADS
|
|
from .decode_head import BaseDecodeHead
|
|
|
|
|
|
class ReassembleBlocks(BaseModule):
|
|
"""ViTPostProcessBlock, process cls_token in ViT backbone output and
|
|
rearrange the feature vector to feature map.
|
|
|
|
Args:
|
|
in_channels (int): ViT feature channels. Default: 768.
|
|
out_channels (List): output channels of each stage.
|
|
Default: [96, 192, 384, 768].
|
|
readout_type (str): Type of readout operation. Default: 'ignore'.
|
|
patch_size (int): The patch size. Default: 16.
|
|
init_cfg (dict, optional): Initialization config dict. Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels=768,
|
|
out_channels=[96, 192, 384, 768],
|
|
readout_type='ignore',
|
|
patch_size=16,
|
|
init_cfg=None):
|
|
super(ReassembleBlocks, self).__init__(init_cfg)
|
|
|
|
assert readout_type in ['ignore', 'add', 'project']
|
|
self.readout_type = readout_type
|
|
self.patch_size = patch_size
|
|
|
|
self.projects = nn.ModuleList([
|
|
ConvModule(
|
|
in_channels=in_channels,
|
|
out_channels=out_channel,
|
|
kernel_size=1,
|
|
act_cfg=None,
|
|
) for out_channel in out_channels
|
|
])
|
|
|
|
self.resize_layers = nn.ModuleList([
|
|
nn.ConvTranspose2d(
|
|
in_channels=out_channels[0],
|
|
out_channels=out_channels[0],
|
|
kernel_size=4,
|
|
stride=4,
|
|
padding=0),
|
|
nn.ConvTranspose2d(
|
|
in_channels=out_channels[1],
|
|
out_channels=out_channels[1],
|
|
kernel_size=2,
|
|
stride=2,
|
|
padding=0),
|
|
nn.Identity(),
|
|
nn.Conv2d(
|
|
in_channels=out_channels[3],
|
|
out_channels=out_channels[3],
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1)
|
|
])
|
|
if self.readout_type == 'project':
|
|
self.readout_projects = nn.ModuleList()
|
|
for _ in range(len(self.projects)):
|
|
self.readout_projects.append(
|
|
nn.Sequential(
|
|
Linear(2 * in_channels, in_channels),
|
|
build_activation_layer(dict(type='GELU'))))
|
|
|
|
def forward(self, inputs):
|
|
assert isinstance(inputs, list)
|
|
out = []
|
|
for i, x in enumerate(inputs):
|
|
assert len(x) == 2
|
|
x, cls_token = x[0], x[1]
|
|
feature_shape = x.shape
|
|
if self.readout_type == 'project':
|
|
x = x.flatten(2).permute((0, 2, 1))
|
|
readout = cls_token.unsqueeze(1).expand_as(x)
|
|
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
|
x = x.permute(0, 2, 1).reshape(feature_shape)
|
|
elif self.readout_type == 'add':
|
|
x = x.flatten(2) + cls_token.unsqueeze(-1)
|
|
x = x.reshape(feature_shape)
|
|
else:
|
|
pass
|
|
x = self.projects[i](x)
|
|
x = self.resize_layers[i](x)
|
|
out.append(x)
|
|
return out
|
|
|
|
|
|
class PreActResidualConvUnit(BaseModule):
|
|
"""ResidualConvUnit, pre-activate residual unit.
|
|
|
|
Args:
|
|
in_channels (int): number of channels in the input feature map.
|
|
act_cfg (dict): dictionary to construct and config activation layer.
|
|
norm_cfg (dict): dictionary to construct and config norm layer.
|
|
stride (int): stride of the first block. Default: 1
|
|
dilation (int): dilation rate for convs layers. Default: 1.
|
|
init_cfg (dict, optional): Initialization config dict. Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
act_cfg,
|
|
norm_cfg,
|
|
stride=1,
|
|
dilation=1,
|
|
init_cfg=None):
|
|
super(PreActResidualConvUnit, self).__init__(init_cfg)
|
|
|
|
self.conv1 = ConvModule(
|
|
in_channels,
|
|
in_channels,
|
|
3,
|
|
stride=stride,
|
|
padding=dilation,
|
|
dilation=dilation,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
bias=False,
|
|
order=('act', 'conv', 'norm'))
|
|
|
|
self.conv2 = ConvModule(
|
|
in_channels,
|
|
in_channels,
|
|
3,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg,
|
|
bias=False,
|
|
order=('act', 'conv', 'norm'))
|
|
|
|
def forward(self, inputs):
|
|
inputs_ = inputs.clone()
|
|
x = self.conv1(inputs)
|
|
x = self.conv2(x)
|
|
return x + inputs_
|
|
|
|
|
|
class FeatureFusionBlock(BaseModule):
|
|
"""FeatureFusionBlock, merge feature map from different stages.
|
|
|
|
Args:
|
|
in_channels (int): Input channels.
|
|
act_cfg (dict): The activation config for ResidualConvUnit.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
expand (bool): Whether expand the channels in post process block.
|
|
Default: False.
|
|
align_corners (bool): align_corner setting for bilinear upsample.
|
|
Default: True.
|
|
init_cfg (dict, optional): Initialization config dict. Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
act_cfg,
|
|
norm_cfg,
|
|
expand=False,
|
|
align_corners=True,
|
|
init_cfg=None):
|
|
super(FeatureFusionBlock, self).__init__(init_cfg)
|
|
|
|
self.in_channels = in_channels
|
|
self.expand = expand
|
|
self.align_corners = align_corners
|
|
|
|
self.out_channels = in_channels
|
|
if self.expand:
|
|
self.out_channels = in_channels // 2
|
|
|
|
self.project = ConvModule(
|
|
self.in_channels,
|
|
self.out_channels,
|
|
kernel_size=1,
|
|
act_cfg=None,
|
|
bias=True)
|
|
|
|
self.res_conv_unit1 = PreActResidualConvUnit(
|
|
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
|
self.res_conv_unit2 = PreActResidualConvUnit(
|
|
in_channels=self.in_channels, act_cfg=act_cfg, norm_cfg=norm_cfg)
|
|
|
|
def forward(self, *inputs):
|
|
x = inputs[0]
|
|
if len(inputs) == 2:
|
|
if x.shape != inputs[1].shape:
|
|
res = resize(
|
|
inputs[1],
|
|
size=(x.shape[2], x.shape[3]),
|
|
mode='bilinear',
|
|
align_corners=False)
|
|
else:
|
|
res = inputs[1]
|
|
x = x + self.res_conv_unit1(res)
|
|
x = self.res_conv_unit2(x)
|
|
x = resize(
|
|
x,
|
|
scale_factor=2,
|
|
mode='bilinear',
|
|
align_corners=self.align_corners)
|
|
x = self.project(x)
|
|
return x
|
|
|
|
|
|
@HEADS.register_module()
|
|
class DPTHead(BaseDecodeHead):
|
|
"""Vision Transformers for Dense Prediction.
|
|
|
|
This head is implemented of `DPT <https://arxiv.org/abs/2103.13413>`_.
|
|
|
|
Args:
|
|
embed_dims (int): The embed dimension of the ViT backbone.
|
|
Default: 768.
|
|
post_process_channels (List): Out channels of post process conv
|
|
layers. Default: [96, 192, 384, 768].
|
|
readout_type (str): Type of readout operation. Default: 'ignore'.
|
|
patch_size (int): The patch size. Default: 16.
|
|
expand_channels (bool): Whether expand the channels in post process
|
|
block. Default: False.
|
|
act_cfg (dict): The activation config for residual conv unit.
|
|
Defalut dict(type='ReLU').
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
"""
|
|
|
|
def __init__(self,
|
|
embed_dims=768,
|
|
post_process_channels=[96, 192, 384, 768],
|
|
readout_type='ignore',
|
|
patch_size=16,
|
|
expand_channels=False,
|
|
act_cfg=dict(type='ReLU'),
|
|
norm_cfg=dict(type='BN'),
|
|
**kwargs):
|
|
super(DPTHead, self).__init__(**kwargs)
|
|
|
|
self.in_channels = self.in_channels
|
|
self.expand_channels = expand_channels
|
|
self.reassemble_blocks = ReassembleBlocks(embed_dims,
|
|
post_process_channels,
|
|
readout_type, patch_size)
|
|
|
|
self.post_process_channels = [
|
|
channel * math.pow(2, i) if expand_channels else channel
|
|
for i, channel in enumerate(post_process_channels)
|
|
]
|
|
self.convs = nn.ModuleList()
|
|
for channel in self.post_process_channels:
|
|
self.convs.append(
|
|
ConvModule(
|
|
channel,
|
|
self.channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
act_cfg=None,
|
|
bias=False))
|
|
self.fusion_blocks = nn.ModuleList()
|
|
for _ in range(len(self.convs)):
|
|
self.fusion_blocks.append(
|
|
FeatureFusionBlock(self.channels, act_cfg, norm_cfg))
|
|
self.fusion_blocks[0].res_conv_unit1 = None
|
|
self.project = ConvModule(
|
|
self.channels,
|
|
self.channels,
|
|
kernel_size=3,
|
|
padding=1,
|
|
norm_cfg=norm_cfg)
|
|
self.num_fusion_blocks = len(self.fusion_blocks)
|
|
self.num_reassemble_blocks = len(self.reassemble_blocks.resize_layers)
|
|
self.num_post_process_channels = len(self.post_process_channels)
|
|
assert self.num_fusion_blocks == self.num_reassemble_blocks
|
|
assert self.num_reassemble_blocks == self.num_post_process_channels
|
|
|
|
def forward(self, inputs):
|
|
assert len(inputs) == self.num_reassemble_blocks
|
|
x = self._transform_inputs(inputs)
|
|
x = self.reassemble_blocks(x)
|
|
x = [self.convs[i](feature) for i, feature in enumerate(x)]
|
|
out = self.fusion_blocks[0](x[-1])
|
|
for i in range(1, len(self.fusion_blocks)):
|
|
out = self.fusion_blocks[i](out, x[-(i + 1)])
|
|
out = self.project(out)
|
|
out = self.cls_seg(out)
|
|
return out
|