423 lines
16 KiB
Python
423 lines
16 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
"""Modified from https://github.com/MichaelFan01/STDC-Seg."""
|
|
import torch
|
|
import torch.nn as nn
|
|
import torch.nn.functional as F
|
|
from mmcv.cnn import ConvModule
|
|
from mmcv.runner.base_module import BaseModule, ModuleList, Sequential
|
|
|
|
from mmseg.ops import resize
|
|
from ..builder import BACKBONES, build_backbone
|
|
from .bisenetv1 import AttentionRefinementModule
|
|
|
|
|
|
class STDCModule(BaseModule):
|
|
"""STDCModule.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
out_channels (int): The number of output channels before scaling.
|
|
stride (int): The number of stride for the first conv layer.
|
|
norm_cfg (dict): Config dict for normalization layer. Default: None.
|
|
act_cfg (dict): The activation config for conv layers.
|
|
num_convs (int): Numbers of conv layers.
|
|
fusion_type (str): Type of fusion operation. Default: 'add'.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
stride,
|
|
norm_cfg=None,
|
|
act_cfg=None,
|
|
num_convs=4,
|
|
fusion_type='add',
|
|
init_cfg=None):
|
|
super(STDCModule, self).__init__(init_cfg=init_cfg)
|
|
assert num_convs > 1
|
|
assert fusion_type in ['add', 'cat']
|
|
self.stride = stride
|
|
self.with_downsample = True if self.stride == 2 else False
|
|
self.fusion_type = fusion_type
|
|
|
|
self.layers = ModuleList()
|
|
conv_0 = ConvModule(
|
|
in_channels, out_channels // 2, kernel_size=1, norm_cfg=norm_cfg)
|
|
|
|
if self.with_downsample:
|
|
self.downsample = ConvModule(
|
|
out_channels // 2,
|
|
out_channels // 2,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
groups=out_channels // 2,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None)
|
|
|
|
if self.fusion_type == 'add':
|
|
self.layers.append(nn.Sequential(conv_0, self.downsample))
|
|
self.skip = Sequential(
|
|
ConvModule(
|
|
in_channels,
|
|
in_channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
groups=in_channels,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None),
|
|
ConvModule(
|
|
in_channels,
|
|
out_channels,
|
|
1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=None))
|
|
else:
|
|
self.layers.append(conv_0)
|
|
self.skip = nn.AvgPool2d(kernel_size=3, stride=2, padding=1)
|
|
else:
|
|
self.layers.append(conv_0)
|
|
|
|
for i in range(1, num_convs):
|
|
out_factor = 2**(i + 1) if i != num_convs - 1 else 2**i
|
|
self.layers.append(
|
|
ConvModule(
|
|
out_channels // 2**i,
|
|
out_channels // out_factor,
|
|
kernel_size=3,
|
|
stride=1,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg))
|
|
|
|
def forward(self, inputs):
|
|
if self.fusion_type == 'add':
|
|
out = self.forward_add(inputs)
|
|
else:
|
|
out = self.forward_cat(inputs)
|
|
return out
|
|
|
|
def forward_add(self, inputs):
|
|
layer_outputs = []
|
|
x = inputs.clone()
|
|
for layer in self.layers:
|
|
x = layer(x)
|
|
layer_outputs.append(x)
|
|
if self.with_downsample:
|
|
inputs = self.skip(inputs)
|
|
|
|
return torch.cat(layer_outputs, dim=1) + inputs
|
|
|
|
def forward_cat(self, inputs):
|
|
x0 = self.layers[0](inputs)
|
|
layer_outputs = [x0]
|
|
for i, layer in enumerate(self.layers[1:]):
|
|
if i == 0:
|
|
if self.with_downsample:
|
|
x = layer(self.downsample(x0))
|
|
else:
|
|
x = layer(x0)
|
|
else:
|
|
x = layer(x)
|
|
layer_outputs.append(x)
|
|
if self.with_downsample:
|
|
layer_outputs[0] = self.skip(x0)
|
|
return torch.cat(layer_outputs, dim=1)
|
|
|
|
|
|
class FeatureFusionModule(BaseModule):
|
|
"""Feature Fusion Module. This module is different from FeatureFusionModule
|
|
in BiSeNetV1. It uses two ConvModules in `self.attention` whose inter
|
|
channel number is calculated by given `scale_factor`, while
|
|
FeatureFusionModule in BiSeNetV1 only uses one ConvModule in
|
|
`self.conv_atten`.
|
|
|
|
Args:
|
|
in_channels (int): The number of input channels.
|
|
out_channels (int): The number of output channels.
|
|
scale_factor (int): The number of channel scale factor.
|
|
Default: 4.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): The activation config for conv layers.
|
|
Default: dict(type='ReLU').
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
scale_factor=4,
|
|
norm_cfg=dict(type='BN'),
|
|
act_cfg=dict(type='ReLU'),
|
|
init_cfg=None):
|
|
super(FeatureFusionModule, self).__init__(init_cfg=init_cfg)
|
|
channels = out_channels // scale_factor
|
|
self.conv0 = ConvModule(
|
|
in_channels, out_channels, 1, norm_cfg=norm_cfg, act_cfg=act_cfg)
|
|
self.attention = nn.Sequential(
|
|
nn.AdaptiveAvgPool2d((1, 1)),
|
|
ConvModule(
|
|
out_channels,
|
|
channels,
|
|
1,
|
|
norm_cfg=None,
|
|
bias=False,
|
|
act_cfg=act_cfg),
|
|
ConvModule(
|
|
channels,
|
|
out_channels,
|
|
1,
|
|
norm_cfg=None,
|
|
bias=False,
|
|
act_cfg=None), nn.Sigmoid())
|
|
|
|
def forward(self, spatial_inputs, context_inputs):
|
|
inputs = torch.cat([spatial_inputs, context_inputs], dim=1)
|
|
x = self.conv0(inputs)
|
|
attn = self.attention(x)
|
|
x_attn = x * attn
|
|
return x_attn + x
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class STDCNet(BaseModule):
|
|
"""This backbone is the implementation of `Rethinking BiSeNet For Real-time
|
|
Semantic Segmentation <https://arxiv.org/abs/2104.13188>`_.
|
|
|
|
Args:
|
|
stdc_type (int): The type of backbone structure,
|
|
`STDCNet1` and`STDCNet2` denotes two main backbones in paper,
|
|
whose FLOPs is 813M and 1446M, respectively.
|
|
in_channels (int): The num of input_channels.
|
|
channels (tuple[int]): The output channels for each stage.
|
|
bottleneck_type (str): The type of STDC Module type, the value must
|
|
be 'add' or 'cat'.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
act_cfg (dict): The activation config for conv layers.
|
|
num_convs (int): Numbers of conv layer at each STDC Module.
|
|
Default: 4.
|
|
with_final_conv (bool): Whether add a conv layer at the Module output.
|
|
Default: True.
|
|
pretrained (str, optional): Model pretrained path. Default: None.
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
|
|
Example:
|
|
>>> import torch
|
|
>>> stdc_type = 'STDCNet1'
|
|
>>> in_channels = 3
|
|
>>> channels = (32, 64, 256, 512, 1024)
|
|
>>> bottleneck_type = 'cat'
|
|
>>> inputs = torch.rand(1, 3, 1024, 2048)
|
|
>>> self = STDCNet(stdc_type, in_channels,
|
|
... channels, bottleneck_type).eval()
|
|
>>> outputs = self.forward(inputs)
|
|
>>> for i in range(len(outputs)):
|
|
... print(f'outputs[{i}].shape = {outputs[i].shape}')
|
|
outputs[0].shape = torch.Size([1, 256, 128, 256])
|
|
outputs[1].shape = torch.Size([1, 512, 64, 128])
|
|
outputs[2].shape = torch.Size([1, 1024, 32, 64])
|
|
"""
|
|
|
|
arch_settings = {
|
|
'STDCNet1': [(2, 1), (2, 1), (2, 1)],
|
|
'STDCNet2': [(2, 1, 1, 1), (2, 1, 1, 1, 1), (2, 1, 1)]
|
|
}
|
|
|
|
def __init__(self,
|
|
stdc_type,
|
|
in_channels,
|
|
channels,
|
|
bottleneck_type,
|
|
norm_cfg,
|
|
act_cfg,
|
|
num_convs=4,
|
|
with_final_conv=False,
|
|
pretrained=None,
|
|
init_cfg=None):
|
|
super(STDCNet, self).__init__(init_cfg=init_cfg)
|
|
assert stdc_type in self.arch_settings, \
|
|
f'invalid structure {stdc_type} for STDCNet.'
|
|
assert bottleneck_type in ['add', 'cat'],\
|
|
f'bottleneck_type must be `add` or `cat`, got {bottleneck_type}'
|
|
|
|
assert len(channels) == 5,\
|
|
f'invalid channels length {len(channels)} for STDCNet.'
|
|
|
|
self.in_channels = in_channels
|
|
self.channels = channels
|
|
self.stage_strides = self.arch_settings[stdc_type]
|
|
self.prtrained = pretrained
|
|
self.num_convs = num_convs
|
|
self.with_final_conv = with_final_conv
|
|
|
|
self.stages = ModuleList([
|
|
ConvModule(
|
|
self.in_channels,
|
|
self.channels[0],
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg),
|
|
ConvModule(
|
|
self.channels[0],
|
|
self.channels[1],
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
])
|
|
# `self.num_shallow_features` is the number of shallow modules in
|
|
# `STDCNet`, which is noted as `Stage1` and `Stage2` in original paper.
|
|
# They are both not used for following modules like Attention
|
|
# Refinement Module and Feature Fusion Module.
|
|
# Thus they would be cut from `outs`. Please refer to Figure 4
|
|
# of original paper for more details.
|
|
self.num_shallow_features = len(self.stages)
|
|
|
|
for strides in self.stage_strides:
|
|
idx = len(self.stages) - 1
|
|
self.stages.append(
|
|
self._make_stage(self.channels[idx], self.channels[idx + 1],
|
|
strides, norm_cfg, act_cfg, bottleneck_type))
|
|
# After appending, `self.stages` is a ModuleList including several
|
|
# shallow modules and STDCModules.
|
|
# (len(self.stages) ==
|
|
# self.num_shallow_features + len(self.stage_strides))
|
|
if self.with_final_conv:
|
|
self.final_conv = ConvModule(
|
|
self.channels[-1],
|
|
max(1024, self.channels[-1]),
|
|
1,
|
|
norm_cfg=norm_cfg,
|
|
act_cfg=act_cfg)
|
|
|
|
def _make_stage(self, in_channels, out_channels, strides, norm_cfg,
|
|
act_cfg, bottleneck_type):
|
|
layers = []
|
|
for i, stride in enumerate(strides):
|
|
layers.append(
|
|
STDCModule(
|
|
in_channels if i == 0 else out_channels,
|
|
out_channels,
|
|
stride,
|
|
norm_cfg,
|
|
act_cfg,
|
|
num_convs=self.num_convs,
|
|
fusion_type=bottleneck_type))
|
|
return Sequential(*layers)
|
|
|
|
def forward(self, x):
|
|
outs = []
|
|
for stage in self.stages:
|
|
x = stage(x)
|
|
outs.append(x)
|
|
if self.with_final_conv:
|
|
outs[-1] = self.final_conv(outs[-1])
|
|
outs = outs[self.num_shallow_features:]
|
|
return tuple(outs)
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class STDCContextPathNet(BaseModule):
|
|
"""STDCNet with Context Path. The `outs` below is a list of three feature
|
|
maps from deep to shallow, whose height and width is from small to big,
|
|
respectively. The biggest feature map of `outs` is outputted for
|
|
`STDCHead`, where Detail Loss would be calculated by Detail Ground-truth.
|
|
The other two feature maps are used for Attention Refinement Module,
|
|
respectively. Besides, the biggest feature map of `outs` and the last
|
|
output of Attention Refinement Module are concatenated for Feature Fusion
|
|
Module. Then, this fusion feature map `feat_fuse` would be outputted for
|
|
`decode_head`. More details please refer to Figure 4 of original paper.
|
|
|
|
Args:
|
|
backbone_cfg (dict): Config dict for stdc backbone.
|
|
last_in_channels (tuple(int)), The number of channels of last
|
|
two feature maps from stdc backbone. Default: (1024, 512).
|
|
out_channels (int): The channels of output feature maps.
|
|
Default: 128.
|
|
ffm_cfg (dict): Config dict for Feature Fusion Module. Default:
|
|
`dict(in_channels=512, out_channels=256, scale_factor=4)`.
|
|
upsample_mode (str): Algorithm used for upsampling:
|
|
``'nearest'`` | ``'linear'`` | ``'bilinear'`` | ``'bicubic'`` |
|
|
``'trilinear'``. Default: ``'nearest'``.
|
|
align_corners (str): align_corners argument of F.interpolate. It
|
|
must be `None` if upsample_mode is ``'nearest'``. Default: None.
|
|
norm_cfg (dict): Config dict for normalization layer.
|
|
Default: dict(type='BN').
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
|
|
Return:
|
|
outputs (tuple): The tuple of list of output feature map for
|
|
auxiliary heads and decoder head.
|
|
"""
|
|
|
|
def __init__(self,
|
|
backbone_cfg,
|
|
last_in_channels=(1024, 512),
|
|
out_channels=128,
|
|
ffm_cfg=dict(
|
|
in_channels=512, out_channels=256, scale_factor=4),
|
|
upsample_mode='nearest',
|
|
align_corners=None,
|
|
norm_cfg=dict(type='BN'),
|
|
init_cfg=None):
|
|
super(STDCContextPathNet, self).__init__(init_cfg=init_cfg)
|
|
self.backbone = build_backbone(backbone_cfg)
|
|
self.arms = ModuleList()
|
|
self.convs = ModuleList()
|
|
for channels in last_in_channels:
|
|
self.arms.append(AttentionRefinementModule(channels, out_channels))
|
|
self.convs.append(
|
|
ConvModule(
|
|
out_channels,
|
|
out_channels,
|
|
3,
|
|
padding=1,
|
|
norm_cfg=norm_cfg))
|
|
self.conv_avg = ConvModule(
|
|
last_in_channels[0], out_channels, 1, norm_cfg=norm_cfg)
|
|
|
|
self.ffm = FeatureFusionModule(**ffm_cfg)
|
|
|
|
self.upsample_mode = upsample_mode
|
|
self.align_corners = align_corners
|
|
|
|
def forward(self, x):
|
|
outs = list(self.backbone(x))
|
|
avg = F.adaptive_avg_pool2d(outs[-1], 1)
|
|
avg_feat = self.conv_avg(avg)
|
|
|
|
feature_up = resize(
|
|
avg_feat,
|
|
size=outs[-1].shape[2:],
|
|
mode=self.upsample_mode,
|
|
align_corners=self.align_corners)
|
|
arms_out = []
|
|
for i in range(len(self.arms)):
|
|
x_arm = self.arms[i](outs[len(outs) - 1 - i]) + feature_up
|
|
feature_up = resize(
|
|
x_arm,
|
|
size=outs[len(outs) - 1 - i - 1].shape[2:],
|
|
mode=self.upsample_mode,
|
|
align_corners=self.align_corners)
|
|
feature_up = self.convs[i](feature_up)
|
|
arms_out.append(feature_up)
|
|
|
|
feat_fuse = self.ffm(outs[0], arms_out[1])
|
|
|
|
# The `outputs` has four feature maps.
|
|
# `outs[0]` is outputted for `STDCHead` auxiliary head.
|
|
# Two feature maps of `arms_out` are outputted for auxiliary head.
|
|
# `feat_fuse` is outputted for decoder head.
|
|
outputs = [outs[0]] + list(arms_out) + [feat_fuse]
|
|
return tuple(outputs)
|