mmclassification/tests/test_backbones/test_swin_transformer.py

145 lines
4.2 KiB
Python

from math import ceil
import numpy as np
import pytest
import torch
from mmcls.models.backbones import SwinTransformer
def test_swin_transformer():
"""Test Swin Transformer backbone."""
with pytest.raises(AssertionError):
# Swin Transformer arch string should be in
SwinTransformer(arch='unknown')
with pytest.raises(AssertionError):
# Swin Transformer arch dict should include 'embed_dims',
# 'depths' and 'num_head' keys.
SwinTransformer(arch=dict(embed_dims=96, depths=[2, 2, 18, 2]))
# Test tiny arch forward
model = SwinTransformer(arch='Tiny')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert output.shape == (1, 768, 49)
# Test small arch forward
model = SwinTransformer(arch='small')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert output.shape == (1, 768, 49)
# Test base arch forward
model = SwinTransformer(arch='B')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert output.shape == (1, 1024, 49)
# Test large arch forward
model = SwinTransformer(arch='l')
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
output = model(imgs)
assert output.shape == (1, 1536, 49)
# Test base arch with window_size=12, image_size=384
model = SwinTransformer(
arch='base',
img_size=384,
stage_cfgs=dict(block_cfgs=dict(window_size=12)))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 384, 384)
output = model(imgs)
assert output.shape == (1, 1024, 144)
# Test small with use_abs_pos_embed = True
model = SwinTransformer(arch='small', use_abs_pos_embed=True)
model.init_weights()
model.train()
assert model.absolute_pos_embed.shape == (1, 3136, 96)
# Test small with use_abs_pos_embed = False
with pytest.raises(AttributeError):
model = SwinTransformer(arch='small', use_abs_pos_embed=False)
model.absolute_pos_embed
# Test small with auto_pad = True
model = SwinTransformer(
arch='small',
auto_pad=True,
stage_cfgs=dict(
block_cfgs={'window_size': 7},
downsample_cfg={
'kernel_size': (3, 2),
}))
model.init_weights()
model.train()
imgs = torch.randn(1, 3, 224, 224)
# stage 1
input_h = int(224 / 4 / 3)
expect_h = ceil(input_h / 7) * 7
input_w = int(224 / 4 / 2)
expect_w = ceil(input_w / 7) * 7
assert model.stages[1].blocks[0].attn.pad_b == expect_h - input_h
assert model.stages[1].blocks[0].attn.pad_r == expect_w - input_w
# stage 2
input_h = int(224 / 4 / 3 / 3)
# input_h is smaller than window_size, shrink the window_size to input_h.
expect_h = input_h
input_w = int(224 / 4 / 2 / 2)
expect_w = ceil(input_w / input_h) * input_h
assert model.stages[2].blocks[0].attn.pad_b == expect_h - input_h
assert model.stages[2].blocks[0].attn.pad_r == expect_w - input_w
# stage 3
input_h = int(224 / 4 / 3 / 3 / 3)
expect_h = input_h
input_w = int(224 / 4 / 2 / 2 / 2)
expect_w = ceil(input_w / input_h) * input_h
assert model.stages[3].blocks[0].attn.pad_b == expect_h - input_h
assert model.stages[3].blocks[0].attn.pad_r == expect_w - input_w
# Test small with auto_pad = False
with pytest.raises(AssertionError):
model = SwinTransformer(
arch='small',
auto_pad=False,
stage_cfgs=dict(
block_cfgs={'window_size': 7},
downsample_cfg={
'kernel_size': (3, 2),
}))
# Test drop_path_rate decay
model = SwinTransformer(
arch='small',
drop_path_rate=0.2,
)
depths = model.arch_settings['depths']
pos = 0
for i, depth in enumerate(depths):
for j in range(depth):
block = model.stages[i].blocks[j]
expect_prob = 0.2 / (sum(depths) - 1) * pos
assert np.isclose(block.ffn.dropout_layer.drop_prob, expect_prob)
assert np.isclose(block.attn.drop.drop_prob, expect_prob)
pos += 1