mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* [Feature]Segformer re-implementation * Using act_cfg and norm_cfg to control activation and normalization * Split this PR into several little PRs * Fix lint error * Remove SegFormerHead * parameters init refactor * 1. Refactor segformer backbone parameters init; 2. Remove rebundant functions and unit tests; * Remove rebundant codes * 1. Remove rebundant codes; 2. Modify module name; * Refactor the backbone of segformer using mmcv.cnn.bricks.transformer.py * Fix some code logic bugs. * Add mit_convert.py to match pretrain keys of segformer. * Resolve some comments. * 1. Add some assert to ensure right params; 2. Support flexible peconv position; * Add pe_index assert and fix unit test. * 1. Add doc string for MixVisionTransformer; 2. Add some unit tests for MixVisionTransformer; * Use hw_shape to pass shape of feature map. * 1. Fix doc string of MixVisionTransformer; 2. Simplify MixFFN; 3. Modify H, W to hw_shape; * Add more unit tests. * Add doc string for shape convertion functions. * Add some unit tests to improve code coverage. * Fix Segformer backbone pretrain weights match bug. * resolve the shape convertion functions doc string. * Add pad_to_patch_size arg. * Modify default value of pad_to_patch_size arg.
61 lines
1.9 KiB
Python
61 lines
1.9 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.backbones import MixVisionTransformer
|
|
from mmseg.models.backbones.mit import EfficientMultiheadAttention, MixFFN
|
|
|
|
|
|
def test_mit():
|
|
with pytest.raises(AssertionError):
|
|
# It's only support official style and mmcls style now.
|
|
MixVisionTransformer(pretrain_style='timm')
|
|
|
|
with pytest.raises(TypeError):
|
|
# Pretrained represents pretrain url and must be str or None.
|
|
MixVisionTransformer(pretrained=123)
|
|
|
|
# Test normal input
|
|
H, W = (224, 224)
|
|
temp = torch.randn((1, 3, H, W))
|
|
model = MixVisionTransformer(
|
|
embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3))
|
|
model.init_weights()
|
|
outs = model(temp)
|
|
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
|
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
|
assert outs[2].shape == (1, 160, H // 16, W // 16)
|
|
assert outs[3].shape == (1, 256, H // 32, W // 32)
|
|
|
|
# Test non-squared input
|
|
H, W = (224, 320)
|
|
temp = torch.randn((1, 3, H, W))
|
|
outs = model(temp)
|
|
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
|
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
|
assert outs[2].shape == (1, 160, H // 16, W // 16)
|
|
assert outs[3].shape == (1, 256, H // 32, W // 32)
|
|
|
|
# Test MixFFN
|
|
FFN = MixFFN(128, 512)
|
|
hw_shape = (32, 32)
|
|
token_len = 32 * 32
|
|
temp = torch.randn((1, token_len, 128))
|
|
# Self identity
|
|
out = FFN(temp, hw_shape)
|
|
assert out.shape == (1, token_len, 128)
|
|
# Out identity
|
|
outs = FFN(temp, hw_shape, temp)
|
|
assert out.shape == (1, token_len, 128)
|
|
|
|
# Test EfficientMHA
|
|
MHA = EfficientMultiheadAttention(128, 2)
|
|
hw_shape = (32, 32)
|
|
token_len = 32 * 32
|
|
temp = torch.randn((1, token_len, 128))
|
|
# Self identity
|
|
out = MHA(temp, hw_shape)
|
|
assert out.shape == (1, token_len, 128)
|
|
# Out identity
|
|
outs = MHA(temp, hw_shape, temp)
|
|
assert out.shape == (1, token_len, 128)
|