import copy import pytest import torch from mmcv.cnn.bricks.drop import DropPath from mmcv.cnn.bricks.transformer import (FFN, AdaptivePadding, BaseTransformerLayer, MultiheadAttention, PatchEmbed, PatchMerging, TransformerLayerSequence) from mmcv.runner import ModuleList def test_adaptive_padding(): for padding in ('same', 'corner'): kernel_size = 16 stride = 16 dilation = 1 input = torch.rand(1, 1, 15, 17) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) out = adap_pad(input) # padding to divisible by 16 assert (out.shape[2], out.shape[3]) == (16, 32) input = torch.rand(1, 1, 16, 17) out = adap_pad(input) # padding to divisible by 16 assert (out.shape[2], out.shape[3]) == (16, 32) kernel_size = (2, 2) stride = (2, 2) dilation = (1, 1) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) input = torch.rand(1, 1, 11, 13) out = adap_pad(input) # padding to divisible by 2 assert (out.shape[2], out.shape[3]) == (12, 14) kernel_size = (2, 2) stride = (10, 10) dilation = (1, 1) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) input = torch.rand(1, 1, 10, 13) out = adap_pad(input) # no padding assert (out.shape[2], out.shape[3]) == (10, 13) kernel_size = (11, 11) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) input = torch.rand(1, 1, 11, 13) out = adap_pad(input) # all padding assert (out.shape[2], out.shape[3]) == (21, 21) # test padding as kernel is (7,9) input = torch.rand(1, 1, 11, 13) stride = (3, 4) kernel_size = (4, 5) dilation = (2, 2) # actually (7, 9) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) dilation_out = adap_pad(input) assert (dilation_out.shape[2], dilation_out.shape[3]) == (16, 21) kernel_size = (7, 9) dilation = (1, 1) adap_pad = AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=padding) kernel79_out = adap_pad(input) assert (kernel79_out.shape[2], kernel79_out.shape[3]) == (16, 21) assert kernel79_out.shape == dilation_out.shape # assert only support "same" "corner" with pytest.raises(AssertionError): AdaptivePadding( kernel_size=kernel_size, stride=stride, dilation=dilation, padding=1) def test_patch_embed(): B = 2 H = 3 W = 4 C = 3 embed_dims = 10 kernel_size = 3 stride = 1 dummy_input = torch.rand(B, C, H, W) patch_merge_1 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=1, norm_cfg=None) x1, shape = patch_merge_1(dummy_input) # test out shape assert x1.shape == (2, 2, 10) # test outsize is correct assert shape == (1, 2) # test L = out_h * out_w assert shape[0] * shape[1] == x1.shape[1] B = 2 H = 10 W = 10 C = 3 embed_dims = 10 kernel_size = 5 stride = 2 dummy_input = torch.rand(B, C, H, W) # test dilation patch_merge_2 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=2, norm_cfg=None, ) x2, shape = patch_merge_2(dummy_input) # test out shape assert x2.shape == (2, 1, 10) # test outsize is correct assert shape == (1, 1) # test L = out_h * out_w assert shape[0] * shape[1] == x2.shape[1] stride = 2 input_size = (10, 10) dummy_input = torch.rand(B, C, H, W) # test stride and norm patch_merge_3 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=2, norm_cfg=dict(type='LN'), input_size=input_size) x3, shape = patch_merge_3(dummy_input) # test out shape assert x3.shape == (2, 1, 10) # test outsize is correct assert shape == (1, 1) # test L = out_h * out_w assert shape[0] * shape[1] == x3.shape[1] # test the init_out_size with nn.Unfold assert patch_merge_3.init_out_size[1] == (input_size[0] - 2 * 4 - 1) // 2 + 1 assert patch_merge_3.init_out_size[0] == (input_size[0] - 2 * 4 - 1) // 2 + 1 H = 11 W = 12 input_size = (H, W) dummy_input = torch.rand(B, C, H, W) # test stride and norm patch_merge_3 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=2, norm_cfg=dict(type='LN'), input_size=input_size) _, shape = patch_merge_3(dummy_input) # when input_size equal to real input # the out_size should be equal to `init_out_size` assert shape == patch_merge_3.init_out_size input_size = (H, W) dummy_input = torch.rand(B, C, H, W) # test stride and norm patch_merge_3 = PatchEmbed( in_channels=C, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=0, dilation=2, norm_cfg=dict(type='LN'), input_size=input_size) _, shape = patch_merge_3(dummy_input) # when input_size equal to real input # the out_size should be equal to `init_out_size` assert shape == patch_merge_3.init_out_size # test adap padding for padding in ('same', 'corner'): in_c = 2 embed_dims = 3 B = 2 # test stride is 1 input_size = (5, 5) kernel_size = (5, 5) stride = (1, 1) dilation = 1 bias = False x = torch.rand(B, in_c, *input_size) patch_embed = PatchEmbed( in_channels=in_c, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 25, 3) assert out_size == (5, 5) assert x_out.size(1) == out_size[0] * out_size[1] # test kernel_size == stride input_size = (5, 5) kernel_size = (5, 5) stride = (5, 5) dilation = 1 bias = False x = torch.rand(B, in_c, *input_size) patch_embed = PatchEmbed( in_channels=in_c, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 1, 3) assert out_size == (1, 1) assert x_out.size(1) == out_size[0] * out_size[1] # test kernel_size == stride input_size = (6, 5) kernel_size = (5, 5) stride = (5, 5) dilation = 1 bias = False x = torch.rand(B, in_c, *input_size) patch_embed = PatchEmbed( in_channels=in_c, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 2, 3) assert out_size == (2, 1) assert x_out.size(1) == out_size[0] * out_size[1] # test different kernel_size with different stride input_size = (6, 5) kernel_size = (6, 2) stride = (6, 2) dilation = 1 bias = False x = torch.rand(B, in_c, *input_size) patch_embed = PatchEmbed( in_channels=in_c, embed_dims=embed_dims, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_embed(x) assert x_out.size() == (B, 3, 3) assert out_size == (1, 3) assert x_out.size(1) == out_size[0] * out_size[1] def test_patch_merging(): # Test the model with int padding in_c = 3 out_c = 4 kernel_size = 3 stride = 3 padding = 1 dilation = 1 bias = False # test the case `pad_to_stride` is False patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) B, L, C = 1, 100, 3 input_size = (10, 10) x = torch.rand(B, L, C) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (1, 16, 4) assert out_size == (4, 4) # assert out size is consistent with real output assert x_out.size(1) == out_size[0] * out_size[1] in_c = 4 out_c = 5 kernel_size = 6 stride = 3 padding = 2 dilation = 2 bias = False patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) B, L, C = 1, 100, 4 input_size = (10, 10) x = torch.rand(B, L, C) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (1, 4, 5) assert out_size == (2, 2) # assert out size is consistent with real output assert x_out.size(1) == out_size[0] * out_size[1] # Test with adaptive padding for padding in ('same', 'corner'): in_c = 2 out_c = 3 B = 2 # test stride is 1 input_size = (5, 5) kernel_size = (5, 5) stride = (1, 1) dilation = 1 bias = False L = input_size[0] * input_size[1] x = torch.rand(B, L, in_c) patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 25, 3) assert out_size == (5, 5) assert x_out.size(1) == out_size[0] * out_size[1] # test kernel_size == stride input_size = (5, 5) kernel_size = (5, 5) stride = (5, 5) dilation = 1 bias = False L = input_size[0] * input_size[1] x = torch.rand(B, L, in_c) patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 1, 3) assert out_size == (1, 1) assert x_out.size(1) == out_size[0] * out_size[1] # test kernel_size == stride input_size = (6, 5) kernel_size = (5, 5) stride = (5, 5) dilation = 1 bias = False L = input_size[0] * input_size[1] x = torch.rand(B, L, in_c) patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 2, 3) assert out_size == (2, 1) assert x_out.size(1) == out_size[0] * out_size[1] # test different kernel_size with different stride input_size = (6, 5) kernel_size = (6, 2) stride = (6, 2) dilation = 1 bias = False L = input_size[0] * input_size[1] x = torch.rand(B, L, in_c) patch_merge = PatchMerging( in_channels=in_c, out_channels=out_c, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, bias=bias) x_out, out_size = patch_merge(x, input_size) assert x_out.size() == (B, 3, 3) assert out_size == (1, 3) assert x_out.size(1) == out_size[0] * out_size[1] def test_multiheadattention(): MultiheadAttention( embed_dims=5, num_heads=5, attn_drop=0, proj_drop=0, dropout_layer=dict(type='Dropout', drop_prob=0.), batch_first=True) batch_dim = 2 embed_dim = 5 num_query = 100 attn_batch_first = MultiheadAttention( embed_dims=5, num_heads=5, attn_drop=0, proj_drop=0, dropout_layer=dict(type='DropPath', drop_prob=0.), batch_first=True) attn_query_first = MultiheadAttention( embed_dims=5, num_heads=5, attn_drop=0, proj_drop=0, dropout_layer=dict(type='DropPath', drop_prob=0.), batch_first=False) param_dict = dict(attn_query_first.named_parameters()) for n, v in attn_batch_first.named_parameters(): param_dict[n].data = v.data input_batch_first = torch.rand(batch_dim, num_query, embed_dim) input_query_first = input_batch_first.transpose(0, 1) assert torch.allclose( attn_query_first(input_query_first).sum(), attn_batch_first(input_batch_first).sum()) key_batch_first = torch.rand(batch_dim, num_query, embed_dim) key_query_first = key_batch_first.transpose(0, 1) assert torch.allclose( attn_query_first(input_query_first, key_query_first).sum(), attn_batch_first(input_batch_first, key_batch_first).sum()) identity = torch.ones_like(input_query_first) # check deprecated arguments can be used normally assert torch.allclose( attn_query_first( input_query_first, key_query_first, residual=identity).sum(), attn_batch_first(input_batch_first, key_batch_first).sum() + identity.sum() - input_batch_first.sum()) assert torch.allclose( attn_query_first( input_query_first, key_query_first, identity=identity).sum(), attn_batch_first(input_batch_first, key_batch_first).sum() + identity.sum() - input_batch_first.sum()) attn_query_first( input_query_first, key_query_first, identity=identity).sum(), def test_ffn(): with pytest.raises(AssertionError): # num_fcs should be no less than 2 FFN(num_fcs=1) FFN(dropout=0, add_residual=True) ffn = FFN(dropout=0, add_identity=True) input_tensor = torch.rand(2, 20, 256) input_tensor_nbc = input_tensor.transpose(0, 1) assert torch.allclose(ffn(input_tensor).sum(), ffn(input_tensor_nbc).sum()) residual = torch.rand_like(input_tensor) torch.allclose( ffn(input_tensor, residual=residual).sum(), ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) torch.allclose( ffn(input_tensor, identity=residual).sum(), ffn(input_tensor).sum() + residual.sum() - input_tensor.sum()) @pytest.mark.skipif(not torch.cuda.is_available(), reason='Cuda not available') def test_basetransformerlayer_cuda(): # To test if the BaseTransformerLayer's behaviour remains # consistent after being deepcopied operation_order = ('self_attn', 'ffn') baselayer = BaseTransformerLayer( operation_order=operation_order, batch_first=True, attn_cfgs=dict( type='MultiheadAttention', embed_dims=256, num_heads=8, ), ) baselayers = ModuleList([copy.deepcopy(baselayer) for _ in range(2)]) baselayers.to('cuda') x = torch.rand(2, 10, 256).cuda() for m in baselayers: x = m(x) assert x.shape == torch.Size([2, 10, 256]) @pytest.mark.parametrize('embed_dims', [False, 256]) def test_basetransformerlayer(embed_dims): attn_cfgs = dict(type='MultiheadAttention', embed_dims=256, num_heads=8), if embed_dims: ffn_cfgs = dict( type='FFN', embed_dims=embed_dims, feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ) else: ffn_cfgs = dict( type='FFN', feedforward_channels=1024, num_fcs=2, ffn_drop=0., act_cfg=dict(type='ReLU', inplace=True), ) feedforward_channels = 2048 ffn_dropout = 0.1 operation_order = ('self_attn', 'norm', 'ffn', 'norm') # test deprecated_args baselayer = BaseTransformerLayer( attn_cfgs=attn_cfgs, ffn_cfgs=ffn_cfgs, feedforward_channels=feedforward_channels, ffn_dropout=ffn_dropout, operation_order=operation_order) assert baselayer.batch_first is False assert baselayer.ffns[0].feedforward_channels == feedforward_channels attn_cfgs = dict(type='MultiheadAttention', num_heads=8, embed_dims=256), feedforward_channels = 2048 ffn_dropout = 0.1 operation_order = ('self_attn', 'norm', 'ffn', 'norm') baselayer = BaseTransformerLayer( attn_cfgs=attn_cfgs, feedforward_channels=feedforward_channels, ffn_dropout=ffn_dropout, operation_order=operation_order, batch_first=True) assert baselayer.attentions[0].batch_first in_tensor = torch.rand(2, 10, 256) baselayer(in_tensor) def test_transformerlayersequence(): squeue = TransformerLayerSequence( num_layers=6, transformerlayers=dict( type='BaseTransformerLayer', attn_cfgs=[ dict( type='MultiheadAttention', embed_dims=256, num_heads=8, dropout=0.1), dict(type='MultiheadAttention', embed_dims=256, num_heads=4) ], feedforward_channels=1024, ffn_dropout=0.1, operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm'))) assert len(squeue.layers) == 6 assert squeue.pre_norm is False with pytest.raises(AssertionError): # if transformerlayers is a list, len(transformerlayers) # should be equal to num_layers TransformerLayerSequence( num_layers=6, transformerlayers=[ dict( type='BaseTransformerLayer', attn_cfgs=[ dict( type='MultiheadAttention', embed_dims=256, num_heads=8, dropout=0.1), dict(type='MultiheadAttention', embed_dims=256) ], feedforward_channels=1024, ffn_dropout=0.1, operation_order=('self_attn', 'norm', 'cross_attn', 'norm', 'ffn', 'norm')) ]) def test_drop_path(): drop_path = DropPath(drop_prob=0) test_in = torch.rand(2, 3, 4, 5) assert test_in is drop_path(test_in) drop_path = DropPath(drop_prob=0.1) drop_path.training = False test_in = torch.rand(2, 3, 4, 5) assert test_in is drop_path(test_in) drop_path.training = True assert test_in is not drop_path(test_in)