mmsegmentation/tests/test_models/test_backbones/test_mit.py

61 lines
1.9 KiB
Python

import pytest
import torch
from mmseg.models.backbones import MixVisionTransformer
from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
def test_mit():
with pytest.raises(AssertionError):
# It's only support official style and mmcls style now.
MixVisionTransformer(pretrain_style='timm')
with pytest.raises(TypeError):
# Pretrained represents pretrain url and must be str or None.
MixVisionTransformer(pretrained=123)
# Test normal input
H, W = (224, 224)
temp = torch.randn((1, 3, H, W))
model = MixVisionTransformer(
embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3))
model.init_weights()
outs = model(temp)
assert outs[0].shape == (1, 32, H // 4, W // 4)
assert outs[1].shape == (1, 64, H // 8, W // 8)
assert outs[2].shape == (1, 160, H // 16, W // 16)
assert outs[3].shape == (1, 256, H // 32, W // 32)
# Test non-squared input
H, W = (224, 320)
temp = torch.randn((1, 3, H, W))
outs = model(temp)
assert outs[0].shape == (1, 32, H // 4, W // 4)
assert outs[1].shape == (1, 64, H // 8, W // 8)
assert outs[2].shape == (1, 160, H // 16, W // 16)
assert outs[3].shape == (1, 256, H // 32, W // 32)
# Test MixFFN
FFN = MixFFN(128, 512)
hw_shape = (32, 32)
token_len = 32 * 32
temp = torch.randn((1, token_len, 128))
# Self identity
out = FFN(temp, hw_shape)
assert out.shape == (1, token_len, 128)
# Out identity
outs = FFN(temp, hw_shape, temp)
assert out.shape == (1, token_len, 128)
# Test EfficientMHA
MHA = EfficientMultiheadAttention(128, 2)
hw_shape = (32, 32)
token_len = 32 * 32
temp = torch.randn((1, token_len, 128))
# Self identity
out = MHA(temp, hw_shape)
assert out.shape == (1, token_len, 128)
# Out identity
outs = MHA(temp, hw_shape, temp)
assert out.shape == (1, token_len, 128)