mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
[Feature] Support BiSeNetV2 (#804)
* BiSeNetV2 first commit * BiSeNetV2 unittest * remove pytest * add pytest module * fix ConvModule input name * fix pytest error * fix unittest * refactor * BiSeNetV2 Refactory * fix docstrings and add some small changes * use_sigmoid=False * fix potential bugs about upsampling * Use ConvModule instead * Use ConvModule instead * fix typos * fix typos * fix typos * discard nn.conv2d * discard nn.conv2d * discard nn.conv2d * delete **kwargs * uploading markdown and model * final commit * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * Fix README conflict * Fix unittest problem * Fix unittest problem * BiSeNetV2 * Fixing fps * Fixing typpos * bisenetv2
This commit is contained in:
parent
96b369bd55
commit
f82e4d6fc9
@ -94,6 +94,7 @@ Supported methods:
|
||||
- [x] [PointRend (CVPR'2020)](configs/point_rend)
|
||||
- [x] [CGNet (TIP'2020)](configs/cgnet)
|
||||
- [x] [SETR (CVPR'2021)](configs/setr)
|
||||
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
|
||||
- [x] [SegFormer (ArXiv'2021)](configs/segformer)
|
||||
|
||||
Supported datasets:
|
||||
|
@ -93,6 +93,7 @@ MMSegmentation 是一个基于 PyTorch 的语义分割开源工具箱。它是 O
|
||||
- [x] [PointRend (CVPR'2020)](configs/point_rend)
|
||||
- [x] [CGNet (TIP'2020)](configs/cgnet)
|
||||
- [x] [SETR (CVPR'2021)](configs/setr)
|
||||
- [x] [BiSeNetV2 (IJCV'2021)](configs/bisenetv2)
|
||||
- [x] [SegFormer (ArXiv'2021)](configs/segformer)
|
||||
|
||||
已支持的数据集:
|
||||
|
35
configs/_base_/datasets/cityscapes_1024x1024.py
Normal file
35
configs/_base_/datasets/cityscapes_1024x1024.py
Normal file
@ -0,0 +1,35 @@
|
||||
_base_ = './cityscapes.py'
|
||||
img_norm_cfg = dict(
|
||||
mean=[123.675, 116.28, 103.53], std=[58.395, 57.12, 57.375], to_rgb=True)
|
||||
crop_size = (1024, 1024)
|
||||
train_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(type='LoadAnnotations'),
|
||||
dict(type='Resize', img_scale=(2048, 1024), ratio_range=(0.5, 2.0)),
|
||||
dict(type='RandomCrop', crop_size=crop_size, cat_max_ratio=0.75),
|
||||
dict(type='RandomFlip', prob=0.5),
|
||||
dict(type='PhotoMetricDistortion'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='Pad', size=crop_size, pad_val=0, seg_pad_val=255),
|
||||
dict(type='DefaultFormatBundle'),
|
||||
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
||||
]
|
||||
test_pipeline = [
|
||||
dict(type='LoadImageFromFile'),
|
||||
dict(
|
||||
type='MultiScaleFlipAug',
|
||||
img_scale=(2048, 1024),
|
||||
# img_ratios=[0.5, 0.75, 1.0, 1.25, 1.5, 1.75],
|
||||
flip=False,
|
||||
transforms=[
|
||||
dict(type='Resize', keep_ratio=True),
|
||||
dict(type='RandomFlip'),
|
||||
dict(type='Normalize', **img_norm_cfg),
|
||||
dict(type='ImageToTensor', keys=['img']),
|
||||
dict(type='Collect', keys=['img']),
|
||||
])
|
||||
]
|
||||
data = dict(
|
||||
train=dict(pipeline=train_pipeline),
|
||||
val=dict(pipeline=test_pipeline),
|
||||
test=dict(pipeline=test_pipeline))
|
80
configs/_base_/models/bisenetv2.py
Normal file
80
configs/_base_/models/bisenetv2.py
Normal file
@ -0,0 +1,80 @@
|
||||
# model settings
|
||||
norm_cfg = dict(type='SyncBN', requires_grad=True)
|
||||
model = dict(
|
||||
type='EncoderDecoder',
|
||||
pretrained=None,
|
||||
backbone=dict(
|
||||
type='BiSeNetV2',
|
||||
detail_channels=(64, 64, 128),
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
init_cfg=None,
|
||||
align_corners=False),
|
||||
decode_head=dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
in_index=0,
|
||||
channels=1024,
|
||||
num_convs=1,
|
||||
concat_input=False,
|
||||
dropout_ratio=0.1,
|
||||
num_classes=19,
|
||||
norm_cfg=norm_cfg,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
auxiliary_head=[
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=16,
|
||||
channels=16,
|
||||
num_convs=2,
|
||||
num_classes=19,
|
||||
in_index=1,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=32,
|
||||
channels=64,
|
||||
num_convs=2,
|
||||
num_classes=19,
|
||||
in_index=2,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=64,
|
||||
channels=256,
|
||||
num_convs=2,
|
||||
num_classes=19,
|
||||
in_index=3,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
dict(
|
||||
type='FCNHead',
|
||||
in_channels=128,
|
||||
channels=1024,
|
||||
num_convs=2,
|
||||
num_classes=19,
|
||||
in_index=4,
|
||||
norm_cfg=norm_cfg,
|
||||
concat_input=False,
|
||||
align_corners=False,
|
||||
loss_decode=dict(
|
||||
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
|
||||
],
|
||||
# model training and testing settings
|
||||
train_cfg=dict(),
|
||||
test_cfg=dict(mode='whole'))
|
33
configs/bisenetv2/README.md
Normal file
33
configs/bisenetv2/README.md
Normal file
@ -0,0 +1,33 @@
|
||||
# Bisenet v2: Bilateral Network with Guided Aggregation for Real-time Semantic Segmentation
|
||||
|
||||
## Introduction
|
||||
|
||||
<!-- [ALGORITHM] -->
|
||||
|
||||
```latex
|
||||
@article{yu2021bisenet,
|
||||
title={Bisenet v2: Bilateral network with guided aggregation for real-time semantic segmentation},
|
||||
author={Yu, Changqian and Gao, Changxin and Wang, Jingbo and Yu, Gang and Shen, Chunhua and Sang, Nong},
|
||||
journal={International Journal of Computer Vision},
|
||||
pages={1--18},
|
||||
year={2021},
|
||||
publisher={Springer}
|
||||
}
|
||||
```
|
||||
|
||||
## Results and models
|
||||
|
||||
### Cityscapes
|
||||
|
||||
| Method | Backbone | Crop Size | Lr schd | Mem (GB) | Inf time (fps) | mIoU | mIoU(ms+flip) | config | download |
|
||||
| ------ | -------- | --------- | ------: | -------- | -------------- | ----: | ------------: | --------------------------------------------------------------------------------------------------------------------- | -------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- |
|
||||
| BiSeNetV2 | BiSeNetV2 | 1024x1024 | 160000 | 7.64 | 31.77 | 73.21 | 75.74 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551-bcf10f09.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551.log.json) |
|
||||
| BiSeNetV2 (OHEM) | BiSeNetV2 | 1024x1024 | 160000 | 7.64 | - | 73.57 | 75.80 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947-5f8103b4.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947.log.json) |
|
||||
| BiSeNetV2 (4x8) | BiSeNetV2 | 1024x1024 | 160000 | 15.05 | - | 75.76 | 77.79 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032-e1a2eed6.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032.log.json) |
|
||||
| BiSeNetV2 (FP16) | BiSeNetV2 | 1024x1024 | 160000 | 5.77 | 36.65 | 73.07 | 75.13 | [config](https://github.com/open-mmlab/mmsegmentation/blob/master/configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py) | [model](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942-b979777b.pth) | [log](https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942.log.json) |
|
||||
|
||||
Note:
|
||||
|
||||
- `OHEM` means Online Hard Example Mining (OHEM) is adopted in training.
|
||||
- `FP16` means Mixed Precision (FP16) is adopted in training.
|
||||
- `4x8` means 4 GPUs with 8 samples per GPU in training.
|
80
configs/bisenetv2/bisenetv2.yml
Normal file
80
configs/bisenetv2/bisenetv2.yml
Normal file
@ -0,0 +1,80 @@
|
||||
Collections:
|
||||
- Metadata:
|
||||
Training Data:
|
||||
- Cityscapes
|
||||
Name: bisenetv2
|
||||
Models:
|
||||
- Config: configs/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py
|
||||
In Collection: bisenetv2
|
||||
Metadata:
|
||||
backbone: BiSeNetV2
|
||||
crop size: (1024,1024)
|
||||
inference time (ms/im):
|
||||
- backend: PyTorch
|
||||
batch size: 1
|
||||
hardware: V100
|
||||
mode: FP32
|
||||
resolution: (1024,1024)
|
||||
value: 31.48
|
||||
lr schd: 160000
|
||||
memory (GB): 7.64
|
||||
Name: bisenetv2_fcn_4x4_1024x1024_160k_cityscapes
|
||||
Results:
|
||||
Dataset: Cityscapes
|
||||
Metrics:
|
||||
mIoU: 73.21
|
||||
mIoU(ms+flip): 75.74
|
||||
Task: Semantic Segmentation
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_4x4_1024x1024_160k_cityscapes_20210902_015551-bcf10f09.pth
|
||||
- Config: configs/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes.py
|
||||
In Collection: bisenetv2
|
||||
Metadata:
|
||||
backbone: BiSeNetV2
|
||||
crop size: (1024,1024)
|
||||
lr schd: 160000
|
||||
memory (GB): 7.64
|
||||
Name: bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes
|
||||
Results:
|
||||
Dataset: Cityscapes
|
||||
Metrics:
|
||||
mIoU: 73.57
|
||||
mIoU(ms+flip): 75.8
|
||||
Task: Semantic Segmentation
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_ohem_4x4_1024x1024_160k_cityscapes_20210902_112947-5f8103b4.pth
|
||||
- Config: configs/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes.py
|
||||
In Collection: bisenetv2
|
||||
Metadata:
|
||||
backbone: BiSeNetV2
|
||||
crop size: (1024,1024)
|
||||
lr schd: 160000
|
||||
memory (GB): 15.05
|
||||
Name: bisenetv2_fcn_4x8_1024x1024_160k_cityscapes
|
||||
Results:
|
||||
Dataset: Cityscapes
|
||||
Metrics:
|
||||
mIoU: 75.76
|
||||
mIoU(ms+flip): 77.79
|
||||
Task: Semantic Segmentation
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes/bisenetv2_fcn_4x8_1024x1024_160k_cityscapes_20210903_000032-e1a2eed6.pth
|
||||
- Config: configs/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes.py
|
||||
In Collection: bisenetv2
|
||||
Metadata:
|
||||
backbone: BiSeNetV2
|
||||
crop size: (1024,1024)
|
||||
inference time (ms/im):
|
||||
- backend: PyTorch
|
||||
batch size: 1
|
||||
hardware: V100
|
||||
mode: FP32
|
||||
resolution: (1024,1024)
|
||||
value: 27.29
|
||||
lr schd: 160000
|
||||
memory (GB): 5.77
|
||||
Name: bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes
|
||||
Results:
|
||||
Dataset: Cityscapes
|
||||
Metrics:
|
||||
mIoU: 73.07
|
||||
mIoU(ms+flip): 75.13
|
||||
Task: Semantic Segmentation
|
||||
Weights: https://download.openmmlab.com/mmsegmentation/v0.5/bisenetv2/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes/bisenetv2_fcn_fp16_4x4_1024x1024_160k_cityscapes_20210902_045942-b979777b.pth
|
@ -0,0 +1,11 @@
|
||||
_base_ = [
|
||||
'../_base_/models/bisenetv2.py',
|
||||
'../_base_/datasets/cityscapes_1024x1024.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
lr_config = dict(warmup='linear', warmup_iters=1000)
|
||||
optimizer = dict(lr=0.05)
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
)
|
@ -0,0 +1,11 @@
|
||||
_base_ = [
|
||||
'../_base_/models/bisenetv2.py',
|
||||
'../_base_/datasets/cityscapes_1024x1024.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
lr_config = dict(warmup='linear', warmup_iters=1000)
|
||||
optimizer = dict(lr=0.05)
|
||||
data = dict(
|
||||
samples_per_gpu=8,
|
||||
workers_per_gpu=8,
|
||||
)
|
@ -0,0 +1,5 @@
|
||||
_base_ = './bisenetv2_fcn_4x4_1024x1024_160k_cityscapes.py'
|
||||
# fp16 settings
|
||||
optimizer_config = dict(type='Fp16OptimizerHook', loss_scale=512.)
|
||||
# fp16 placeholder
|
||||
fp16 = dict()
|
@ -0,0 +1,12 @@
|
||||
_base_ = [
|
||||
'../_base_/models/bisenetv2.py',
|
||||
'../_base_/datasets/cityscapes_1024x1024.py',
|
||||
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
|
||||
]
|
||||
sampler = dict(type='OHEMPixelSampler', thresh=0.7, min_kept=10000)
|
||||
lr_config = dict(warmup='linear', warmup_iters=1000)
|
||||
optimizer = dict(lr=0.05)
|
||||
data = dict(
|
||||
samples_per_gpu=4,
|
||||
workers_per_gpu=4,
|
||||
)
|
@ -1,4 +1,5 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
from .bisenetv2 import BiSeNetV2
|
||||
from .cgnet import CGNet
|
||||
from .fast_scnn import FastSCNN
|
||||
from .hrnet import HRNet
|
||||
@ -15,5 +16,5 @@ from .vit import VisionTransformer
|
||||
__all__ = [
|
||||
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
|
||||
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
|
||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer'
|
||||
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer', 'BiSeNetV2'
|
||||
]
|
||||
|
622
mmseg/models/backbones/bisenetv2.py
Normal file
622
mmseg/models/backbones/bisenetv2.py
Normal file
@ -0,0 +1,622 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule,
|
||||
build_activation_layer, build_norm_layer)
|
||||
from mmcv.runner import BaseModule
|
||||
|
||||
from mmseg.ops import resize
|
||||
from ..builder import BACKBONES
|
||||
|
||||
|
||||
class DetailBranch(BaseModule):
|
||||
"""Detail Branch with wide channels and shallow layers to capture low-level
|
||||
details and generate high-resolution feature representation.
|
||||
|
||||
Args:
|
||||
detail_channels (Tuple[int]): Size of channel numbers of each stage
|
||||
in Detail Branch, in paper it has 3 stages.
|
||||
Default: (64, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Feature map of Detail Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
detail_channels=(64, 64, 128),
|
||||
in_channels=3,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(DetailBranch, self).__init__(init_cfg=init_cfg)
|
||||
detail_branch = []
|
||||
for i in range(len(detail_channels)):
|
||||
if i == 0:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
else:
|
||||
detail_branch.append(
|
||||
nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i - 1],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=detail_channels[i],
|
||||
out_channels=detail_channels[i],
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)))
|
||||
self.detail_branch = nn.ModuleList(detail_branch)
|
||||
|
||||
def forward(self, x):
|
||||
for stage in self.detail_branch:
|
||||
x = stage(x)
|
||||
return x
|
||||
|
||||
|
||||
class StemBlock(BaseModule):
|
||||
"""Stem Block at the beginning of Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): First feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(StemBlock, self).__init__(init_cfg=init_cfg)
|
||||
|
||||
self.conv_first = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
self.convs = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=out_channels,
|
||||
out_channels=out_channels // 2,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
ConvModule(
|
||||
in_channels=out_channels // 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.pool = nn.MaxPool2d(
|
||||
kernel_size=3, stride=2, padding=1, ceil_mode=False)
|
||||
self.fuse_last = ConvModule(
|
||||
in_channels=out_channels * 2,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv_first(x)
|
||||
x_left = self.convs(x)
|
||||
x_right = self.pool(x)
|
||||
x = self.fuse_last(torch.cat([x_left, x_right], dim=1))
|
||||
return x
|
||||
|
||||
|
||||
class GELayer(BaseModule):
|
||||
"""Gather-and-Expansion Layer.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
out_channels (int): Number of output channels.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
stride (int): Stride of GELayer. Default: 1
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Intermidiate feature map in
|
||||
Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
exp_ratio=6,
|
||||
stride=1,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(GELayer, self).__init__(init_cfg=init_cfg)
|
||||
mid_channel = in_channels * exp_ratio
|
||||
self.conv1 = ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=in_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
if stride == 1:
|
||||
self.dwconv = nn.Sequential(
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg))
|
||||
self.shortcut = None
|
||||
else:
|
||||
self.dwconv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
groups=in_channels,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
# ReLU in ConvModule not shown in paper
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=mid_channel,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
groups=mid_channel,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg),
|
||||
)
|
||||
self.shortcut = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=in_channels,
|
||||
out_channels=out_channels,
|
||||
kernel_size=3,
|
||||
stride=stride,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=norm_cfg,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
|
||||
self.conv2 = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=mid_channel,
|
||||
out_channels=out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None,
|
||||
))
|
||||
|
||||
self.act = build_activation_layer(act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.conv1(x)
|
||||
x = self.dwconv(x)
|
||||
x = self.conv2(x)
|
||||
if self.shortcut is not None:
|
||||
shortcut = self.shortcut(identity)
|
||||
x = x + shortcut
|
||||
else:
|
||||
x = x + identity
|
||||
x = self.act(x)
|
||||
return x
|
||||
|
||||
|
||||
class CEBlock(BaseModule):
|
||||
"""Context Embedding Block for large receptive filed in Semantic Branch.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of input channels.
|
||||
Default: 3.
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 16.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
x (torch.Tensor): Last feature map in Semantic Branch.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
out_channels=16,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(CEBlock, self).__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.gap = nn.Sequential(
|
||||
nn.AdaptiveAvgPool2d((1, 1)),
|
||||
build_norm_layer(norm_cfg, self.in_channels)[1])
|
||||
self.conv_gap = ConvModule(
|
||||
in_channels=self.in_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
# Note: in paper here is naive conv2d, no bn-relu
|
||||
self.conv_last = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg)
|
||||
|
||||
def forward(self, x):
|
||||
identity = x
|
||||
x = self.gap(x)
|
||||
x = self.conv_gap(x)
|
||||
x = identity + x
|
||||
x = self.conv_last(x)
|
||||
return x
|
||||
|
||||
|
||||
class SemanticBranch(BaseModule):
|
||||
"""Semantic Branch which is lightweight with narrow channels and deep
|
||||
layers to obtain high-level semantic context.
|
||||
|
||||
Args:
|
||||
semantic_channels(Tuple[int]): Size of channel numbers of
|
||||
various stages in Semantic Branch.
|
||||
Default: (16, 32, 64, 128).
|
||||
in_channels (int): Number of channels of input image. Default: 3.
|
||||
exp_ratio (int): Expansion ratio for middle channels.
|
||||
Default: 6.
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
semantic_outs (List[torch.Tensor]): List of several feature maps
|
||||
for auxiliary heads (Booster) and Bilateral
|
||||
Guided Aggregation Layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
in_channels=3,
|
||||
exp_ratio=6,
|
||||
init_cfg=None):
|
||||
super(SemanticBranch, self).__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_stages = []
|
||||
for i in range(len(semantic_channels)):
|
||||
stage_name = f'stage{i + 1}'
|
||||
self.semantic_stages.append(stage_name)
|
||||
if i == 0:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
StemBlock(self.in_channels, semantic_channels[i]))
|
||||
elif i == (len(semantic_channels) - 1):
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
else:
|
||||
self.add_module(
|
||||
stage_name,
|
||||
nn.Sequential(
|
||||
GELayer(semantic_channels[i - 1], semantic_channels[i],
|
||||
exp_ratio, 2),
|
||||
GELayer(semantic_channels[i], semantic_channels[i],
|
||||
exp_ratio, 1)))
|
||||
|
||||
self.add_module(f'stage{len(semantic_channels)}_CEBlock',
|
||||
CEBlock(semantic_channels[-1], semantic_channels[-1]))
|
||||
self.semantic_stages.append(f'stage{len(semantic_channels)}_CEBlock')
|
||||
|
||||
def forward(self, x):
|
||||
semantic_outs = []
|
||||
for stage_name in self.semantic_stages:
|
||||
semantic_stage = getattr(self, stage_name)
|
||||
x = semantic_stage(x)
|
||||
semantic_outs.append(x)
|
||||
return semantic_outs
|
||||
|
||||
|
||||
class BGALayer(BaseModule):
|
||||
"""Bilateral Guided Aggregation Layer to fuse the complementary information
|
||||
from both Detail Branch and Semantic Branch.
|
||||
|
||||
Args:
|
||||
out_channels (int): Number of output channels.
|
||||
Default: 128.
|
||||
align_corners (bool): align_corners argument of F.interpolate.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
Returns:
|
||||
output (torch.Tensor): Output feature map for Segment heads.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
out_channels=128,
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
super(BGALayer, self).__init__(init_cfg=init_cfg)
|
||||
self.out_channels = out_channels
|
||||
self.align_corners = align_corners
|
||||
self.detail_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.detail_down = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=2,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None),
|
||||
nn.AvgPool2d(kernel_size=3, stride=2, padding=1, ceil_mode=False))
|
||||
self.semantic_conv = nn.Sequential(
|
||||
ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
bias=False,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=None))
|
||||
self.semantic_dwconv = nn.Sequential(
|
||||
DepthwiseSeparableConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
dw_norm_cfg=norm_cfg,
|
||||
dw_act_cfg=None,
|
||||
pw_norm_cfg=None,
|
||||
pw_act_cfg=None,
|
||||
))
|
||||
self.conv = ConvModule(
|
||||
in_channels=self.out_channels,
|
||||
out_channels=self.out_channels,
|
||||
kernel_size=3,
|
||||
stride=1,
|
||||
padding=1,
|
||||
inplace=True,
|
||||
conv_cfg=conv_cfg,
|
||||
norm_cfg=norm_cfg,
|
||||
act_cfg=act_cfg,
|
||||
)
|
||||
|
||||
def forward(self, x_d, x_s):
|
||||
detail_dwconv = self.detail_dwconv(x_d)
|
||||
detail_down = self.detail_down(x_d)
|
||||
semantic_conv = self.semantic_conv(x_s)
|
||||
semantic_dwconv = self.semantic_dwconv(x_s)
|
||||
semantic_conv = resize(
|
||||
input=semantic_conv,
|
||||
size=detail_dwconv.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
fuse_1 = detail_dwconv * torch.sigmoid(semantic_conv)
|
||||
fuse_2 = detail_down * torch.sigmoid(semantic_dwconv)
|
||||
fuse_2 = resize(
|
||||
input=fuse_2,
|
||||
size=fuse_1.shape[2:],
|
||||
mode='bilinear',
|
||||
align_corners=self.align_corners)
|
||||
output = self.conv(fuse_1 + fuse_2)
|
||||
return output
|
||||
|
||||
|
||||
@BACKBONES.register_module()
|
||||
class BiSeNetV2(BaseModule):
|
||||
"""BiSeNetV2: Bilateral Network with Guided Aggregation for
|
||||
Real-time Semantic Segmentation.
|
||||
|
||||
This backbone is the implementation of
|
||||
`BiSeNetV2 <https://arxiv.org/abs/2004.02147>`_.
|
||||
|
||||
Args:
|
||||
in_channels (int): Number of channel of input image. Default: 3.
|
||||
detail_channels (Tuple[int], optional): Channels of each stage
|
||||
in Detail Branch. Default: (64, 64, 128).
|
||||
semantic_channels (Tuple[int], optional): Channels of each stage
|
||||
in Semantic Branch. Default: (16, 32, 64, 128).
|
||||
See Table 1 and Figure 3 of paper for more details.
|
||||
semantic_expansion_ratio (int, optional): The expansion factor
|
||||
expanding channel number of middle channels in Semantic Branch.
|
||||
Default: 6.
|
||||
bga_channels (int, optional): Number of middle channels in
|
||||
Bilateral Guided Aggregation Layer. Default: 128.
|
||||
out_indices (Tuple[int] | int, optional): Output from which stages.
|
||||
Default: (0, 1, 2, 3, 4).
|
||||
align_corners (bool, optional): The align_corners argument of
|
||||
resize operation in Bilateral Guided Aggregation Layer.
|
||||
Default: False.
|
||||
conv_cfg (dict | None): Config of conv layers.
|
||||
Default: None.
|
||||
norm_cfg (dict | None): Config of norm layers.
|
||||
Default: dict(type='BN').
|
||||
act_cfg (dict): Config of activation layers.
|
||||
Default: dict(type='ReLU').
|
||||
init_cfg (dict or list[dict], optional): Initialization config dict.
|
||||
Default: None.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
in_channels=3,
|
||||
detail_channels=(64, 64, 128),
|
||||
semantic_channels=(16, 32, 64, 128),
|
||||
semantic_expansion_ratio=6,
|
||||
bga_channels=128,
|
||||
out_indices=(0, 1, 2, 3, 4),
|
||||
align_corners=False,
|
||||
conv_cfg=None,
|
||||
norm_cfg=dict(type='BN'),
|
||||
act_cfg=dict(type='ReLU'),
|
||||
init_cfg=None):
|
||||
if init_cfg is None:
|
||||
init_cfg = [
|
||||
dict(type='Kaiming', layer='Conv2d'),
|
||||
dict(
|
||||
type='Constant', val=1, layer=['_BatchNorm', 'GroupNorm'])
|
||||
]
|
||||
super(BiSeNetV2, self).__init__(init_cfg=init_cfg)
|
||||
self.in_channels = in_channels
|
||||
self.out_indices = out_indices
|
||||
self.detail_channels = detail_channels
|
||||
self.semantic_channels = semantic_channels
|
||||
self.semantic_expansion_ratio = semantic_expansion_ratio
|
||||
self.bga_channels = bga_channels
|
||||
self.align_corners = align_corners
|
||||
self.conv_cfg = conv_cfg
|
||||
self.norm_cfg = norm_cfg
|
||||
self.act_cfg = act_cfg
|
||||
|
||||
self.detail = DetailBranch(self.detail_channels, self.in_channels)
|
||||
self.semantic = SemanticBranch(self.semantic_channels,
|
||||
self.in_channels,
|
||||
self.semantic_expansion_ratio)
|
||||
self.bga = BGALayer(self.bga_channels, self.align_corners)
|
||||
|
||||
def forward(self, x):
|
||||
# stole refactoring code from Coin Cheung, thanks
|
||||
x_detail = self.detail(x)
|
||||
x_semantic_lst = self.semantic(x)
|
||||
x_head = self.bga(x_detail, x_semantic_lst[-1])
|
||||
outs = [x_head] + x_semantic_lst[:-1]
|
||||
outs = [outs[i] for i in self.out_indices]
|
||||
return tuple(outs)
|
@ -25,6 +25,7 @@ class PatchMerging(BaseModule):
|
||||
|
||||
This layer use nn.Unfold to group feature map by kernel_size, and use norm
|
||||
and linear layer to embed grouped feature map.
|
||||
|
||||
Args:
|
||||
in_channels (int): The num of input channels.
|
||||
out_channels (int): The num of output channels.
|
||||
|
@ -1,6 +1,7 @@
|
||||
Import:
|
||||
- configs/ann/ann.yml
|
||||
- configs/apcnet/apcnet.yml
|
||||
- configs/bisenetv2/bisenetv2.yml
|
||||
- configs/ccnet/ccnet.yml
|
||||
- configs/cgnet/cgnet.yml
|
||||
- configs/danet/danet.yml
|
||||
|
57
tests/test_models/test_backbones/test_bisenetv2.py
Normal file
57
tests/test_models/test_backbones/test_bisenetv2.py
Normal file
@ -0,0 +1,57 @@
|
||||
# Copyright (c) OpenMMLab. All rights reserved.
|
||||
import torch
|
||||
from mmcv.cnn import ConvModule
|
||||
|
||||
from mmseg.models.backbones import BiSeNetV2
|
||||
from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch,
|
||||
SemanticBranch)
|
||||
|
||||
|
||||
def test_bisenetv2_backbone():
|
||||
# Test BiSeNetV2 Standard Forward
|
||||
model = BiSeNetV2()
|
||||
model.init_weights()
|
||||
model.train()
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 512, 1024)
|
||||
feat = model(imgs)
|
||||
|
||||
assert len(feat) == 5
|
||||
# output for segment Head
|
||||
assert feat[0].shape == torch.Size([batch_size, 128, 64, 128])
|
||||
# for auxiliary head 1
|
||||
assert feat[1].shape == torch.Size([batch_size, 16, 128, 256])
|
||||
# for auxiliary head 2
|
||||
assert feat[2].shape == torch.Size([batch_size, 32, 64, 128])
|
||||
# for auxiliary head 3
|
||||
assert feat[3].shape == torch.Size([batch_size, 64, 32, 64])
|
||||
# for auxiliary head 4
|
||||
assert feat[4].shape == torch.Size([batch_size, 128, 16, 32])
|
||||
|
||||
# Test input with rare shape
|
||||
batch_size = 2
|
||||
imgs = torch.randn(batch_size, 3, 527, 952)
|
||||
feat = model(imgs)
|
||||
assert len(feat) == 5
|
||||
|
||||
|
||||
def test_bisenetv2_DetailBranch():
|
||||
x = torch.randn(1, 3, 512, 1024)
|
||||
detail_branch = DetailBranch(detail_channels=(64, 64, 128))
|
||||
assert isinstance(detail_branch.detail_branch[0][0], ConvModule)
|
||||
x_out = detail_branch(x)
|
||||
assert x_out.shape == torch.Size([1, 128, 64, 128])
|
||||
|
||||
|
||||
def test_bisenetv2_SemanticBranch():
|
||||
semantic_branch = SemanticBranch(semantic_channels=(16, 32, 64, 128))
|
||||
assert semantic_branch.stage1.pool.stride == 2
|
||||
|
||||
|
||||
def test_bisenetv2_BGALayer():
|
||||
x_a = torch.randn(1, 128, 64, 128)
|
||||
x_b = torch.randn(1, 128, 16, 32)
|
||||
bga = BGALayer()
|
||||
assert isinstance(bga.conv, ConvModule)
|
||||
x_out = bga(x_a, x_b)
|
||||
assert x_out.shape == torch.Size([1, 128, 64, 128])
|
Loading…
x
Reference in New Issue
Block a user