145 lines
4.2 KiB
Python
145 lines
4.2 KiB
Python
from math import ceil
|
|
|
|
import numpy as np
|
|
import pytest
|
|
import torch
|
|
|
|
from mmcls.models.backbones import SwinTransformer
|
|
|
|
|
|
def test_swin_transformer():
|
|
"""Test Swin Transformer backbone."""
|
|
with pytest.raises(AssertionError):
|
|
# Swin Transformer arch string should be in
|
|
SwinTransformer(arch='unknown')
|
|
|
|
with pytest.raises(AssertionError):
|
|
# Swin Transformer arch dict should include 'embed_dims',
|
|
# 'depths' and 'num_head' keys.
|
|
SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))
|
|
|
|
# Test tiny arch forward
|
|
model = SwinTransformer(arch='Tiny')
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
output = model(imgs)
|
|
assert output.shape == (1, 768, 49)
|
|
|
|
# Test small arch forward
|
|
model = SwinTransformer(arch='small')
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
output = model(imgs)
|
|
assert output.shape == (1, 768, 49)
|
|
|
|
# Test base arch forward
|
|
model = SwinTransformer(arch='B')
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
output = model(imgs)
|
|
assert output.shape == (1, 1024, 49)
|
|
|
|
# Test large arch forward
|
|
model = SwinTransformer(arch='l')
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
output = model(imgs)
|
|
assert output.shape == (1, 1536, 49)
|
|
|
|
# Test base arch with window_size=12, image_size=384
|
|
model = SwinTransformer(
|
|
arch='base',
|
|
img_size=384,
|
|
stage_cfgs=dict(block_cfgs=dict(window_size=12)))
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 384, 384)
|
|
output = model(imgs)
|
|
assert output.shape == (1, 1024, 144)
|
|
|
|
# Test small with use_abs_pos_embed = True
|
|
model = SwinTransformer(arch='small', use_abs_pos_embed=True)
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
assert model.absolute_pos_embed.shape == (1, 3136, 96)
|
|
|
|
# Test small with use_abs_pos_embed = False
|
|
with pytest.raises(AttributeError):
|
|
model = SwinTransformer(arch='small', use_abs_pos_embed=False)
|
|
model.absolute_pos_embed
|
|
|
|
# Test small with auto_pad = True
|
|
model = SwinTransformer(
|
|
arch='small',
|
|
auto_pad=True,
|
|
stage_cfgs=dict(
|
|
block_cfgs={'window_size': 7},
|
|
downsample_cfg={
|
|
'kernel_size': (3, 2),
|
|
}))
|
|
model.init_weights()
|
|
model.train()
|
|
|
|
imgs = torch.randn(1, 3, 224, 224)
|
|
|
|
# stage 1
|
|
input_h = int(224 / 4 / 3)
|
|
expect_h = ceil(input_h / 7) * 7
|
|
input_w = int(224 / 4 / 2)
|
|
expect_w = ceil(input_w / 7) * 7
|
|
assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h
|
|
assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w
|
|
|
|
# stage 2
|
|
input_h = int(224 / 4 / 3 / 3)
|
|
# input_h is smaller than window_size, shrink the window_size to input_h.
|
|
expect_h = input_h
|
|
input_w = int(224 / 4 / 2 / 2)
|
|
expect_w = ceil(input_w / input_h) * input_h
|
|
assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h
|
|
assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w
|
|
|
|
# stage 3
|
|
input_h = int(224 / 4 / 3 / 3 / 3)
|
|
expect_h = input_h
|
|
input_w = int(224 / 4 / 2 / 2 / 2)
|
|
expect_w = ceil(input_w / input_h) * input_h
|
|
assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h
|
|
assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w
|
|
|
|
# Test small with auto_pad = False
|
|
with pytest.raises(AssertionError):
|
|
model = SwinTransformer(
|
|
arch='small',
|
|
auto_pad=False,
|
|
stage_cfgs=dict(
|
|
block_cfgs={'window_size': 7},
|
|
downsample_cfg={
|
|
'kernel_size': (3, 2),
|
|
}))
|
|
|
|
# Test drop_path_rate decay
|
|
model = SwinTransformer(
|
|
arch='small',
|
|
drop_path_rate=0.2,
|
|
)
|
|
depths = model.arch_settings['depths']
|
|
pos = 0
|
|
for i, depth in enumerate(depths):
|
|
for j in range(depth):
|
|
block = model.stages[i].blocks[j]
|
|
expect_prob = 0.2 / (sum(depths) - 1) * pos
|
|
assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob)
|
|
assert np.isclose(block.attn.drop.drop_prob, expect_prob)
|
|
pos += 1
|