mmsegmentation/tests/test_models/test_backbones/test_bisenetv2.py

58 lines
1.8 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import torch
from mmcv.cnn import ConvModule
from mmseg.models.backbones import BiSeNetV2
from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch,
SemanticBranch)
def test_bisenetv2_backbone():
# Test BiSeNetV2 Standard Forward
model = BiSeNetV2()
model.init_weights()
model.train()
batch_size = 2
imgs = torch.randn(batch_size, 3, 512, 1024)
feat = model(imgs)
assert len(feat) == 5
# output for segment Head
assert feat[0].shape == torch.Size([batch_size, 128, 64, 128])
# for auxiliary head 1
assert feat[1].shape == torch.Size([batch_size, 16, 128, 256])
# for auxiliary head 2
assert feat[2].shape == torch.Size([batch_size, 32, 64, 128])
# for auxiliary head 3
assert feat[3].shape == torch.Size([batch_size, 64, 32, 64])
# for auxiliary head 4
assert feat[4].shape == torch.Size([batch_size, 128, 16, 32])
# Test input with rare shape
batch_size = 2
imgs = torch.randn(batch_size, 3, 527, 952)
feat = model(imgs)
assert len(feat) == 5
def test_bisenetv2_DetailBranch():
x = torch.randn(1, 3, 512, 1024)
detail_branch = DetailBranch(detail_channels=(64, 64, 128))
assert isinstance(detail_branch.detail_branch[0][0], ConvModule)
x_out = detail_branch(x)
assert x_out.shape == torch.Size([1, 128, 64, 128])
def test_bisenetv2_SemanticBranch():
semantic_branch = SemanticBranch(semantic_channels=(16, 32, 64, 128))
assert semantic_branch.stage1.pool.stride == 2
def test_bisenetv2_BGALayer():
x_a = torch.randn(1, 128, 64, 128)
x_b = torch.randn(1, 128, 16, 32)
bga = BGALayer()
assert isinstance(bga.conv, ConvModule)
x_out = bga(x_a, x_b)
assert x_out.shape == torch.Size([1, 128, 64, 128])