190 lines
7.3 KiB
Python
190 lines
7.3 KiB
Python
# Copyright (c) OpenMMLab. All rights reserved.
|
|
from unittest import TestCase
|
|
from unittest.mock import ANY, MagicMock
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from mmpretrain.models.utils.attention import (ShiftWindowMSA, WindowMSA,
|
|
torch_meshgrid)
|
|
|
|
|
|
def get_relative_position_index(window_size):
|
|
"""Method from original code of Swin-Transformer."""
|
|
coords_h = torch.arange(window_size[0])
|
|
coords_w = torch.arange(window_size[1])
|
|
coords = torch.stack(torch_meshgrid([coords_h, coords_w])) # 2, Wh, Ww
|
|
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
|
|
# 2, Wh*Ww, Wh*Ww
|
|
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
|
|
# Wh*Ww, Wh*Ww, 2
|
|
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
|
|
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
|
|
relative_coords[:, :, 1] += window_size[1] - 1
|
|
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
|
|
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
|
|
return relative_position_index
|
|
|
|
|
|
class TestWindowMSA(TestCase):
|
|
|
|
def test_forward(self):
|
|
attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
|
|
inputs = torch.rand((16, 7 * 7, 96))
|
|
output = attn(inputs)
|
|
self.assertEqual(output.shape, inputs.shape)
|
|
|
|
# test non-square window_size
|
|
attn = WindowMSA(embed_dims=96, window_size=(6, 7), num_heads=4)
|
|
inputs = torch.rand((16, 6 * 7, 96))
|
|
output = attn(inputs)
|
|
self.assertEqual(output.shape, inputs.shape)
|
|
|
|
def test_relative_pos_embed(self):
|
|
attn = WindowMSA(embed_dims=96, window_size=(7, 8), num_heads=4)
|
|
self.assertEqual(attn.relative_position_bias_table.shape,
|
|
((2 * 7 - 1) * (2 * 8 - 1), 4))
|
|
# test relative_position_index
|
|
expected_rel_pos_index = get_relative_position_index((7, 8))
|
|
self.assertTrue(
|
|
torch.allclose(attn.relative_position_index,
|
|
expected_rel_pos_index))
|
|
|
|
# test default init
|
|
self.assertTrue(
|
|
torch.allclose(attn.relative_position_bias_table,
|
|
torch.tensor(0.)))
|
|
attn.init_weights()
|
|
self.assertFalse(
|
|
torch.allclose(attn.relative_position_bias_table,
|
|
torch.tensor(0.)))
|
|
|
|
def test_qkv_bias(self):
|
|
# test qkv_bias=True
|
|
attn = WindowMSA(
|
|
embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=True)
|
|
self.assertEqual(attn.qkv.bias.shape, (96 * 3, ))
|
|
|
|
# test qkv_bias=False
|
|
attn = WindowMSA(
|
|
embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=False)
|
|
self.assertIsNone(attn.qkv.bias)
|
|
|
|
def tets_qk_scale(self):
|
|
# test default qk_scale
|
|
attn = WindowMSA(
|
|
embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=None)
|
|
head_dims = 96 // 4
|
|
self.assertAlmostEqual(attn.scale, head_dims**-0.5)
|
|
|
|
# test specified qk_scale
|
|
attn = WindowMSA(
|
|
embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=0.3)
|
|
self.assertEqual(attn.scale, 0.3)
|
|
|
|
def test_attn_drop(self):
|
|
inputs = torch.rand(16, 7 * 7, 96)
|
|
attn = WindowMSA(
|
|
embed_dims=96, window_size=(7, 7), num_heads=4, attn_drop=1.0)
|
|
# drop all attn output, output shuold be equal to proj.bias
|
|
self.assertTrue(torch.allclose(attn(inputs), attn.proj.bias))
|
|
|
|
def test_prob_drop(self):
|
|
inputs = torch.rand(16, 7 * 7, 96)
|
|
attn = WindowMSA(
|
|
embed_dims=96, window_size=(7, 7), num_heads=4, proj_drop=1.0)
|
|
self.assertTrue(torch.allclose(attn(inputs), torch.tensor(0.)))
|
|
|
|
def test_mask(self):
|
|
inputs = torch.rand(16, 7 * 7, 96)
|
|
attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
|
|
mask = torch.zeros((4, 49, 49))
|
|
# Mask the first column
|
|
mask[:, 0, :] = -100
|
|
mask[:, :, 0] = -100
|
|
outs = attn(inputs, mask=mask)
|
|
inputs[:, 0, :].normal_()
|
|
outs_with_mask = attn(inputs, mask=mask)
|
|
torch.testing.assert_allclose(outs[:, 1:, :], outs_with_mask[:, 1:, :])
|
|
|
|
|
|
class TestShiftWindowMSA(TestCase):
|
|
|
|
def test_forward(self):
|
|
inputs = torch.rand((1, 14 * 14, 96))
|
|
attn = ShiftWindowMSA(embed_dims=96, window_size=7, num_heads=4)
|
|
output = attn(inputs, (14, 14))
|
|
self.assertEqual(output.shape, inputs.shape)
|
|
self.assertEqual(attn.w_msa.relative_position_bias_table.shape,
|
|
((2 * 7 - 1)**2, 4))
|
|
|
|
# test forward with shift_size
|
|
attn = ShiftWindowMSA(
|
|
embed_dims=96, window_size=7, num_heads=4, shift_size=3)
|
|
output = attn(inputs, (14, 14))
|
|
assert output.shape == (inputs.shape)
|
|
|
|
# test irregular input shape
|
|
input_resolution = (19, 18)
|
|
attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7)
|
|
inputs = torch.rand((1, 19 * 18, 96))
|
|
output = attn(inputs, input_resolution)
|
|
assert output.shape == (inputs.shape)
|
|
|
|
# test wrong input_resolution
|
|
input_resolution = (14, 14)
|
|
attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7)
|
|
inputs = torch.rand((1, 14 * 14, 96))
|
|
with pytest.raises(AssertionError):
|
|
attn(inputs, (14, 15))
|
|
|
|
def test_pad_small_map(self):
|
|
# test pad_small_map=True
|
|
inputs = torch.rand((1, 6 * 7, 96))
|
|
attn = ShiftWindowMSA(
|
|
embed_dims=96,
|
|
window_size=7,
|
|
num_heads=4,
|
|
shift_size=3,
|
|
pad_small_map=True)
|
|
attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask)
|
|
output = attn(inputs, (6, 7))
|
|
self.assertEqual(output.shape, inputs.shape)
|
|
attn.get_attn_mask.assert_called_once_with((7, 7),
|
|
window_size=7,
|
|
shift_size=3,
|
|
device=ANY)
|
|
|
|
# test pad_small_map=False
|
|
inputs = torch.rand((1, 6 * 7, 96))
|
|
attn = ShiftWindowMSA(
|
|
embed_dims=96,
|
|
window_size=7,
|
|
num_heads=4,
|
|
shift_size=3,
|
|
pad_small_map=False)
|
|
with self.assertRaisesRegex(AssertionError, r'the window size \(7\)'):
|
|
attn(inputs, (6, 7))
|
|
|
|
# test pad_small_map=False, and the input size equals to window size
|
|
inputs = torch.rand((1, 7 * 7, 96))
|
|
attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask)
|
|
output = attn(inputs, (7, 7))
|
|
self.assertEqual(output.shape, inputs.shape)
|
|
attn.get_attn_mask.assert_called_once_with((7, 7),
|
|
window_size=7,
|
|
shift_size=0,
|
|
device=ANY)
|
|
|
|
def test_drop_layer(self):
|
|
inputs = torch.rand((1, 14 * 14, 96))
|
|
attn = ShiftWindowMSA(
|
|
embed_dims=96,
|
|
window_size=7,
|
|
num_heads=4,
|
|
dropout_layer=dict(type='Dropout', drop_prob=1.0))
|
|
attn.init_weights()
|
|
# drop all attn output, output shuold be equal to proj.bias
|
|
self.assertTrue(
|
|
torch.allclose(attn(inputs, (14, 14)), torch.tensor(0.)))
|