mirror of
https://github.com/open-mmlab/mmsegmentation.git
synced 2025-06-03 22:03:48 +08:00
* add Swin Transformer * add Swin Transformer * fixed import * Add some swin training settings. * Fix some filename error. * Fix attribute name: pretrain -> pretrained * Upload mmcls implementation of swin transformer. * Refactor Swin Transformer to follow mmcls style. * Refactor init_weigths of swin_transformer.py * Fix lint * Match inference precision * Add some comments * Add swin_convert to load official style ckpt * Remove arg: auto_pad * 1. Complete comments for each block; 2. Correct weight convert function; 3. Fix the pad of Patch Merging; * Clean function args. * Fix vit unit test. * 1. Add swin transformer unit tests; 2. Fix some pad bug; 3. Modify config to adapt new swin implementation; * Modify config arg * Update readme.md of swin * Fix config arg error and Add some swin benchmark msg. * Add MeM and ms test content for readme.md of swin transformer. * Fix doc string of swin module * 1. Register swin transformer to model list; 2. Modify pth url which keep meta attribute; * Update swin.py * Merge config settings. * Modify config style. * Update README.md Add ViT link * Modify main readme.md Co-authored-by: Jiarui XU <xvjiarui0826@gmail.com> Co-authored-by: sennnnn <201730271412@mail.scut.edu.cn> Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
65 lines
2.0 KiB
Python
65 lines
2.0 KiB
Python
import pytest
|
|
import torch
|
|
|
|
from mmseg.models.backbones import SwinTransformer
|
|
|
|
|
|
def test_swin_transformer():
|
|
"""Test Swin Transformer backbone."""
|
|
|
|
with pytest.raises(AssertionError):
|
|
# We only support 'official' or 'mmcls' for this arg.
|
|
model = SwinTransformer(pretrain_style='swin')
|
|
|
|
with pytest.raises(TypeError):
|
|
# Pretrained arg must be str or None.
|
|
model = SwinTransformer(pretrained=123)
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Because swin use non-overlapping patch embed, so the stride of patch
|
|
# embed must be equal to patch size.
|
|
model = SwinTransformer(strides=(2, 2, 2, 2), patch_size=4)
|
|
|
|
# Test absolute position embedding
|
|
temp = torch.randn((1, 3, 224, 224))
|
|
model = SwinTransformer(pretrain_img_size=224, use_abs_pos_embed=True)
|
|
model.init_weights()
|
|
model(temp)
|
|
|
|
# Test patch norm
|
|
model = SwinTransformer(patch_norm=False)
|
|
model(temp)
|
|
|
|
# Test pretrain img size
|
|
model = SwinTransformer(pretrain_img_size=(224, ))
|
|
|
|
with pytest.raises(AssertionError):
|
|
model = SwinTransformer(pretrain_img_size=(224, 224, 224))
|
|
|
|
# Test normal inference
|
|
temp = torch.randn((1, 3, 512, 512))
|
|
model = SwinTransformer()
|
|
outs = model(temp)
|
|
assert outs[0].shape == (1, 96, 128, 128)
|
|
assert outs[1].shape == (1, 192, 64, 64)
|
|
assert outs[2].shape == (1, 384, 32, 32)
|
|
assert outs[3].shape == (1, 768, 16, 16)
|
|
|
|
# Test abnormal inference
|
|
temp = torch.randn((1, 3, 511, 511))
|
|
model = SwinTransformer()
|
|
outs = model(temp)
|
|
assert outs[0].shape == (1, 96, 128, 128)
|
|
assert outs[1].shape == (1, 192, 64, 64)
|
|
assert outs[2].shape == (1, 384, 32, 32)
|
|
assert outs[3].shape == (1, 768, 16, 16)
|
|
|
|
# Test abnormal inference
|
|
temp = torch.randn((1, 3, 112, 137))
|
|
model = SwinTransformer()
|
|
outs = model(temp)
|
|
assert outs[0].shape == (1, 96, 28, 35)
|
|
assert outs[1].shape == (1, 192, 14, 18)
|
|
assert outs[2].shape == (1, 384, 7, 9)
|
|
assert outs[3].shape == (1, 768, 4, 5)
|