mmclassification/tests/test_models/test_utils/test_attention.py

179 lines
5.6 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
from mmcls.models.utils.attention import ShiftWindowMSA, WindowMSA
def get_relative_position_index(window_size):
"""Method from original code of Swin-Transformer."""
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
# 2, Wh*Ww, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
# Wh*Ww, Wh*Ww, 2
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
return relative_position_index
def test_window_msa():
batch_size = 1
num_windows = (4, 4)
embed_dims = 96
window_size = (7, 7)
num_heads = 4
attn = WindowMSA(
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
window_size[0] * window_size[1], embed_dims))
# test forward
output = attn(inputs)
assert output.shape == inputs.shape
assert attn.relative_position_bias_table.shape == (
(2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
# test relative_position_bias_table init
attn.init_weights()
assert abs(attn.relative_position_bias_table).sum() > 0
# test non-square window_size
window_size = (6, 7)
attn = WindowMSA(
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
window_size[0] * window_size[1], embed_dims))
output = attn(inputs)
assert output.shape == inputs.shape
# test relative_position_index
expected_rel_pos_index = get_relative_position_index(window_size)
assert (attn.relative_position_index == expected_rel_pos_index).all()
# test qkv_bias=True
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qkv_bias=True)
assert attn.qkv.bias.shape == (embed_dims * 3, )
# test qkv_bias=False
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qkv_bias=False)
assert attn.qkv.bias is None
# test default qk_scale
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qk_scale=None)
head_dims = embed_dims // num_heads
assert np.isclose(attn.scale, head_dims**-0.5)
# test specified qk_scale
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
qk_scale=0.3)
assert attn.scale == 0.3
# test attn_drop
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
attn_drop=1.0)
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
window_size[0] * window_size[1], embed_dims))
# drop all attn output, output shuold be equal to proj.bias
assert torch.allclose(attn(inputs), attn.proj.bias)
# test prob_drop
attn = WindowMSA(
embed_dims=embed_dims,
window_size=window_size,
num_heads=num_heads,
proj_drop=1.0)
assert (attn(inputs) == 0).all()
def test_shift_window_msa():
batch_size = 1
embed_dims = 96
input_resolution = (14, 14)
num_heads = 4
window_size = 7
# test forward
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size)
inputs = torch.rand(
(batch_size, input_resolution[0] * input_resolution[1], embed_dims))
output = attn(inputs)
assert output.shape == (inputs.shape)
assert attn.w_msa.relative_position_bias_table.shape == ((2 * window_size -
1)**2, num_heads)
# test forward with shift_size
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=1)
output = attn(inputs)
assert output.shape == (inputs.shape)
# test relative_position_bias_table init
attn.init_weights()
assert abs(attn.w_msa.relative_position_bias_table).sum() > 0
# test dropout_layer
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
dropout_layer=dict(type='DropPath', drop_prob=0.5))
torch.manual_seed(0)
output = attn(inputs)
assert (output == 0).all()
# test auto_pad
input_resolution = (19, 18)
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
auto_pad=True)
assert attn.pad_r == 3
assert attn.pad_b == 2
# test small input_resolution
input_resolution = (5, 6)
attn = ShiftWindowMSA(
embed_dims=embed_dims,
input_resolution=input_resolution,
num_heads=num_heads,
window_size=window_size,
shift_size=3,
auto_pad=True)
assert attn.window_size == 5
assert attn.shift_size == 0