179 lines
5.6 KiB
Python
179 lines
5.6 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
import numpy as np
|
|
import torch
|
|
|
|
from mmcls.models.utils.attention import ShiftWindowMSA, WindowMSA
|
|
|
|
|
|
def get_relative_position_index(window_size):
|
|
"""Method from original code of Swin-Transformer."""
|
|
coords_h = torch.arange(window_size[0])
|
|
coords_w = torch.arange(window_size[1])
|
|
coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
# 2, Wh*Ww, Wh*Ww
|
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
# Wh*Ww, Wh*Ww, 2
|
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
|
relative_coords[:, :, 1] += window_size[1] - 1
|
|
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
return relative_position_index
|
|
|
|
|
|
def test_window_msa():
|
|
batch_size = 1
|
|
num_windows = (4, 4)
|
|
embed_dims = 96
|
|
window_size = (7, 7)
|
|
num_heads = 4
|
|
attn = WindowMSA(
|
|
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
|
|
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
|
|
window_size[0] * window_size[1], embed_dims))
|
|
|
|
# test forward
|
|
output = attn(inputs)
|
|
assert output.shape == inputs.shape
|
|
assert attn.relative_position_bias_table.shape == (
|
|
(2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
|
|
|
|
# test relative_position_bias_table init
|
|
attn.init_weights()
|
|
assert abs(attn.relative_position_bias_table).sum() > 0
|
|
|
|
# test non-square window_size
|
|
window_size = (6, 7)
|
|
attn = WindowMSA(
|
|
embed_dims=embed_dims, window_size=window_size, num_heads=num_heads)
|
|
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
|
|
window_size[0] * window_size[1], embed_dims))
|
|
output = attn(inputs)
|
|
assert output.shape == inputs.shape
|
|
|
|
# test relative_position_index
|
|
expected_rel_pos_index = get_relative_position_index(window_size)
|
|
assert (attn.relative_position_index == expected_rel_pos_index).all()
|
|
|
|
# test qkv_bias=True
|
|
attn = WindowMSA(
|
|
embed_dims=embed_dims,
|
|
window_size=window_size,
|
|
num_heads=num_heads,
|
|
qkv_bias=True)
|
|
assert attn.qkv.bias.shape == (embed_dims * 3, )
|
|
|
|
# test qkv_bias=False
|
|
attn = WindowMSA(
|
|
embed_dims=embed_dims,
|
|
window_size=window_size,
|
|
num_heads=num_heads,
|
|
qkv_bias=False)
|
|
assert attn.qkv.bias is None
|
|
|
|
# test default qk_scale
|
|
attn = WindowMSA(
|
|
embed_dims=embed_dims,
|
|
window_size=window_size,
|
|
num_heads=num_heads,
|
|
qk_scale=None)
|
|
head_dims = embed_dims // num_heads
|
|
assert np.isclose(attn.scale, head_dims**-0.5)
|
|
|
|
# test specified qk_scale
|
|
attn = WindowMSA(
|
|
embed_dims=embed_dims,
|
|
window_size=window_size,
|
|
num_heads=num_heads,
|
|
qk_scale=0.3)
|
|
assert attn.scale == 0.3
|
|
|
|
# test attn_drop
|
|
attn = WindowMSA(
|
|
embed_dims=embed_dims,
|
|
window_size=window_size,
|
|
num_heads=num_heads,
|
|
attn_drop=1.0)
|
|
inputs = torch.rand((batch_size * num_windows[0] * num_windows[1],
|
|
window_size[0] * window_size[1], embed_dims))
|
|
# drop all attn output, output shuold be equal to proj.bias
|
|
assert torch.allclose(attn(inputs), attn.proj.bias)
|
|
|
|
# test prob_drop
|
|
attn = WindowMSA(
|
|
embed_dims=embed_dims,
|
|
window_size=window_size,
|
|
num_heads=num_heads,
|
|
proj_drop=1.0)
|
|
assert (attn(inputs) == 0).all()
|
|
|
|
|
|
def test_shift_window_msa():
|
|
batch_size = 1
|
|
embed_dims = 96
|
|
input_resolution = (14, 14)
|
|
num_heads = 4
|
|
window_size = 7
|
|
|
|
# test forward
|
|
attn = ShiftWindowMSA(
|
|
embed_dims=embed_dims,
|
|
input_resolution=input_resolution,
|
|
num_heads=num_heads,
|
|
window_size=window_size)
|
|
inputs = torch.rand(
|
|
(batch_size, input_resolution[0] * input_resolution[1], embed_dims))
|
|
output = attn(inputs)
|
|
assert output.shape == (inputs.shape)
|
|
assert attn.w_msa.relative_position_bias_table.shape == ((2 * window_size -
|
|
1)**2, num_heads)
|
|
|
|
# test forward with shift_size
|
|
attn = ShiftWindowMSA(
|
|
embed_dims=embed_dims,
|
|
input_resolution=input_resolution,
|
|
num_heads=num_heads,
|
|
window_size=window_size,
|
|
shift_size=1)
|
|
output = attn(inputs)
|
|
assert output.shape == (inputs.shape)
|
|
|
|
# test relative_position_bias_table init
|
|
attn.init_weights()
|
|
assert abs(attn.w_msa.relative_position_bias_table).sum() > 0
|
|
|
|
# test dropout_layer
|
|
attn = ShiftWindowMSA(
|
|
embed_dims=embed_dims,
|
|
input_resolution=input_resolution,
|
|
num_heads=num_heads,
|
|
window_size=window_size,
|
|
dropout_layer=dict(type='DropPath', drop_prob=0.5))
|
|
torch.manual_seed(0)
|
|
output = attn(inputs)
|
|
assert (output == 0).all()
|
|
|
|
# test auto_pad
|
|
input_resolution = (19, 18)
|
|
attn = ShiftWindowMSA(
|
|
embed_dims=embed_dims,
|
|
input_resolution=input_resolution,
|
|
num_heads=num_heads,
|
|
window_size=window_size,
|
|
auto_pad=True)
|
|
assert attn.pad_r == 3
|
|
assert attn.pad_b == 2
|
|
|
|
# test small input_resolution
|
|
input_resolution = (5, 6)
|
|
attn = ShiftWindowMSA(
|
|
embed_dims=embed_dims,
|
|
input_resolution=input_resolution,
|
|
num_heads=num_heads,
|
|
window_size=window_size,
|
|
shift_size=3,
|
|
auto_pad=True)
|
|
assert attn.window_size == 5
|
|
assert attn.shift_size == 0
|