mmpretrain/tests/test_models/test_utils/test_attention.py

190 lines
7.3 KiB
Python

# Copyright (c) OpenMMLab. All rights reserved.
from unittest import TestCase
from unittest.mock import ANY, MagicMock
import pytest
import torch
from mmpretrain.models.utils.attention import (ShiftWindowMSA, WindowMSA,
torch_meshgrid)
def get_relative_position_index(window_size):
"""Method from original code of Swin-Transformer."""
coords_h = torch.arange(window_size[0])
coords_w = torch.arange(window_size[1])
coords = torch.stack(torch_meshgrid([coords_h, coords_w])) # 2, Wh, Ww
coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
# 2, Wh*Ww, Wh*Ww
relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
# Wh*Ww, Wh*Ww, 2
relative_coords = relative_coords.permute(1, 2, 0).contiguous()
relative_coords[:, :, 0] += window_size[0] - 1 # shift to start from 0
relative_coords[:, :, 1] += window_size[1] - 1
relative_coords[:, :, 0] *= 2 * window_size[1] - 1
relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
return relative_position_index
class TestWindowMSA(TestCase):
def test_forward(self):
attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
inputs = torch.rand((16, 7 * 7, 96))
output = attn(inputs)
self.assertEqual(output.shape, inputs.shape)
# test non-square window_size
attn = WindowMSA(embed_dims=96, window_size=(6, 7), num_heads=4)
inputs = torch.rand((16, 6 * 7, 96))
output = attn(inputs)
self.assertEqual(output.shape, inputs.shape)
def test_relative_pos_embed(self):
attn = WindowMSA(embed_dims=96, window_size=(7, 8), num_heads=4)
self.assertEqual(attn.relative_position_bias_table.shape,
((2 * 7 - 1) * (2 * 8 - 1), 4))
# test relative_position_index
expected_rel_pos_index = get_relative_position_index((7, 8))
self.assertTrue(
torch.allclose(attn.relative_position_index,
expected_rel_pos_index))
# test default init
self.assertTrue(
torch.allclose(attn.relative_position_bias_table,
torch.tensor(0.)))
attn.init_weights()
self.assertFalse(
torch.allclose(attn.relative_position_bias_table,
torch.tensor(0.)))
def test_qkv_bias(self):
# test qkv_bias=True
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=True)
self.assertEqual(attn.qkv.bias.shape, (96 * 3, ))
# test qkv_bias=False
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, qkv_bias=False)
self.assertIsNone(attn.qkv.bias)
def tets_qk_scale(self):
# test default qk_scale
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=None)
head_dims = 96 // 4
self.assertAlmostEqual(attn.scale, head_dims**-0.5)
# test specified qk_scale
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, qk_scale=0.3)
self.assertEqual(attn.scale, 0.3)
def test_attn_drop(self):
inputs = torch.rand(16, 7 * 7, 96)
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, attn_drop=1.0)
# drop all attn output, output shuold be equal to proj.bias
self.assertTrue(torch.allclose(attn(inputs), attn.proj.bias))
def test_prob_drop(self):
inputs = torch.rand(16, 7 * 7, 96)
attn = WindowMSA(
embed_dims=96, window_size=(7, 7), num_heads=4, proj_drop=1.0)
self.assertTrue(torch.allclose(attn(inputs), torch.tensor(0.)))
def test_mask(self):
inputs = torch.rand(16, 7 * 7, 96)
attn = WindowMSA(embed_dims=96, window_size=(7, 7), num_heads=4)
mask = torch.zeros((4, 49, 49))
# Mask the first column
mask[:, 0, :] = -100
mask[:, :, 0] = -100
outs = attn(inputs, mask=mask)
inputs[:, 0, :].normal_()
outs_with_mask = attn(inputs, mask=mask)
torch.testing.assert_allclose(outs[:, 1:, :], outs_with_mask[:, 1:, :])
class TestShiftWindowMSA(TestCase):
def test_forward(self):
inputs = torch.rand((1, 14 * 14, 96))
attn = ShiftWindowMSA(embed_dims=96, window_size=7, num_heads=4)
output = attn(inputs, (14, 14))
self.assertEqual(output.shape, inputs.shape)
self.assertEqual(attn.w_msa.relative_position_bias_table.shape,
((2 * 7 - 1)**2, 4))
# test forward with shift_size
attn = ShiftWindowMSA(
embed_dims=96, window_size=7, num_heads=4, shift_size=3)
output = attn(inputs, (14, 14))
assert output.shape == (inputs.shape)
# test irregular input shape
input_resolution = (19, 18)
attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7)
inputs = torch.rand((1, 19 * 18, 96))
output = attn(inputs, input_resolution)
assert output.shape == (inputs.shape)
# test wrong input_resolution
input_resolution = (14, 14)
attn = ShiftWindowMSA(embed_dims=96, num_heads=4, window_size=7)
inputs = torch.rand((1, 14 * 14, 96))
with pytest.raises(AssertionError):
attn(inputs, (14, 15))
def test_pad_small_map(self):
# test pad_small_map=True
inputs = torch.rand((1, 6 * 7, 96))
attn = ShiftWindowMSA(
embed_dims=96,
window_size=7,
num_heads=4,
shift_size=3,
pad_small_map=True)
attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask)
output = attn(inputs, (6, 7))
self.assertEqual(output.shape, inputs.shape)
attn.get_attn_mask.assert_called_once_with((7, 7),
window_size=7,
shift_size=3,
device=ANY)
# test pad_small_map=False
inputs = torch.rand((1, 6 * 7, 96))
attn = ShiftWindowMSA(
embed_dims=96,
window_size=7,
num_heads=4,
shift_size=3,
pad_small_map=False)
with self.assertRaisesRegex(AssertionError, r'the window size \(7\)'):
attn(inputs, (6, 7))
# test pad_small_map=False, and the input size equals to window size
inputs = torch.rand((1, 7 * 7, 96))
attn.get_attn_mask = MagicMock(wraps=attn.get_attn_mask)
output = attn(inputs, (7, 7))
self.assertEqual(output.shape, inputs.shape)
attn.get_attn_mask.assert_called_once_with((7, 7),
window_size=7,
shift_size=0,
device=ANY)
def test_drop_layer(self):
inputs = torch.rand((1, 14 * 14, 96))
attn = ShiftWindowMSA(
embed_dims=96,
window_size=7,
num_heads=4,
dropout_layer=dict(type='Dropout', drop_prob=1.0))
attn.init_weights()
# drop all attn output, output shuold be equal to proj.bias
self.assertTrue(
torch.allclose(attn(inputs, (14, 14)), torch.tensor(0.)))