mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* BiSeNetV2 first commit * BiSeNetV2 unittest * remove pytest * add pytest module * fix ConvModule input name * fix pytest error * fix unittest * refactor * BiSeNetV2 Refactory * fix docstrings and add some small changes * use_sigmoid=False * fix potential bugs about upsampling * Use ConvModule instead * Use ConvModule instead * fix typos * fix typos * fix typos * discard nn.conv2d * discard nn.conv2d * discard nn.conv2d * delete **kwargs * uploading markdown and model * final commit * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * BiSeNetV2 adding Unittest for its modules * Fix README conflict * Fix unittest problem * Fix unittest problem * BiSeNetV2 * Fixing fps * Fixing typpos * bisenetv2
58 lines
1.8 KiB
Python
58 lines
1.8 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import torch
|
|
from mmcv.cnn import ConvModule
|
|
|
|
from mmseg.models.backbones import BiSeNetV2
|
|
from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch,
|
|
SemanticBranch)
|
|
|
|
|
|
def test_bisenetv2_backbone():
|
|
# Test BiSeNetV2 Standard Forward
|
|
model = BiSeNetV2()
|
|
model.init_weights()
|
|
model.train()
|
|
batch_size = 2
|
|
imgs = torch.randn(batch_size, 3, 512, 1024)
|
|
feat = model(imgs)
|
|
|
|
assert len(feat) == 5
|
|
# output for segment Head
|
|
assert feat[0].shape == torch.Size([batch_size, 128, 64, 128])
|
|
# for auxiliary head 1
|
|
assert feat[1].shape == torch.Size([batch_size, 16, 128, 256])
|
|
# for auxiliary head 2
|
|
assert feat[2].shape == torch.Size([batch_size, 32, 64, 128])
|
|
# for auxiliary head 3
|
|
assert feat[3].shape == torch.Size([batch_size, 64, 32, 64])
|
|
# for auxiliary head 4
|
|
assert feat[4].shape == torch.Size([batch_size, 128, 16, 32])
|
|
|
|
# Test input with rare shape
|
|
batch_size = 2
|
|
imgs = torch.randn(batch_size, 3, 527, 952)
|
|
feat = model(imgs)
|
|
assert len(feat) == 5
|
|
|
|
|
|
def test_bisenetv2_DetailBranch():
|
|
x = torch.randn(1, 3, 512, 1024)
|
|
detail_branch = DetailBranch(detail_channels=(64, 64, 128))
|
|
assert isinstance(detail_branch.detail_branch[0][0], ConvModule)
|
|
x_out = detail_branch(x)
|
|
assert x_out.shape == torch.Size([1, 128, 64, 128])
|
|
|
|
|
|
def test_bisenetv2_SemanticBranch():
|
|
semantic_branch = SemanticBranch(semantic_channels=(16, 32, 64, 128))
|
|
assert semantic_branch.stage1.pool.stride == 2
|
|
|
|
|
|
def test_bisenetv2_BGALayer():
|
|
x_a = torch.randn(1, 128, 64, 128)
|
|
x_b = torch.randn(1, 128, 16, 32)
|
|
bga = BGALayer()
|
|
assert isinstance(bga.conv, ConvModule)
|
|
x_out = bga(x_a, x_b)
|
|
assert x_out.shape == torch.Size([1, 128, 64, 128])
|