[Feature] Support BiSeNetV2 (#804)

* BiSeNetV2 first commit

* BiSeNetV2 unittest

* remove pytest

* add pytest module

* fix ConvModule input name

* fix pytest error

* fix unittest

* refactor

* BiSeNetV2 Refactory

* fix docstrings and add some small changes

* use_sigmoid=False

* fix potential bugs about upsampling

* Use ConvModule instead

* Use ConvModule instead

* fix typos

* fix typos

* fix typos

* discard nn.conv2d

* discard nn.conv2d

* discard nn.conv2d

* delete **kwargs

* uploading markdown and model

* final commit

* BiSeNetV2 adding Unittest for its modules

* BiSeNetV2 adding Unittest for its modules

* BiSeNetV2 adding Unittest for its modules

* BiSeNetV2 adding Unittest for its modules

* BiSeNetV2 adding Unittest for its modules

* BiSeNetV2 adding Unittest for its modules

* BiSeNetV2 adding Unittest for its modules

* Fix README conflict

* Fix unittest problem

* Fix unittest problem

* BiSeNetV2

* Fixing fps

* Fixing typpos

* bisenetv2
This commit is contained in:
MengzhangLI 2021-09-26 18:52:16 +08:00 committed by GitHub
parent 29c82eaf13
commit 4003b8f421
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
15 changed files with 952 additions and 1 deletions

View File

@ -94,6 +94,7 @@ Supported methods:
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
- [x] [SegFormer (ArXiv'2021)](configs/segformer)
Supported datasets:

View File

@ -93,6 +93,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
- [x] [PointRend (CVPR'2020)](configs/point_rend)
- [x] [CGNet (TIP'2020)](configs/cgnet)
- [x] [SETR (CVPR'2021)](configs/setr)
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
- [x] [SegFormer (ArXiv'2021)](configs/segformer)
已支持的数据集:

View File

@ -0,0 +1,35 @@
_base_ = './cityscapes.py'
img_norm_cfg = dict(
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
crop_size = (1024, 1024)
train_pipeline = [
dict(type='LoadImageFromFile'),
dict(type='LoadAnnotations'),
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
dict(type='RandomFlip', prob=0.5),
dict(type='PhotoMetricDistortion'),
dict(type='Normalize', **img_norm_cfg),
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
dict(type='DefaultFormatBundle'),
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
]
test_pipeline = [
dict(type='LoadImageFromFile'),
dict(
type='MultiScaleFlipAug',
img_scale=(2048, 1024),
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
flip=False,
transforms=[
dict(type='Resize', keep_ratio=True),
dict(type='RandomFlip'),
dict(type='Normalize', **img_norm_cfg),
dict(type='ImageToTensor', keys=['img']),
dict(type='Collect', keys=['img']),
])
]
data = dict(
train=dict(pipeline=train_pipeline),
val=dict(pipeline=test_pipeline),
test=dict(pipeline=test_pipeline))

View File

@ -0,0 +1,80 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
model = dict(
type='EncoderDecoder',
pretrained=None,
backbone=dict(
type='BiSeNetV2',
detail_channels=(64, 64, 128),
semantic_channels=(16, 32, 64, 128),
semantic_expansion_ratio=6,
bga_channels=128,
out_indices=(0, 1, 2, 3, 4),
init_cfg=None,
align_corners=False),
decode_head=dict(
type='FCNHead',
in_channels=128,
in_index=0,
channels=1024,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=[
dict(
type='FCNHead',
in_channels=16,
channels=16,
num_convs=2,
num_classes=19,
in_index=1,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=32,
channels=64,
num_convs=2,
num_classes=19,
in_index=2,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=64,
channels=256,
num_convs=2,
num_classes=19,
in_index=3,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
dict(
type='FCNHead',
in_channels=128,
channels=1024,
num_convs=2,
num_classes=19,
in_index=4,
norm_cfg=norm_cfg,
concat_input=False,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
],
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))

View File

@ -0,0 +1,33 @@
# Bisenet v2: Bilateral Network with Guided Aggregation for Real-time Semantic Segmentation
## Introduction
<!-- [ALGORITHM] -->
```latex
@article{yu2021bisenet,
title={Bisenet v2: Bilateral network with guided aggregation for real-time semantic segmentation},
author={Yu, Changqian and Gao, Changxin and Wang, Jingbo and Yu, Gang and Shen, Chunhua and Sang, Nong},
journal={International Journal of Computer Vision},
pages={1--18},
year={2021},
publisher={Springer}
}
```
## Results and models
### Cityscapes
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
| ------ | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | --------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
| BiSeNetV2 | BiSeNetV2 | 1024x1024 | 160000 | 7.64 | 31.77 | 73.21 | 75.74 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551-bcf10f09.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551.log.json) |
| BiSeNetV2 (OHEM) | BiSeNetV2 | 1024x1024 | 160000 | 7.64 | - | 73.57 | 75.80 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947-5f8103b4.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947.log.json) |
| BiSeNetV2 (4x8) | BiSeNetV2 | 1024x1024 | 160000 | 15.05 | - | 75.76 | 77.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032-e1a2eed6.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032.log.json) |
| BiSeNetV2 (FP16) | BiSeNetV2 | 1024x1024 | 160000 | 5.77 | 36.65 | 73.07 | 75.13 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942-b979777b.pth) &#124; [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942.log.json) |
Note:
- `OHEM` means Online Hard Example Mining (OHEM) is adopted in training.
- `FP16` means Mixed Precision (FP16) is adopted in training.
- `4x8` means 4 GPUs with 8 samples per GPU in training.

View File

@ -0,0 +1,80 @@
Collections:
- Metadata:
Training Data:
- Cityscapes
Name: bisenetv2
Models:
- Config: configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py
In Collection: bisenetv2
Metadata:
backbone: BiSeNetV2
crop size: (1024,1024)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (1024,1024)
value: 31.48
lr schd: 160000
memory (GB): 7.64
Name: bisenetv2_fcn_4x4_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 73.21
mIoU(ms+flip): 75.74
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551-bcf10f09.pth
- Config: configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py
In Collection: bisenetv2
Metadata:
backbone: BiSeNetV2
crop size: (1024,1024)
lr schd: 160000
memory (GB): 7.64
Name: bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 73.57
mIoU(ms+flip): 75.8
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947-5f8103b4.pth
- Config: configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py
In Collection: bisenetv2
Metadata:
backbone: BiSeNetV2
crop size: (1024,1024)
lr schd: 160000
memory (GB): 15.05
Name: bisenetv2_fcn_4x8_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 75.76
mIoU(ms+flip): 77.79
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032-e1a2eed6.pth
- Config: configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py
In Collection: bisenetv2
Metadata:
backbone: BiSeNetV2
crop size: (1024,1024)
inference time (ms/im):
- backend: PyTorch
batch size: 1
hardware: V100
mode: FP32
resolution: (1024,1024)
value: 27.29
lr schd: 160000
memory (GB): 5.77
Name: bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes
Results:
Dataset: Cityscapes
Metrics:
mIoU: 73.07
mIoU(ms+flip): 75.13
Task: Semantic Segmentation
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942-b979777b.pth

View File

@ -0,0 +1,11 @@
_base_ = [
'../_base_/models/bisenetv2.py',
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.05)
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
)

View File

@ -0,0 +1,11 @@
_base_ = [
'../_base_/models/bisenetv2.py',
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.05)
data = dict(
samples_per_gpu=8,
workers_per_gpu=8,
)

View File

@ -0,0 +1,5 @@
_base_ = './bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py'
# fp16 settings
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale=512.)
# fp16 placeholder
fp16 = dict()

View File

@ -0,0 +1,12 @@
_base_ = [
'../_base_/models/bisenetv2.py',
'../_base_/datasets/cityscapes_1024x1024.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
sampler = dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000)
lr_config = dict(warmup='linear', warmup_iters=1000)
optimizer = dict(lr=0.05)
data = dict(
samples_per_gpu=4,
workers_per_gpu=4,
)

View File

@ -1,4 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from .bisenetv2 import BiSeNetV2
from .cgnet import CGNet
from .fast_scnn import FastSCNN
from .hrnet import HRNet
@ -15,5 +16,5 @@ from .vit import VisionTransformer
__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer'
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'BiSeNetV2'
]

View File

@ -0,0 +1,622 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
build_activation_layer, build_norm_layer)
from mmcv.runner import BaseModule
from mmseg.ops import resize
from ..builder import BACKBONES
class DetailBranch(BaseModule):
"""Detail Branch with wide channels and shallow layers to capture low-level
details and generate high-resolution feature representation.
Args:
detail_channels (Tuple[int]): Size of channel numbers of each stage
in Detail Branch, in paper it has 3 stages.
Default: (64, 64, 128).
in_channels (int): Number of channels of input image. Default: 3.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): Feature map of Detail Branch.
"""
def __init__(self,
detail_channels=(64, 64, 128),
in_channels=3,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(DetailBranch, self).__init__(init_cfg=init_cfg)
detail_branch = []
for i in range(len(detail_channels)):
if i == 0:
detail_branch.append(
nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=detail_channels[i],
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=detail_channels[i],
out_channels=detail_channels[i],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)))
else:
detail_branch.append(
nn.Sequential(
ConvModule(
in_channels=detail_channels[i - 1],
out_channels=detail_channels[i],
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=detail_channels[i],
out_channels=detail_channels[i],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=detail_channels[i],
out_channels=detail_channels[i],
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)))
self.detail_branch = nn.ModuleList(detail_branch)
def forward(self, x):
for stage in self.detail_branch:
x = stage(x)
return x
class StemBlock(BaseModule):
"""Stem Block at the beginning of Semantic Branch.
Args:
in_channels (int): Number of input channels.
Default: 3.
out_channels (int): Number of output channels.
Default: 16.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): First feature map in Semantic Branch.
"""
def __init__(self,
in_channels=3,
out_channels=16,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(StemBlock, self).__init__(init_cfg=init_cfg)
self.conv_first = ConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
self.convs = nn.Sequential(
ConvModule(
in_channels=out_channels,
out_channels=out_channels // 2,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
ConvModule(
in_channels=out_channels // 2,
out_channels=out_channels,
kernel_size=3,
stride=2,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.pool = nn.MaxPool2d(
kernel_size=3, stride=2, padding=1, ceil_mode=False)
self.fuse_last = ConvModule(
in_channels=out_channels * 2,
out_channels=out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x):
x = self.conv_first(x)
x_left = self.convs(x)
x_right = self.pool(x)
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
return x
class GELayer(BaseModule):
"""Gather-and-Expansion Layer.
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
exp_ratio (int): Expansion ratio for middle channels.
Default: 6.
stride (int): Stride of GELayer. Default: 1
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): Intermidiate feature map in
Semantic Branch.
"""
def __init__(self,
in_channels,
out_channels,
exp_ratio=6,
stride=1,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(GELayer, self).__init__(init_cfg=init_cfg)
mid_channel = in_channels * exp_ratio
self.conv1 = ConvModule(
in_channels=in_channels,
out_channels=in_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
if stride == 1:
self.dwconv = nn.Sequential(
# ReLU in ConvModule not shown in paper
ConvModule(
in_channels=in_channels,
out_channels=mid_channel,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg))
self.shortcut = None
else:
self.dwconv = nn.Sequential(
ConvModule(
in_channels=in_channels,
out_channels=mid_channel,
kernel_size=3,
stride=stride,
padding=1,
groups=in_channels,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
# ReLU in ConvModule not shown in paper
ConvModule(
in_channels=mid_channel,
out_channels=mid_channel,
kernel_size=3,
stride=1,
padding=1,
groups=mid_channel,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg),
)
self.shortcut = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=3,
stride=stride,
padding=1,
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=norm_cfg,
pw_act_cfg=None,
))
self.conv2 = nn.Sequential(
ConvModule(
in_channels=mid_channel,
out_channels=out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None,
))
self.act = build_activation_layer(act_cfg)
def forward(self, x):
identity = x
x = self.conv1(x)
x = self.dwconv(x)
x = self.conv2(x)
if self.shortcut is not None:
shortcut = self.shortcut(identity)
x = x + shortcut
else:
x = x + identity
x = self.act(x)
return x
class CEBlock(BaseModule):
"""Context Embedding Block for large receptive filed in Semantic Branch.
Args:
in_channels (int): Number of input channels.
Default: 3.
out_channels (int): Number of output channels.
Default: 16.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
x (torch.Tensor): Last feature map in Semantic Branch.
"""
def __init__(self,
in_channels=3,
out_channels=16,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(CEBlock, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_channels = out_channels
self.gap = nn.Sequential(
nn.AdaptiveAvgPool2d((1, 1)),
build_norm_layer(norm_cfg, self.in_channels)[1])
self.conv_gap = ConvModule(
in_channels=self.in_channels,
out_channels=self.out_channels,
kernel_size=1,
stride=1,
padding=0,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
# Note: in paper here is naive conv2d, no bn-relu
self.conv_last = ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg)
def forward(self, x):
identity = x
x = self.gap(x)
x = self.conv_gap(x)
x = identity + x
x = self.conv_last(x)
return x
class SemanticBranch(BaseModule):
"""Semantic Branch which is lightweight with narrow channels and deep
layers to obtain high-level semantic context.
Args:
semantic_channels(Tuple[int]): Size of channel numbers of
various stages in Semantic Branch.
Default: (16, 32, 64, 128).
in_channels (int): Number of channels of input image. Default: 3.
exp_ratio (int): Expansion ratio for middle channels.
Default: 6.
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
semantic_outs (List[torch.Tensor]): List of several feature maps
for auxiliary heads (Booster) and Bilateral
Guided Aggregation Layer.
"""
def __init__(self,
semantic_channels=(16, 32, 64, 128),
in_channels=3,
exp_ratio=6,
init_cfg=None):
super(SemanticBranch, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.semantic_channels = semantic_channels
self.semantic_stages = []
for i in range(len(semantic_channels)):
stage_name = f'stage{i + 1}'
self.semantic_stages.append(stage_name)
if i == 0:
self.add_module(
stage_name,
StemBlock(self.in_channels, semantic_channels[i]))
elif i == (len(semantic_channels) - 1):
self.add_module(
stage_name,
nn.Sequential(
GELayer(semantic_channels[i - 1], semantic_channels[i],
exp_ratio, 2),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1)))
else:
self.add_module(
stage_name,
nn.Sequential(
GELayer(semantic_channels[i - 1], semantic_channels[i],
exp_ratio, 2),
GELayer(semantic_channels[i], semantic_channels[i],
exp_ratio, 1)))
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
CEBlock(semantic_channels[-1], semantic_channels[-1]))
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
def forward(self, x):
semantic_outs = []
for stage_name in self.semantic_stages:
semantic_stage = getattr(self, stage_name)
x = semantic_stage(x)
semantic_outs.append(x)
return semantic_outs
class BGALayer(BaseModule):
"""Bilateral Guided Aggregation Layer to fuse the complementary information
from both Detail Branch and Semantic Branch.
Args:
out_channels (int): Number of output channels.
Default: 128.
align_corners (bool): align_corners argument of F.interpolate.
Default: False.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
Returns:
output (torch.Tensor): Output feature map for Segment heads.
"""
def __init__(self,
out_channels=128,
align_corners=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
super(BGALayer, self).__init__(init_cfg=init_cfg)
self.out_channels = out_channels
self.align_corners = align_corners
self.detail_dwconv = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=None,
pw_act_cfg=None,
))
self.detail_down = nn.Sequential(
ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=2,
padding=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None),
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
self.semantic_conv = nn.Sequential(
ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=None))
self.semantic_dwconv = nn.Sequential(
DepthwiseSeparableConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
dw_norm_cfg=norm_cfg,
dw_act_cfg=None,
pw_norm_cfg=None,
pw_act_cfg=None,
))
self.conv = ConvModule(
in_channels=self.out_channels,
out_channels=self.out_channels,
kernel_size=3,
stride=1,
padding=1,
inplace=True,
conv_cfg=conv_cfg,
norm_cfg=norm_cfg,
act_cfg=act_cfg,
)
def forward(self, x_d, x_s):
detail_dwconv = self.detail_dwconv(x_d)
detail_down = self.detail_down(x_d)
semantic_conv = self.semantic_conv(x_s)
semantic_dwconv = self.semantic_dwconv(x_s)
semantic_conv = resize(
input=semantic_conv,
size=detail_dwconv.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
fuse_2 = resize(
input=fuse_2,
size=fuse_1.shape[2:],
mode='bilinear',
align_corners=self.align_corners)
output = self.conv(fuse_1 + fuse_2)
return output
@BACKBONES.register_module()
class BiSeNetV2(BaseModule):
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
Real-time Semantic Segmentation.
This backbone is the implementation of
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_.
Args:
in_channels (int): Number of channel of input image. Default: 3.
detail_channels (Tuple[int], optional): Channels of each stage
in Detail Branch. Default: (64, 64, 128).
semantic_channels (Tuple[int], optional): Channels of each stage
in Semantic Branch. Default: (16, 32, 64, 128).
See Table 1 and Figure 3 of paper for more details.
semantic_expansion_ratio (int, optional): The expansion factor
expanding channel number of middle channels in Semantic Branch.
Default: 6.
bga_channels (int, optional): Number of middle channels in
Bilateral Guided Aggregation Layer. Default: 128.
out_indices (Tuple[int] | int, optional): Output from which stages.
Default: (0, 1, 2, 3, 4).
align_corners (bool, optional): The align_corners argument of
resize operation in Bilateral Guided Aggregation Layer.
Default: False.
conv_cfg (dict | None): Config of conv layers.
Default: None.
norm_cfg (dict | None): Config of norm layers.
Default: dict(type='BN').
act_cfg (dict): Config of activation layers.
Default: dict(type='ReLU').
init_cfg (dict or list[dict], optional): Initialization config dict.
Default: None.
"""
def __init__(self,
in_channels=3,
detail_channels=(64, 64, 128),
semantic_channels=(16, 32, 64, 128),
semantic_expansion_ratio=6,
bga_channels=128,
out_indices=(0, 1, 2, 3, 4),
align_corners=False,
conv_cfg=None,
norm_cfg=dict(type='BN'),
act_cfg=dict(type='ReLU'),
init_cfg=None):
if init_cfg is None:
init_cfg = [
dict(type='Kaiming', layer='Conv2d'),
dict(
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
]
super(BiSeNetV2, self).__init__(init_cfg=init_cfg)
self.in_channels = in_channels
self.out_indices = out_indices
self.detail_channels = detail_channels
self.semantic_channels = semantic_channels
self.semantic_expansion_ratio = semantic_expansion_ratio
self.bga_channels = bga_channels
self.align_corners = align_corners
self.conv_cfg = conv_cfg
self.norm_cfg = norm_cfg
self.act_cfg = act_cfg
self.detail = DetailBranch(self.detail_channels, self.in_channels)
self.semantic = SemanticBranch(self.semantic_channels,
self.in_channels,
self.semantic_expansion_ratio)
self.bga = BGALayer(self.bga_channels, self.align_corners)
def forward(self, x):
# stole refactoring code from Coin Cheung, thanks
x_detail = self.detail(x)
x_semantic_lst = self.semantic(x)
x_head = self.bga(x_detail, x_semantic_lst[-1])
outs = [x_head] + x_semantic_lst[:-1]
outs = [outs[i] for i in self.out_indices]
return tuple(outs)

View File

@ -25,6 +25,7 @@ class PatchMerging(BaseModule):
This layer use nn.Unfold to group feature map by kernel_size, and use norm
and linear layer to embed grouped feature map.
Args:
in_channels (int): The num of input channels.
out_channels (int): The num of output channels.

View File

@ -1,6 +1,7 @@
Import:
- configs/ann/ann.yml
- configs/apcnet/apcnet.yml
- configs/bisenetv2/bisenetv2.yml
- configs/ccnet/ccnet.yml
- configs/cgnet/cgnet.yml
- configs/danet/danet.yml

View File

@ -0,0 +1,57 @@
# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from mmseg.models.backbones import BiSeNetV2
from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch,
SemanticBranch)
def test_bisenetv2_backbone():
# Test BiSeNetV2 Standard Forward
model = BiSeNetV2()
model.init_weights()
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 512, 1024)
feat = model(imgs)
assert len(feat) == 5
# output for segment Head
assert feat[0].shape == torch.Size([batch_size, 128, 64, 128])
# for auxiliary head 1
assert feat[1].shape == torch.Size([batch_size, 16, 128, 256])
# for auxiliary head 2
assert feat[2].shape == torch.Size([batch_size, 32, 64, 128])
# for auxiliary head 3
assert feat[3].shape == torch.Size([batch_size, 64, 32, 64])
# for auxiliary head 4
assert feat[4].shape == torch.Size([batch_size, 128, 16, 32])
# Test input with rare shape
batch_size = 2
imgs = torch.randn(batch_size, 3, 527, 952)
feat = model(imgs)
assert len(feat) == 5
def test_bisenetv2_DetailBranch():
x = torch.randn(1, 3, 512, 1024)
detail_branch = DetailBranch(detail_channels=(64, 64, 128))
assert isinstance(detail_branch.detail_branch[0][0], ConvModule)
x_out = detail_branch(x)
assert x_out.shape == torch.Size([1, 128, 64, 128])
def test_bisenetv2_SemanticBranch():
semantic_branch = SemanticBranch(semantic_channels=(16, 32, 64, 128))
assert semantic_branch.stage1.pool.stride == 2
def test_bisenetv2_BGALayer():
x_a = torch.randn(1, 128, 64, 128)
x_b = torch.randn(1, 128, 16, 32)
bga = BGALayer()
assert isinstance(bga.conv, ConvModule)
x_out = bga(x_a, x_b)
assert x_out.shape == torch.Size([1, 128, 64, 128])