mmsegmentation/mmseg/models/backbones/bisenetv2.py

623 lines
22 KiB
Python
Raw Blame History

This file contains invisible Unicode characters!

This file contains invisible Unicode characters that may be processed differently from what appears below. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to reveal hidden characters.

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
build_activation_layer, build_norm_layer)
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import BACKBONES
class DetailBranch(BaseModule):
"""Detail Branch with wide channels and shallow layers to capture low-level
details and generate high-resolution feature representation.
Args:
detail_channels (Tuple[int]): Size of channel numbers of each stage
in Detail Branch, in paper it has 3 stages.
Default: (64, 64, 128).
in_channels (int): Number of channels of input image. Default: 3.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): Feature map of Detail Branch.
"""
def __init__(self,
detail_channels=(64, 64, 128),
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(DetailBranch, self).__init__(init_cfg=init_cfg)
detail_branch = []
for i in range(len(detail_channels)):
if i == 0:
detail_branch.append(
nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=detail_channels[i],
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=detail_channels[i],
out_channels=detail_channels[i],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)))
else:
detail_branch.append(
nn.Sequential(
ConvModule(
in_channels=detail_channels[i - 1],
out_channels=detail_channels[i],
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=detail_channels[i],
out_channels=detail_channels[i],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=detail_channels[i],
out_channels=detail_channels[i],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)))
self.detail_branch = nn.ModuleList(detail_branch)
def forward(self, x):
for stage in self.detail_branch:
x = stage(x)
return x
class StemBlock(BaseModule):
"""Stem Block at the beginning of Semantic Branch.
Args:
in_channels (int): Number of input channels.
Default: 3.
out_channels (int): Number of output channels.
Default: 16.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): First feature map in Semantic Branch.
"""
def __init__(self,
in_channels=3,
out_channels=16,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(StemBlock, self).__init__(init_cfg=init_cfg)
self.conv_first = ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.convs = nn.Sequential(
ConvModule(
in_channels=out_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=out_channels // 2,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.pool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1, ceil_mode=False)
self.fuse_last = ConvModule(
in_channels=out_channels * 2,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x):
x = self.conv_first(x)
x_left = self.convs(x)
x_right = self.pool(x)
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
return x
class GELayer(BaseModule):
"""Gather-and-Expansion Layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
exp_ratio (int): Expansion ratio for middle channels.
Default: 6.
stride (int): Stride of GELayer. Default: 1
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): Intermediate feature map in
Semantic Branch.
"""
def __init__(self,
in_channels,
out_channels,
exp_ratio=6,
stride=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(GELayer, self).__init__(init_cfg=init_cfg)
mid_channel = in_channels * exp_ratio
self.conv1 = ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if stride == 1:
self.dwconv = nn.Sequential(
# ReLU in ConvModule not shown in paper
ConvModule(
in_channels=in_channels,
out_channels=mid_channel,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.shortcut = None
else:
self.dwconv = nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=mid_channel,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
# ReLU in ConvModule not shown in paper
ConvModule(
in_channels=mid_channel,
out_channels=mid_channel,
kernel_size=3,
stride=1,
padding=1,
groups=mid_channel,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
self.shortcut = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=1,
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=norm_cfg,
pw_act_cfg=None,
))
self.conv2 = nn.Sequential(
ConvModule(
in_channels=mid_channel,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None,
))
self.act = build_activation_layer(act_cfg)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.dwconv(x)
x = self.conv2(x)
if self.shortcut is not None:
shortcut = self.shortcut(identity)
x = x + shortcut
else:
x = x + identity
x = self.act(x)
return x
class CEBlock(BaseModule):
"""Context Embedding Block for large receptive filed in Semantic Branch.
Args:
in_channels (int): Number of input channels.
Default: 3.
out_channels (int): Number of output channels.
Default: 16.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): Last feature map in Semantic Branch.
"""
def __init__(self,
in_channels=3,
out_channels=16,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(CEBlock, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.gap = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
build_norm_layer(norm_cfg, self.in_channels)[1])
self.conv_gap = ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
# Note: in paper here is naive conv2d, no bn-relu
self.conv_last = ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x):
identity = x
x = self.gap(x)
x = self.conv_gap(x)
x = identity + x
x = self.conv_last(x)
return x
class SemanticBranch(BaseModule):
"""Semantic Branch which is lightweight with narrow channels and deep
layers to obtain high-level semantic context.
Args:
semantic_channels(Tuple[int]): Size of channel numbers of
various stages in Semantic Branch.
Default: (16, 32, 64, 128).
in_channels (int): Number of channels of input image. Default: 3.
exp_ratio (int): Expansion ratio for middle channels.
Default: 6.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
semantic_outs (List[torch.Tensor]): List of several feature maps
for auxiliary heads (Booster) and Bilateral
Guided Aggregation Layer.
"""
def __init__(self,
semantic_channels=(16, 32, 64, 128),
in_channels=3,
exp_ratio=6,
init_cfg=None):
super(SemanticBranch, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.semantic_channels = semantic_channels
self.semantic_stages = []
for i in range(len(semantic_channels)):
stage_name = f'stage{i + 1}'
self.semantic_stages.append(stage_name)
if i == 0:
self.add_module(
stage_name,
StemBlock(self.in_channels, semantic_channels[i]))
elif i == (len(semantic_channels) - 1):
self.add_module(
stage_name,
nn.Sequential(
GELayer(semantic_channels[i - 1], semantic_channels[i],
exp_ratio, 2),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1)))
else:
self.add_module(
stage_name,
nn.Sequential(
GELayer(semantic_channels[i - 1], semantic_channels[i],
exp_ratio, 2),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1)))
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
CEBlock(semantic_channels[-1], semantic_channels[-1]))
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
def forward(self, x):
semantic_outs = []
for stage_name in self.semantic_stages:
semantic_stage = getattr(self, stage_name)
x = semantic_stage(x)
semantic_outs.append(x)
return semantic_outs
class BGALayer(BaseModule):
"""Bilateral Guided Aggregation Layer to fuse the complementary information
from both Detail Branch and Semantic Branch.
Args:
out_channels (int): Number of output channels.
Default: 128.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
output (torch.Tensor): Output feature map for Segment heads.
"""
def __init__(self,
out_channels=128,
align_corners=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(BGALayer, self).__init__(init_cfg=init_cfg)
self.out_channels = out_channels
self.align_corners = align_corners
self.detail_dwconv = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=None,
pw_act_cfg=None,
))
self.detail_down = nn.Sequential(
ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
self.semantic_conv = nn.Sequential(
ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None))
self.semantic_dwconv = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=None,
pw_act_cfg=None,
))
self.conv = ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
inplace=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
def forward(self, x_d, x_s):
detail_dwconv = self.detail_dwconv(x_d)
detail_down = self.detail_down(x_d)
semantic_conv = self.semantic_conv(x_s)
semantic_dwconv = self.semantic_dwconv(x_s)
semantic_conv = resize(
input=semantic_conv,
size=detail_dwconv.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
fuse_2 = resize(
input=fuse_2,
size=fuse_1.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = self.conv(fuse_1 + fuse_2)
return output
@BACKBONES.register_module()
class BiSeNetV2(BaseModule):
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
Real-time Semantic Segmentation.
This backbone is the implementation of
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_.
Args:
in_channels (int): Number of channel of input image. Default: 3.
detail_channels (Tuple[int], optional): Channels of each stage
in Detail Branch. Default: (64, 64, 128).
semantic_channels (Tuple[int], optional): Channels of each stage
in Semantic Branch. Default: (16, 32, 64, 128).
See Table 1 and Figure 3 of paper for more details.
semantic_expansion_ratio (int, optional): The expansion factor
expanding channel number of middle channels in Semantic Branch.
Default: 6.
bga_channels (int, optional): Number of middle channels in
Bilateral Guided Aggregation Layer. Default: 128.
out_indices (Tuple[int] | int, optional): Output from which stages.
Default: (0, 1, 2, 3, 4).
align_corners (bool, optional): The align_corners argument of
resize operation in Bilateral Guided Aggregation Layer.
Default: False.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=3,
detail_channels=(64, 64, 128),
semantic_channels=(16, 32, 64, 128),
semantic_expansion_ratio=6,
bga_channels=128,
out_indices=(0, 1, 2, 3, 4),
align_corners=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
if init_cfg is None:
init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
super(BiSeNetV2, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_indices = out_indices
self.detail_channels = detail_channels
self.semantic_channels = semantic_channels
self.semantic_expansion_ratio = semantic_expansion_ratio
self.bga_channels = bga_channels
self.align_corners = align_corners
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.detail = DetailBranch(self.detail_channels, self.in_channels)
self.semantic = SemanticBranch(self.semantic_channels,
self.in_channels,
self.semantic_expansion_ratio)
self.bga = BGALayer(self.bga_channels, self.align_corners)
def forward(self, x):
# stole refactoring code from Coin Cheung, thanks
x_detail = self.detail(x)
x_semantic_lst = self.semantic(x)
x_head = self.bga(x_detail, x_semantic_lst[-1])
outs = [x_head] + x_semantic_lst[:-1]
outs = [outs[i] for i in self.out_indices]
return tuple(outs)