# Copyright (c) OpenMMLab. All rights reserved. import torch from mmcv.cnn import ConvModule from mmseg.models.backbones import BiSeNetV2 from mmseg.models.backbones.bisenetv2 import (BGALayer, DetailBranch, SemanticBranch) def test_bisenetv2_backbone(): # Test BiSeNetV2 Standard Forward model = BiSeNetV2() model.init_weights() model.train() batch_size = 2 imgs = torch.randn(batch_size, 3, 512, 1024) feat = model(imgs) assert len(feat) == 5 # output for segment Head assert feat[0].shape == torch.Size([batch_size, 128, 64, 128]) # for auxiliary head 1 assert feat[1].shape == torch.Size([batch_size, 16, 128, 256]) # for auxiliary head 2 assert feat[2].shape == torch.Size([batch_size, 32, 64, 128]) # for auxiliary head 3 assert feat[3].shape == torch.Size([batch_size, 64, 32, 64]) # for auxiliary head 4 assert feat[4].shape == torch.Size([batch_size, 128, 16, 32]) # Test input with rare shape batch_size = 2 imgs = torch.randn(batch_size, 3, 527, 952) feat = model(imgs) assert len(feat) == 5 def test_bisenetv2_DetailBranch(): x = torch.randn(1, 3, 512, 1024) detail_branch = DetailBranch(detail_channels=(64, 64, 128)) assert isinstance(detail_branch.detail_branch[0][0], ConvModule) x_out = detail_branch(x) assert x_out.shape == torch.Size([1, 128, 64, 128]) def test_bisenetv2_SemanticBranch(): semantic_branch = SemanticBranch(semantic_channels=(16, 32, 64, 128)) assert semantic_branch.stage1.pool.stride == 2 def test_bisenetv2_BGALayer(): x_a = torch.randn(1, 128, 64, 128) x_b = torch.randn(1, 128, 16, 32) bga = BGALayer() assert isinstance(bga.conv, ConvModule) x_out = bga(x_a, x_b) assert x_out.shape == torch.Size([1, 128, 64, 128])