123 lines
4.3 KiB
Python
123 lines
4.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.backbones import MixVisionTransformer
|
|
from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN,
|
|
TransformerEncoderLayer)
|
|
|
|
|
|
def test_mit():
|
|
with pytest.raises(TypeError):
|
|
# Pretrained represents pretrain url and must be str or None.
|
|
MixVisionTransformer(pretrained=123)
|
|
|
|
# Test normal input
|
|
H, W = (224, 224)
|
|
temp = torch.randn((1, 3, H, W))
|
|
model = MixVisionTransformer(
|
|
embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3))
|
|
model.init_weights()
|
|
outs = model(temp)
|
|
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
|
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
|
assert outs[2].shape == (1, 160, H // 16, W // 16)
|
|
assert outs[3].shape == (1, 256, H // 32, W // 32)
|
|
|
|
# Test non-squared input
|
|
H, W = (224, 256)
|
|
temp = torch.randn((1, 3, H, W))
|
|
outs = model(temp)
|
|
assert outs[0].shape == (1, 32, H // 4, W // 4)
|
|
assert outs[1].shape == (1, 64, H // 8, W // 8)
|
|
assert outs[2].shape == (1, 160, H // 16, W // 16)
|
|
assert outs[3].shape == (1, 256, H // 32, W // 32)
|
|
|
|
# Test MixFFN
|
|
FFN = MixFFN(64, 128)
|
|
hw_shape = (32, 32)
|
|
token_len = 32 * 32
|
|
temp = torch.randn((1, token_len, 64))
|
|
# Self identity
|
|
out = FFN(temp, hw_shape)
|
|
assert out.shape == (1, token_len, 64)
|
|
# Out identity
|
|
outs = FFN(temp, hw_shape, temp)
|
|
assert out.shape == (1, token_len, 64)
|
|
|
|
# Test EfficientMHA
|
|
MHA = EfficientMultiheadAttention(64, 2)
|
|
hw_shape = (32, 32)
|
|
token_len = 32 * 32
|
|
temp = torch.randn((1, token_len, 64))
|
|
# Self identity
|
|
out = MHA(temp, hw_shape)
|
|
assert out.shape == (1, token_len, 64)
|
|
# Out identity
|
|
outs = MHA(temp, hw_shape, temp)
|
|
assert out.shape == (1, token_len, 64)
|
|
|
|
# Test TransformerEncoderLayer with checkpoint forward
|
|
block = TransformerEncoderLayer(
|
|
embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True)
|
|
assert block.with_cp
|
|
x = torch.randn(1, 56 * 56, 64)
|
|
x_out = block(x, (56, 56))
|
|
assert x_out.shape == torch.Size([1, 56 * 56, 64])
|
|
|
|
|
|
def test_mit_init():
|
|
path = 'PATH_THAT_DO_NOT_EXIST'
|
|
# Test all combinations of pretrained and init_cfg
|
|
# pretrained=None, init_cfg=None
|
|
model = MixVisionTransformer(pretrained=None, init_cfg=None)
|
|
assert model.init_cfg is None
|
|
model.init_weights()
|
|
|
|
# pretrained=None
|
|
# init_cfg loads pretrain from an non-existent file
|
|
model = MixVisionTransformer(
|
|
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
|
|
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
|
# Test loading a checkpoint from an non-existent file
|
|
with pytest.raises(OSError):
|
|
model.init_weights()
|
|
|
|
# pretrained=None
|
|
# init_cfg=123, whose type is unsupported
|
|
model = MixVisionTransformer(pretrained=None, init_cfg=123)
|
|
with pytest.raises(TypeError):
|
|
model.init_weights()
|
|
|
|
# pretrained loads pretrain from an non-existent file
|
|
# init_cfg=None
|
|
model = MixVisionTransformer(pretrained=path, init_cfg=None)
|
|
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
|
|
# Test loading a checkpoint from an non-existent file
|
|
with pytest.raises(OSError):
|
|
model.init_weights()
|
|
|
|
# pretrained loads pretrain from an non-existent file
|
|
# init_cfg loads pretrain from an non-existent file
|
|
with pytest.raises(AssertionError):
|
|
MixVisionTransformer(
|
|
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
|
|
with pytest.raises(AssertionError):
|
|
MixVisionTransformer(pretrained=path, init_cfg=123)
|
|
|
|
# pretrain=123, whose type is unsupported
|
|
# init_cfg=None
|
|
with pytest.raises(TypeError):
|
|
MixVisionTransformer(pretrained=123, init_cfg=None)
|
|
|
|
# pretrain=123, whose type is unsupported
|
|
# init_cfg loads pretrain from an non-existent file
|
|
with pytest.raises(AssertionError):
|
|
MixVisionTransformer(
|
|
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
|
|
|
|
# pretrain=123, whose type is unsupported
|
|
# init_cfg=123, whose type is unsupported
|
|
with pytest.raises(AssertionError):
|
|
MixVisionTransformer(pretrained=123, init_cfg=123)
|