Lxinyang cd246e3635 [Feature] Support Twins (NeurIPS2021) (#989)
* debug

* debug

* debug

* this is a debug step, and needs to be recovered

* need recover

* git

* debug

* git

* git

* git

* git

* git

* git

* debug need recover

* debug

* git

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debugf

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* git

* git

* git

* use config small/base/large

* debug

* debug

* git

* debug

* git

* debug

* debug

* debug args

* debug

* debug

* git

* git

* debug

* git

* git

* git

* git

* git

* debug

* debug

* git

* debug

* git

* debug

* debug

* debug

* debug

* git

* debug

* git

* git

* debug

* debug

* git

* git

* git

* git

* debug

* debug

* debug

* debug

* git

* debug

* debug

* git

* git

* debug

* debug

* git

* debug

* debug

* debug

* git

* debug

* debug

* debug

* Please enter the commit message for your changes. Lines starting

* git

* git

* debug

* debug

* debug

* git

* git

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* git

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* git

* fix pre-commit error

* fix error

* git

* git

* git

* git

* git

* git

* debug

* debug

* debug

* debug

* debug

* debug

* git

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* debug

* git

* git

* git

* debug

* debug

* debug

* git

* git

* git

* git

* git

* git

* git

* git

* git

* debug

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* fix unittest error

* fix config errors

* fix twins2mmseg bug

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* git

* fix init_weights() in twins.py

* git

* git

* git

* git

* fix comment

* fix comment

* fix comment

* fix comment

* fix unit test coverage in TwinsPR

* Add Twins README

* Add Twins README

* twins refactor

* twins refactor

* delete init_cfg in FFN

* delete init_cfg in FFN

* Update mmseg/models/backbones/twins.py

* Update mmseg/models/backbones/twins.py

* Update mmseg/models/backbones/twins.py

Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>

* Update mmseg/models/backbones/twins.py

* add conference name

Co-authored-by: linxinyang <linxinyang@meituan.com>
Co-authored-by: MengzhangLI <mcmong@pku.edu.cn>
Co-authored-by: Junjun2016 <hejunjun@sjtu.edu.cn>
2021-12-09 19:18:10 +08:00

171 lines
5.8 KiB
Python

import pytest
import torch
from mmseg.models.backbones.twins import (PCPVT, SVT,
ConditionalPositionEncoding,
LocallyGroupedSelfAttention)
def test_pcpvt():
# Test normal input
H, W = (224, 224)
temp = torch.randn((1, 3, H, W))
model = PCPVT(
embed_dims=[32, 64, 160, 256],
num_heads=[1, 2, 5, 8],
mlp_ratios=[8, 8, 4, 4],
qkv_bias=True,
depths=[3, 4, 6, 3],
sr_ratios=[8, 4, 2, 1],
norm_after_stage=False)
model.init_weights()
outs = model(temp)
assert outs[0].shape == (1, 32, H // 4, W // 4)
assert outs[1].shape == (1, 64, H // 8, W // 8)
assert outs[2].shape == (1, 160, H // 16, W // 16)
assert outs[3].shape == (1, 256, H // 32, W // 32)
def test_svt():
# Test normal input
H, W = (224, 224)
temp = torch.randn((1, 3, H, W))
model = SVT(
embed_dims=[32, 64, 128],
num_heads=[1, 2, 4],
mlp_ratios=[4, 4, 4],
qkv_bias=False,
depths=[4, 4, 4],
windiow_sizes=[7, 7, 7],
norm_after_stage=True)
model.init_weights()
outs = model(temp)
assert outs[0].shape == (1, 32, H // 4, W // 4)
assert outs[1].shape == (1, 64, H // 8, W // 8)
assert outs[2].shape == (1, 128, H // 16, W // 16)
def test_svt_init():
path = 'PATH_THAT_DO_NOT_EXIST'
# Test all combinations of pretrained and init_cfg
# pretrained=None, init_cfg=None
model = SVT(pretrained=None, init_cfg=None)
assert model.init_cfg is None
model.init_weights()
# pretrained=None
# init_cfg loads pretrain from an non-existent file
model = SVT(
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained=None
# init_cfg=123, whose type is unsupported
model = SVT(pretrained=None, init_cfg=123)
with pytest.raises(TypeError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg=None
model = SVT(pretrained=path, init_cfg=None)
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = SVT(
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
with pytest.raises(AssertionError):
model = SVT(pretrained=path, init_cfg=123)
# pretrain=123, whose type is unsupported
# init_cfg=None
with pytest.raises(TypeError):
model = SVT(pretrained=123, init_cfg=None)
# pretrain=123, whose type is unsupported
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = SVT(
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
# pretrain=123, whose type is unsupported
# init_cfg=123, whose type is unsupported
with pytest.raises(AssertionError):
model = SVT(pretrained=123, init_cfg=123)
def test_pcpvt_init():
path = 'PATH_THAT_DO_NOT_EXIST'
# Test all combinations of pretrained and init_cfg
# pretrained=None, init_cfg=None
model = PCPVT(pretrained=None, init_cfg=None)
assert model.init_cfg is None
model.init_weights()
# pretrained=None
# init_cfg loads pretrain from an non-existent file
model = PCPVT(
pretrained=None, init_cfg=dict(type='Pretrained', checkpoint=path))
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained=None
# init_cfg=123, whose type is unsupported
model = PCPVT(pretrained=None, init_cfg=123)
with pytest.raises(TypeError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg=None
model = PCPVT(pretrained=path, init_cfg=None)
assert model.init_cfg == dict(type='Pretrained', checkpoint=path)
# Test loading a checkpoint from an non-existent file
with pytest.raises(OSError):
model.init_weights()
# pretrained loads pretrain from an non-existent file
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = PCPVT(
pretrained=path, init_cfg=dict(type='Pretrained', checkpoint=path))
with pytest.raises(AssertionError):
model = PCPVT(pretrained=path, init_cfg=123)
# pretrain=123, whose type is unsupported
# init_cfg=None
with pytest.raises(TypeError):
model = PCPVT(pretrained=123, init_cfg=None)
# pretrain=123, whose type is unsupported
# init_cfg loads pretrain from an non-existent file
with pytest.raises(AssertionError):
model = PCPVT(
pretrained=123, init_cfg=dict(type='Pretrained', checkpoint=path))
# pretrain=123, whose type is unsupported
# init_cfg=123, whose type is unsupported
with pytest.raises(AssertionError):
model = PCPVT(pretrained=123, init_cfg=123)
def test_locallygrouped_self_attention_module():
LSA = LocallyGroupedSelfAttention(embed_dims=32, window_size=3)
outs = LSA(torch.randn(1, 3136, 32), (56, 56))
assert outs.shape == torch.Size([1, 3136, 32])
def test_conditional_position_encoding_module():
CPE = ConditionalPositionEncoding(in_channels=32, embed_dims=32, stride=2)
outs = CPE(torch.randn(1, 3136, 32), (56, 56))
assert outs.shape == torch.Size([1, 784, 32])