330 lines
13 KiB
Python
330 lines
13 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
import torch.nn as nn
|
|
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
|
from mmcv.runner import BaseModule
|
|
|
|
from mmseg.ops import resize
|
|
from ..builder import BACKBONES
|
|
|
|
|
|
class DownsamplerBlock(BaseModule):
|
|
"""Downsampler block of ERFNet.
|
|
|
|
This module is a little different from basical ConvModule.
|
|
The features from Conv and MaxPool layers are
|
|
concatenated before BatchNorm.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
conv_cfg (dict | None): Config of conv layers.
|
|
Default: None.
|
|
norm_cfg (dict | None): Config of norm layers.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config of activation layers.
|
|
Default: dict(type='ReLU').
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', eps=1e-3),
|
|
act_cfg=dict(type='ReLU'),
|
|
init_cfg=None):
|
|
super(DownsamplerBlock, self).__init__(init_cfg=init_cfg)
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
|
|
self.conv = build_conv_layer(
|
|
self.conv_cfg,
|
|
in_channels,
|
|
out_channels - in_channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1)
|
|
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
|
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
|
self.act = build_activation_layer(self.act_cfg)
|
|
|
|
def forward(self, input):
|
|
conv_out = self.conv(input)
|
|
pool_out = self.pool(input)
|
|
pool_out = resize(
|
|
input=pool_out,
|
|
size=conv_out.size()[2:],
|
|
mode='bilinear',
|
|
align_corners=False)
|
|
output = torch.cat([conv_out, pool_out], 1)
|
|
output = self.bn(output)
|
|
output = self.act(output)
|
|
return output
|
|
|
|
|
|
class NonBottleneck1d(BaseModule):
|
|
"""Non-bottleneck block of ERFNet.
|
|
|
|
Args:
|
|
channels (int): Number of channels in Non-bottleneck block.
|
|
drop_rate (float): Probability of an element to be zeroed.
|
|
Default 0.
|
|
dilation (int): Dilation rate for last two conv layers.
|
|
Default 1.
|
|
num_conv_layer (int): Number of 3x1 and 1x3 convolution layers.
|
|
Default 2.
|
|
conv_cfg (dict | None): Config of conv layers.
|
|
Default: None.
|
|
norm_cfg (dict | None): Config of norm layers.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config of activation layers.
|
|
Default: dict(type='ReLU').
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
channels,
|
|
drop_rate=0,
|
|
dilation=1,
|
|
num_conv_layer=2,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', eps=1e-3),
|
|
act_cfg=dict(type='ReLU'),
|
|
init_cfg=None):
|
|
super(NonBottleneck1d, self).__init__(init_cfg=init_cfg)
|
|
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
self.act = build_activation_layer(self.act_cfg)
|
|
|
|
self.convs_layers = nn.ModuleList()
|
|
for conv_layer in range(num_conv_layer):
|
|
first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0)
|
|
first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1)
|
|
second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation)
|
|
second_conv_dilation = 1 if conv_layer == 0 else (1, dilation)
|
|
|
|
self.convs_layers.append(
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
channels,
|
|
channels,
|
|
kernel_size=(3, 1),
|
|
stride=1,
|
|
padding=first_conv_padding,
|
|
bias=True,
|
|
dilation=first_conv_dilation))
|
|
self.convs_layers.append(self.act)
|
|
self.convs_layers.append(
|
|
build_conv_layer(
|
|
self.conv_cfg,
|
|
channels,
|
|
channels,
|
|
kernel_size=(1, 3),
|
|
stride=1,
|
|
padding=second_conv_padding,
|
|
bias=True,
|
|
dilation=second_conv_dilation))
|
|
self.convs_layers.append(
|
|
build_norm_layer(self.norm_cfg, channels)[1])
|
|
if conv_layer == 0:
|
|
self.convs_layers.append(self.act)
|
|
else:
|
|
self.convs_layers.append(nn.Dropout(p=drop_rate))
|
|
|
|
def forward(self, input):
|
|
output = input
|
|
for conv in self.convs_layers:
|
|
output = conv(output)
|
|
output = self.act(output + input)
|
|
return output
|
|
|
|
|
|
class UpsamplerBlock(BaseModule):
|
|
"""Upsampler block of ERFNet.
|
|
|
|
Args:
|
|
in_channels (int): Number of input channels.
|
|
out_channels (int): Number of output channels.
|
|
conv_cfg (dict | None): Config of conv layers.
|
|
Default: None.
|
|
norm_cfg (dict | None): Config of norm layers.
|
|
Default: dict(type='BN').
|
|
act_cfg (dict): Config of activation layers.
|
|
Default: dict(type='ReLU').
|
|
init_cfg (dict or list[dict], optional): Initialization config dict.
|
|
Default: None.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels,
|
|
out_channels,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', eps=1e-3),
|
|
act_cfg=dict(type='ReLU'),
|
|
init_cfg=None):
|
|
super(UpsamplerBlock, self).__init__(init_cfg=init_cfg)
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
|
|
self.conv = nn.ConvTranspose2d(
|
|
in_channels=in_channels,
|
|
out_channels=out_channels,
|
|
kernel_size=3,
|
|
stride=2,
|
|
padding=1,
|
|
output_padding=1,
|
|
bias=True)
|
|
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
|
self.act = build_activation_layer(self.act_cfg)
|
|
|
|
def forward(self, input):
|
|
output = self.conv(input)
|
|
output = self.bn(output)
|
|
output = self.act(output)
|
|
return output
|
|
|
|
|
|
@BACKBONES.register_module()
|
|
class ERFNet(BaseModule):
|
|
"""ERFNet backbone.
|
|
|
|
This backbone is the implementation of `ERFNet: Efficient Residual
|
|
Factorized ConvNet for Real-time SemanticSegmentation
|
|
<https://ieeexplore.ieee.org/document/8063438>`_.
|
|
|
|
Args:
|
|
in_channels (int): The number of channels of input
|
|
image. Default: 3.
|
|
enc_downsample_channels (Tuple[int]): Size of channel
|
|
numbers of various Downsampler block in encoder.
|
|
Default: (16, 64, 128).
|
|
enc_stage_non_bottlenecks (Tuple[int]): Number of stages of
|
|
Non-bottleneck block in encoder.
|
|
Default: (5, 8).
|
|
enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each
|
|
stage of Non-bottleneck block of encoder.
|
|
Default: (2, 4, 8, 16).
|
|
enc_non_bottleneck_channels (Tuple[int]): Size of channel
|
|
numbers of various Non-bottleneck block in encoder.
|
|
Default: (64, 128).
|
|
dec_upsample_channels (Tuple[int]): Size of channel numbers of
|
|
various Deconvolution block in decoder.
|
|
Default: (64, 16).
|
|
dec_stages_non_bottleneck (Tuple[int]): Number of stages of
|
|
Non-bottleneck block in decoder.
|
|
Default: (2, 2).
|
|
dec_non_bottleneck_channels (Tuple[int]): Size of channel
|
|
numbers of various Non-bottleneck block in decoder.
|
|
Default: (64, 16).
|
|
drop_rate (float): Probability of an element to be zeroed.
|
|
Default 0.1.
|
|
"""
|
|
|
|
def __init__(self,
|
|
in_channels=3,
|
|
enc_downsample_channels=(16, 64, 128),
|
|
enc_stage_non_bottlenecks=(5, 8),
|
|
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
|
enc_non_bottleneck_channels=(64, 128),
|
|
dec_upsample_channels=(64, 16),
|
|
dec_stages_non_bottleneck=(2, 2),
|
|
dec_non_bottleneck_channels=(64, 16),
|
|
dropout_ratio=0.1,
|
|
conv_cfg=None,
|
|
norm_cfg=dict(type='BN', requires_grad=True),
|
|
act_cfg=dict(type='ReLU'),
|
|
init_cfg=None):
|
|
|
|
super(ERFNet, self).__init__(init_cfg=init_cfg)
|
|
assert len(enc_downsample_channels) \
|
|
== len(dec_upsample_channels)+1, 'Number of downsample\
|
|
block of encoder does not \
|
|
match number of upsample block of decoder!'
|
|
assert len(enc_downsample_channels) \
|
|
== len(enc_stage_non_bottlenecks)+1, 'Number of \
|
|
downsample block of encoder does not match \
|
|
number of Non-bottleneck block of encoder!'
|
|
assert len(enc_downsample_channels) \
|
|
== len(enc_non_bottleneck_channels)+1, 'Number of \
|
|
downsample block of encoder does not match \
|
|
number of channels of Non-bottleneck block of encoder!'
|
|
assert enc_stage_non_bottlenecks[-1] \
|
|
% len(enc_non_bottleneck_dilations) == 0, 'Number of \
|
|
Non-bottleneck block of encoder does not match \
|
|
number of Non-bottleneck block of encoder!'
|
|
assert len(dec_upsample_channels) \
|
|
== len(dec_stages_non_bottleneck), 'Number of \
|
|
upsample block of decoder does not match \
|
|
number of Non-bottleneck block of decoder!'
|
|
assert len(dec_stages_non_bottleneck) \
|
|
== len(dec_non_bottleneck_channels), 'Number of \
|
|
Non-bottleneck block of decoder does not match \
|
|
number of channels of Non-bottleneck block of decoder!'
|
|
|
|
self.in_channels = in_channels
|
|
self.enc_downsample_channels = enc_downsample_channels
|
|
self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks
|
|
self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations
|
|
self.enc_non_bottleneck_channels = enc_non_bottleneck_channels
|
|
self.dec_upsample_channels = dec_upsample_channels
|
|
self.dec_stages_non_bottleneck = dec_stages_non_bottleneck
|
|
self.dec_non_bottleneck_channels = dec_non_bottleneck_channels
|
|
self.dropout_ratio = dropout_ratio
|
|
|
|
self.encoder = nn.ModuleList()
|
|
self.decoder = nn.ModuleList()
|
|
|
|
self.conv_cfg = conv_cfg
|
|
self.norm_cfg = norm_cfg
|
|
self.act_cfg = act_cfg
|
|
|
|
self.encoder.append(
|
|
DownsamplerBlock(self.in_channels, enc_downsample_channels[0]))
|
|
|
|
for i in range(len(enc_downsample_channels) - 1):
|
|
self.encoder.append(
|
|
DownsamplerBlock(enc_downsample_channels[i],
|
|
enc_downsample_channels[i + 1]))
|
|
# Last part of encoder is some dilated NonBottleneck1d blocks.
|
|
if i == len(enc_downsample_channels) - 2:
|
|
iteration_times = int(enc_stage_non_bottlenecks[-1] /
|
|
len(enc_non_bottleneck_dilations))
|
|
for j in range(iteration_times):
|
|
for k in range(len(enc_non_bottleneck_dilations)):
|
|
self.encoder.append(
|
|
NonBottleneck1d(enc_downsample_channels[-1],
|
|
self.dropout_ratio,
|
|
enc_non_bottleneck_dilations[k]))
|
|
else:
|
|
for j in range(enc_stage_non_bottlenecks[i]):
|
|
self.encoder.append(
|
|
NonBottleneck1d(enc_downsample_channels[i + 1],
|
|
self.dropout_ratio))
|
|
|
|
for i in range(len(dec_upsample_channels)):
|
|
if i == 0:
|
|
self.decoder.append(
|
|
UpsamplerBlock(enc_downsample_channels[-1],
|
|
dec_non_bottleneck_channels[i]))
|
|
else:
|
|
self.decoder.append(
|
|
UpsamplerBlock(dec_non_bottleneck_channels[i - 1],
|
|
dec_non_bottleneck_channels[i]))
|
|
for j in range(dec_stages_non_bottleneck[i]):
|
|
self.decoder.append(
|
|
NonBottleneck1d(dec_non_bottleneck_channels[i]))
|
|
|
|
def forward(self, x):
|
|
for enc in self.encoder:
|
|
x = enc(x)
|
|
for dec in self.decoder:
|
|
x = dec(x)
|
|
return [x]
|