# Copyright (c) OpenMMLab. All rights reserved. import pytest import torch from mmseg.models.backbones import MixVisionTransformer from mmseg.models.backbones.mit import (EfficientMultiheadAttention, MixFFN, TransformerEncoderLayer) def test_mit(): with pytest.raises(TypeError): # Pretrained represents pretrain url and must be str or None. MixVisionTransformer(pretrained=123) # Test normal input H, W = (224, 224) temp = torch.randn((1, 3, H, W)) model = MixVisionTransformer( embed_dims=32, num_heads=[1, 2, 5, 8], out_indices=(0, 1, 2, 3)) model.init_weights() outs = model(temp) assert outs[0].shape == (1, 32, H // 4, W // 4) assert outs[1].shape == (1, 64, H // 8, W // 8) assert outs[2].shape == (1, 160, H // 16, W // 16) assert outs[3].shape == (1, 256, H // 32, W // 32) # Test non-squared input H, W = (224, 256) temp = torch.randn((1, 3, H, W)) outs = model(temp) assert outs[0].shape == (1, 32, H // 4, W // 4) assert outs[1].shape == (1, 64, H // 8, W // 8) assert outs[2].shape == (1, 160, H // 16, W // 16) assert outs[3].shape == (1, 256, H // 32, W // 32) # Test MixFFN FFN = MixFFN(64, 128) hw_shape = (32, 32) token_len = 32 * 32 temp = torch.randn((1, token_len, 64)) # Self identity out = FFN(temp, hw_shape) assert out.shape == (1, token_len, 64) # Out identity outs = FFN(temp, hw_shape, temp) assert out.shape == (1, token_len, 64) # Test EfficientMHA MHA = EfficientMultiheadAttention(64, 2) hw_shape = (32, 32) token_len = 32 * 32 temp = torch.randn((1, token_len, 64)) # Self identity out = MHA(temp, hw_shape) assert out.shape == (1, token_len, 64) # Out identity outs = MHA(temp, hw_shape, temp) assert out.shape == (1, token_len, 64) # Test TransformerEncoderLayer with checkpoint forward block = TransformerEncoderLayer( embed_dims=64, num_heads=4, feedforward_channels=256, with_cp=True) assert block.with_cp x = torch.randn(1, 56 * 56, 64) x_out = block(x, (56, 56)) assert x_out.shape == torch.Size([1, 56 * 56, 64]) def test_mit_init(): path = 'PATH_THAT_DO_NOT_EXIST' # Test all combinations of pretrained and init_cfg # pretrained=None, init_cfg=None model = MixVisionTransformer(pretrained=None, init_cfg=None) assert model.init_cfg is None model.init_weights() # pretrained=None # init_cfg loads pretrain from an non-existent file model = MixVisionTransformer( pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path)) assert model.init_cfg == dict(type='Pretrained', checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() # pretrained=None # init_cfg=123, whose type is unsupported model = MixVisionTransformer(pretrained=None, init_cfg=123) with pytest.raises(TypeError): model.init_weights() # pretrained loads pretrain from an non-existent file # init_cfg=None model = MixVisionTransformer(pretrained=path, init_cfg=None) assert model.init_cfg == dict(type='Pretrained', checkpoint=path) # Test loading a checkpoint from an non-existent file with pytest.raises(OSError): model.init_weights() # pretrained loads pretrain from an non-existent file # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): MixVisionTransformer( pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path)) with pytest.raises(AssertionError): MixVisionTransformer(pretrained=path, init_cfg=123) # pretrain=123, whose type is unsupported # init_cfg=None with pytest.raises(TypeError): MixVisionTransformer(pretrained=123, init_cfg=None) # pretrain=123, whose type is unsupported # init_cfg loads pretrain from an non-existent file with pytest.raises(AssertionError): MixVisionTransformer( pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path)) # pretrain=123, whose type is unsupported # init_cfg=123, whose type is unsupported with pytest.raises(AssertionError): MixVisionTransformer(pretrained=123, init_cfg=123)