[Feature] Support Real-time model ERFNet (#960)
* first commit * Fixing Unittest Error * first refactory of ERFNet * Refactorying NonBottleneck1d Module * uploading models&logs * uploading models&logs * fix partial bugs & typos * ERFNet * add ERFNet with FCNHead * fix typos of ERFNet * add name on README.md cover * chane name to T-ITS'2017 * fix lint errorpull/1801/head
parent
defb21bf8a
commit
6a2cfea73b
|
@ -70,6 +70,7 @@ Supported backbones:
|
|||
Supported methods:
|
||||
|
||||
- [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn)
|
||||
- [x] [ERFNet (T-ITS'2017)](configs/erfnet)
|
||||
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
|
||||
- [x] [PSPNet (CVPR'2017)](configs/pspnet)
|
||||
- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)
|
||||
|
|
|
@ -69,6 +69,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
|
|||
已支持的算法:
|
||||
|
||||
- [x] [FCN (CVPR'2015/TPAMI'2017)](configs/fcn)
|
||||
- [x] [ERFNet (T-ITS'2017)](configs/erfnet)
|
||||
- [x] [UNet (MICCAI'2016/Nat. Methods'2019)](configs/unet)
|
||||
- [x] [PSPNet (CVPR'2017)](configs/pspnet)
|
||||
- [x] [DeepLabV3 (ArXiv'2017)](configs/deeplabv3)
|
||||
|
|
|
@ -0,0 +1,32 @@
|
|||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='ERFNet',
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
init_cfg=None),
|
||||
decode_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=16,
|
||||
channels=128,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
|
@ -0,0 +1,50 @@
|
|||
# ERFNet: Efficient Residual Factorized ConvNet for Real-time Semantic Segmentation
|
||||
|
||||
## Introduction
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
<a href="https://github.com/Eromera/erfnet_pytorch">Official Repo</a>
|
||||
|
||||
<a href="https://github.com/open-mmlab/mmsegmentation/blob/v0.20.0/mmseg/models/backbones/erfnet.py#L321">Code Snippet</a>
|
||||
|
||||
## Abstract
|
||||
|
||||
Semantic segmentation is a challenging task that addresses most of the perception needs of intelligent vehicles (IVs) in an unified way. Deep neural networks excel at this task, as they can be trained end-to-end to accurately classify multiple object categories in an image at pixel level. However, a good tradeoff between high quality and computational resources is yet not present in the state-of-the-art semantic segmentation approaches, limiting their application in real vehicles. In this paper, we propose a deep architecture that is able to run in real time while providing accurate semantic segmentation. The core of our architecture is a novel layer that uses residual connections and factorized convolutions in order to remain efficient while retaining remarkable accuracy. Our approach is able to run at over 83 FPS in a single Titan X, and 7 FPS in a Jetson TX1 (embedded device). A comprehensive set of experiments on the publicly available Cityscapes data set demonstrates that our system achieves an accuracy that is similar to the state of the art, while being orders of magnitude faster to compute than other architectures that achieve top precision. The resulting tradeoff makes our model an ideal approach for scene understanding in IV applications. The code is publicly available at: https://github.com/Eromera/erfnet.
|
||||
|
||||
<!-- [IMAGE] -->
|
||||
<div align=center>
|
||||
<img src="https://user-images.githubusercontent.com/24582831/143479729-ea7951f6-1a3c-47d6-aaee-62c5759c0638.png" width="60%"/>
|
||||
</div>
|
||||
|
||||
<details>
|
||||
<summary align="right"><a href="http://www.robesafe.uah.es/personal/eduardo.romera/pdfs/Romera17tits.pdf">ERFNet (T-ITS'2017)</a></summary>
|
||||
|
||||
```latex
|
||||
@article{romera2017erfnet,
|
||||
title={Erfnet: Efficient residual factorized convnet for real-time semantic segmentation},
|
||||
author={Romera, Eduardo and Alvarez, Jos{\'e} M and Bergasa, Luis M and Arroyo, Roberto},
|
||||
journal={IEEE Transactions on Intelligent Transportation Systems},
|
||||
volume={19},
|
||||
number={1},
|
||||
pages={263--272},
|
||||
year={2017},
|
||||
publisher={IEEE}
|
||||
}
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
|
||||
| --------- | --------- | --------- | ------: | -------- | -------------- | ----: | ------------- | --------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| FCN | ERFNet | 512x1024 | 160000 | 6.04 | 15.26 | 71.08 | 72.6 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056-03d333ed.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056.log.json) |
|
||||
|
||||
Note:
|
||||
|
||||
- The model is trained from scratch.
|
||||
|
||||
- Last deconvolution layer in the [original paper](https://github.com/Eromera/erfnet_pytorch/blob/master/train/erfnet.py#L123) is replaced by a naive `FCNHead` decoder head and a bilinear upsampling layer, found more effective and efficient.
|
|
@ -0,0 +1,37 @@
|
|||
Collections:
|
||||
- Name: erfnet
|
||||
Metadata:
|
||||
Training Data:
|
||||
- Cityscapes
|
||||
Paper:
|
||||
URL: http://www.robesafe.uah.es/personal/eduardo.romera/pdfs/Romera17tits.pdf
|
||||
Title: 'ERFNet: Efficient Residual Factorized ConvNet for Real-time Semantic Segmentation'
|
||||
README: configs/erfnet/README.md
|
||||
Code:
|
||||
URL: https://github.com/open-mmlab/mmsegmentation/blob/v0.20.0/mmseg/models/backbones/erfnet.py#L321
|
||||
Version: v0.20.0
|
||||
Converted From:
|
||||
Code: https://github.com/Eromera/erfnet_pytorch
|
||||
Models:
|
||||
- Name: erfnet_fcn_4x4_512x1024_160k_cityscapes
|
||||
In Collection: erfnet
|
||||
Metadata:
|
||||
backbone: ERFNet
|
||||
crop size: (512,1024)
|
||||
lr schd: 160000
|
||||
inference time (ms/im):
|
||||
- value: 65.53
|
||||
hardware: V100
|
||||
backend: PyTorch
|
||||
batch size: 1
|
||||
mode: FP32
|
||||
resolution: (512,1024)
|
||||
Training Memory (GB): 6.04
|
||||
Results:
|
||||
- Task: Semantic Segmentation
|
||||
Dataset: Cityscapes
|
||||
Metrics:
|
||||
mIoU: 71.08
|
||||
mIoU(ms+flip): 72.6
|
||||
Config: configs/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes.py
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/erfnet/erfnet_fcn_4x4_512x1024_160k_cityscapes/erfnet_fcn_4x4_512x1024_160k_cityscapes_20211126_082056-03d333ed.pth
|
|
@ -0,0 +1,8 @@
|
|||
_base_ = [
|
||||
'../_base_/models/erfnet_fcn.py', '../_base_/datasets/cityscapes.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
)
|
|
@ -2,6 +2,7 @@
|
|||
from .bisenetv1 import BiSeNetV1
|
||||
from .bisenetv2 import BiSeNetV2
|
||||
from .cgnet import CGNet
|
||||
from .erfnet import ERFNet
|
||||
from .fast_scnn import FastSCNN
|
||||
from .hrnet import HRNet
|
||||
from .icnet import ICNet
|
||||
|
@ -20,5 +21,5 @@ __all__ = [
|
|||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
|
||||
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone'
|
||||
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet'
|
||||
]
|
||||
|
|
|
@ -0,0 +1,329 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import build_activation_layer, build_conv_layer, build_norm_layer
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import BACKBONES
|
||||
|
||||
|
||||
class DownsamplerBlock(BaseModule):
|
||||
"""Downsampler block of ERFNet.
|
||||
|
||||
This module is a little different from basical ConvModule.
|
||||
The features from Conv and MaxPool layers are
|
||||
concatenated before BatchNorm.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(DownsamplerBlock, self).__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = build_conv_layer(
|
||||
self.conv_cfg,
|
||||
in_channels,
|
||||
out_channels - in_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1)
|
||||
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
conv_out = self.conv(input)
|
||||
pool_out = self.pool(input)
|
||||
pool_out = resize(
|
||||
input=pool_out,
|
||||
size=conv_out.size()[2:],
|
||||
mode='bilinear',
|
||||
align_corners=False)
|
||||
output = torch.cat([conv_out, pool_out], 1)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
class NonBottleneck1d(BaseModule):
|
||||
"""Non-bottleneck block of ERFNet.
|
||||
|
||||
Args:
|
||||
channels (int): Number of channels in Non-bottleneck block.
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.
|
||||
dilation (int): Dilation rate for last two conv layers.
|
||||
Default 1.
|
||||
num_conv_layer (int): Number of 3x1 and 1x3 convolution layers.
|
||||
Default 2.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
channels,
|
||||
drop_rate=0,
|
||||
dilation=1,
|
||||
num_conv_layer=2,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(NonBottleneck1d, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
self.convs_layers = nn.ModuleList()
|
||||
for conv_layer in range(num_conv_layer):
|
||||
first_conv_padding = (1, 0) if conv_layer == 0 else (dilation, 0)
|
||||
first_conv_dilation = 1 if conv_layer == 0 else (dilation, 1)
|
||||
second_conv_padding = (0, 1) if conv_layer == 0 else (0, dilation)
|
||||
second_conv_dilation = 1 if conv_layer == 0 else (1, dilation)
|
||||
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(3, 1),
|
||||
stride=1,
|
||||
padding=first_conv_padding,
|
||||
bias=True,
|
||||
dilation=first_conv_dilation))
|
||||
self.convs_layers.append(self.act)
|
||||
self.convs_layers.append(
|
||||
build_conv_layer(
|
||||
self.conv_cfg,
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=(1, 3),
|
||||
stride=1,
|
||||
padding=second_conv_padding,
|
||||
bias=True,
|
||||
dilation=second_conv_dilation))
|
||||
self.convs_layers.append(
|
||||
build_norm_layer(self.norm_cfg, channels)[1])
|
||||
if conv_layer == 0:
|
||||
self.convs_layers.append(self.act)
|
||||
else:
|
||||
self.convs_layers.append(nn.Dropout(p=drop_rate))
|
||||
|
||||
def forward(self, input):
|
||||
output = input
|
||||
for conv in self.convs_layers:
|
||||
output = conv(output)
|
||||
output = self.act(output + input)
|
||||
return output
|
||||
|
||||
|
||||
class UpsamplerBlock(BaseModule):
|
||||
"""Upsampler block of ERFNet.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', eps=1e-3),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(UpsamplerBlock, self).__init__(init_cfg=init_cfg)
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.conv = nn.ConvTranspose2d(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
output_padding=1,
|
||||
bias=True)
|
||||
self.bn = build_norm_layer(self.norm_cfg, out_channels)[1]
|
||||
self.act = build_activation_layer(self.act_cfg)
|
||||
|
||||
def forward(self, input):
|
||||
output = self.conv(input)
|
||||
output = self.bn(output)
|
||||
output = self.act(output)
|
||||
return output
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class ERFNet(BaseModule):
|
||||
"""ERFNet backbone.
|
||||
|
||||
This backbone is the implementation of `ERFNet: Efficient Residual
|
||||
Factorized ConvNet for Real-time SemanticSegmentation
|
||||
<https://ieeexplore.ieee.org/document/8063438>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): The number of channels of input
|
||||
image. Default: 3.
|
||||
enc_downsample_channels (Tuple[int]): Size of channel
|
||||
numbers of various Downsampler block in encoder.
|
||||
Default: (16, 64, 128).
|
||||
enc_stage_non_bottlenecks (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in encoder.
|
||||
Default: (5, 8).
|
||||
enc_non_bottleneck_dilations (Tuple[int]): Dilation rate of each
|
||||
stage of Non-bottleneck block of encoder.
|
||||
Default: (2, 4, 8, 16).
|
||||
enc_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in encoder.
|
||||
Default: (64, 128).
|
||||
dec_upsample_channels (Tuple[int]): Size of channel numbers of
|
||||
various Deconvolution block in decoder.
|
||||
Default: (64, 16).
|
||||
dec_stages_non_bottleneck (Tuple[int]): Number of stages of
|
||||
Non-bottleneck block in decoder.
|
||||
Default: (2, 2).
|
||||
dec_non_bottleneck_channels (Tuple[int]): Size of channel
|
||||
numbers of various Non-bottleneck block in decoder.
|
||||
Default: (64, 16).
|
||||
drop_rate (float): Probability of an element to be zeroed.
|
||||
Default 0.1.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN', requires_grad=True),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
|
||||
super(ERFNet, self).__init__(init_cfg=init_cfg)
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(dec_upsample_channels)+1, 'Number of downsample\
|
||||
block of encoder does not \
|
||||
match number of upsample block of decoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_stage_non_bottlenecks)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(enc_downsample_channels) \
|
||||
== len(enc_non_bottleneck_channels)+1, 'Number of \
|
||||
downsample block of encoder does not match \
|
||||
number of channels of Non-bottleneck block of encoder!'
|
||||
assert enc_stage_non_bottlenecks[-1] \
|
||||
% len(enc_non_bottleneck_dilations) == 0, 'Number of \
|
||||
Non-bottleneck block of encoder does not match \
|
||||
number of Non-bottleneck block of encoder!'
|
||||
assert len(dec_upsample_channels) \
|
||||
== len(dec_stages_non_bottleneck), 'Number of \
|
||||
upsample block of decoder does not match \
|
||||
number of Non-bottleneck block of decoder!'
|
||||
assert len(dec_stages_non_bottleneck) \
|
||||
== len(dec_non_bottleneck_channels), 'Number of \
|
||||
Non-bottleneck block of decoder does not match \
|
||||
number of channels of Non-bottleneck block of decoder!'
|
||||
|
||||
self.in_channels = in_channels
|
||||
self.enc_downsample_channels = enc_downsample_channels
|
||||
self.enc_stage_non_bottlenecks = enc_stage_non_bottlenecks
|
||||
self.enc_non_bottleneck_dilations = enc_non_bottleneck_dilations
|
||||
self.enc_non_bottleneck_channels = enc_non_bottleneck_channels
|
||||
self.dec_upsample_channels = dec_upsample_channels
|
||||
self.dec_stages_non_bottleneck = dec_stages_non_bottleneck
|
||||
self.dec_non_bottleneck_channels = dec_non_bottleneck_channels
|
||||
self.dropout_ratio = dropout_ratio
|
||||
|
||||
self.encoder = nn.ModuleList()
|
||||
self.decoder = nn.ModuleList()
|
||||
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(self.in_channels, enc_downsample_channels[0]))
|
||||
|
||||
for i in range(len(enc_downsample_channels) - 1):
|
||||
self.encoder.append(
|
||||
DownsamplerBlock(enc_downsample_channels[i],
|
||||
enc_downsample_channels[i + 1]))
|
||||
# Last part of encoder is some dilated NonBottleneck1d blocks.
|
||||
if i == len(enc_downsample_channels) - 2:
|
||||
iteration_times = int(enc_stage_non_bottlenecks[-1] /
|
||||
len(enc_non_bottleneck_dilations))
|
||||
for j in range(iteration_times):
|
||||
for k in range(len(enc_non_bottleneck_dilations)):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[-1],
|
||||
self.dropout_ratio,
|
||||
enc_non_bottleneck_dilations[k]))
|
||||
else:
|
||||
for j in range(enc_stage_non_bottlenecks[i]):
|
||||
self.encoder.append(
|
||||
NonBottleneck1d(enc_downsample_channels[i + 1],
|
||||
self.dropout_ratio))
|
||||
|
||||
for i in range(len(dec_upsample_channels)):
|
||||
if i == 0:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(enc_downsample_channels[-1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
else:
|
||||
self.decoder.append(
|
||||
UpsamplerBlock(dec_non_bottleneck_channels[i - 1],
|
||||
dec_non_bottleneck_channels[i]))
|
||||
for j in range(dec_stages_non_bottleneck[i]):
|
||||
self.decoder.append(
|
||||
NonBottleneck1d(dec_non_bottleneck_channels[i]))
|
||||
|
||||
def forward(self, x):
|
||||
for enc in self.encoder:
|
||||
x = enc(x)
|
||||
for dec in self.decoder:
|
||||
x = dec(x)
|
||||
return [x]
|
|
@ -13,6 +13,7 @@ Import:
|
|||
- configs/dpt/dpt.yml
|
||||
- configs/emanet/emanet.yml
|
||||
- configs/encnet/encnet.yml
|
||||
- configs/erfnet/erfnet.yml
|
||||
- configs/fastfcn/fastfcn.yml
|
||||
- configs/fastscnn/fastscnn.yml
|
||||
- configs/fcn/fcn.yml
|
||||
|
|
|
@ -0,0 +1,146 @@
|
|||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from mmseg.models.backbones import ERFNet
|
||||
from mmseg.models.backbones.erfnet import (DownsamplerBlock, NonBottleneck1d,
|
||||
UpsamplerBlock)
|
||||
|
||||
|
||||
def test_erfnet_backbone():
|
||||
# Test ERFNet Standard Forward.
|
||||
model = ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 256, 512)
|
||||
output = model(imgs)
|
||||
|
||||
# output for segment Head
|
||||
assert output[0].shape == torch.Size([batch_size, 16, 128, 256])
|
||||
|
||||
# Test input with rare shape
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 527, 279)
|
||||
output = model(imgs)
|
||||
assert len(output[0]) == batch_size
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of encoder downsample block and decoder upsample block.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(128, 64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of encoder downsample block and encoder Non-bottleneck block.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8, 10),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of encoder downsample block and
|
||||
# channels of encoder Non-bottleneck block.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128, 256),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of encoder Non-bottleneck block and number of its channels.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8, 3),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of decoder upsample block and decoder Non-bottleneck block.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2, 3),
|
||||
dec_non_bottleneck_channels=(64, 16),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
with pytest.raises(AssertionError):
|
||||
# Number of decoder Non-bottleneck block and number of its channels.
|
||||
ERFNet(
|
||||
in_channels=3,
|
||||
enc_downsample_channels=(16, 64, 128),
|
||||
enc_stage_non_bottlenecks=(5, 8),
|
||||
enc_non_bottleneck_dilations=(2, 4, 8, 16),
|
||||
enc_non_bottleneck_channels=(64, 128),
|
||||
dec_upsample_channels=(64, 16),
|
||||
dec_stages_non_bottleneck=(2, 2),
|
||||
dec_non_bottleneck_channels=(64, 16, 8),
|
||||
dropout_ratio=0.1,
|
||||
)
|
||||
|
||||
|
||||
def test_erfnet_downsampler_block():
|
||||
x_db = DownsamplerBlock(16, 64)
|
||||
assert x_db.conv.in_channels == 16
|
||||
assert x_db.conv.out_channels == 48
|
||||
assert len(x_db.bn.weight) == 64
|
||||
assert x_db.pool.kernel_size == 2
|
||||
assert x_db.pool.stride == 2
|
||||
|
||||
|
||||
def test_erfnet_non_bottleneck_1d():
|
||||
x_nb1d = NonBottleneck1d(16, 0, 1)
|
||||
assert x_nb1d.convs_layers[0].in_channels == 16
|
||||
assert x_nb1d.convs_layers[0].out_channels == 16
|
||||
assert x_nb1d.convs_layers[2].in_channels == 16
|
||||
assert x_nb1d.convs_layers[2].out_channels == 16
|
||||
assert x_nb1d.convs_layers[5].in_channels == 16
|
||||
assert x_nb1d.convs_layers[5].out_channels == 16
|
||||
assert x_nb1d.convs_layers[7].in_channels == 16
|
||||
assert x_nb1d.convs_layers[7].out_channels == 16
|
||||
assert x_nb1d.convs_layers[9].p == 0
|
||||
|
||||
|
||||
def test_erfnet_upsampler_block():
|
||||
x_ub = UpsamplerBlock(64, 16)
|
||||
assert x_ub.conv.in_channels == 64
|
||||
assert x_ub.conv.out_channels == 16
|
||||
assert len(x_ub.bn.weight) == 16
|
Loading…
Reference in New Issue