mirror of https://github.com/open-mmlab/mmcv.git
63 lines
2.3 KiB
Python
63 lines
2.3 KiB
Python
|
import torch
|
||
|
|
||
|
from mmcv.cnn.bricks import GeneralizedAttention
|
||
|
|
||
|
|
||
|
def test_context_block():
|
||
|
|
||
|
# test attention_type='1000'
|
||
|
imgs = torch.randn(2, 16, 20, 20)
|
||
|
gen_attention_block = GeneralizedAttention(16, attention_type='1000')
|
||
|
assert gen_attention_block.query_conv.in_channels == 16
|
||
|
assert gen_attention_block.key_conv.in_channels == 16
|
||
|
assert gen_attention_block.key_conv.in_channels == 16
|
||
|
out = gen_attention_block(imgs)
|
||
|
assert out.shape == imgs.shape
|
||
|
|
||
|
# test attention_type='0100'
|
||
|
imgs = torch.randn(2, 16, 20, 20)
|
||
|
gen_attention_block = GeneralizedAttention(16, attention_type='0100')
|
||
|
assert gen_attention_block.query_conv.in_channels == 16
|
||
|
assert gen_attention_block.appr_geom_fc_x.in_features == 8
|
||
|
assert gen_attention_block.appr_geom_fc_y.in_features == 8
|
||
|
out = gen_attention_block(imgs)
|
||
|
assert out.shape == imgs.shape
|
||
|
|
||
|
# test attention_type='0010'
|
||
|
imgs = torch.randn(2, 16, 20, 20)
|
||
|
gen_attention_block = GeneralizedAttention(16, attention_type='0010')
|
||
|
assert gen_attention_block.key_conv.in_channels == 16
|
||
|
assert hasattr(gen_attention_block, 'appr_bias')
|
||
|
out = gen_attention_block(imgs)
|
||
|
assert out.shape == imgs.shape
|
||
|
|
||
|
# test attention_type='0001'
|
||
|
imgs = torch.randn(2, 16, 20, 20)
|
||
|
gen_attention_block = GeneralizedAttention(16, attention_type='0001')
|
||
|
assert gen_attention_block.appr_geom_fc_x.in_features == 8
|
||
|
assert gen_attention_block.appr_geom_fc_y.in_features == 8
|
||
|
assert hasattr(gen_attention_block, 'geom_bias')
|
||
|
out = gen_attention_block(imgs)
|
||
|
assert out.shape == imgs.shape
|
||
|
|
||
|
# test spatial_range >= 0
|
||
|
imgs = torch.randn(2, 256, 20, 20)
|
||
|
gen_attention_block = GeneralizedAttention(256, spatial_range=10)
|
||
|
assert hasattr(gen_attention_block, 'local_constraint_map')
|
||
|
out = gen_attention_block(imgs)
|
||
|
assert out.shape == imgs.shape
|
||
|
|
||
|
# test q_stride > 1
|
||
|
imgs = torch.randn(2, 16, 20, 20)
|
||
|
gen_attention_block = GeneralizedAttention(16, q_stride=2)
|
||
|
assert gen_attention_block.q_downsample is not None
|
||
|
out = gen_attention_block(imgs)
|
||
|
assert out.shape == imgs.shape
|
||
|
|
||
|
# test kv_stride > 1
|
||
|
imgs = torch.randn(2, 16, 20, 20)
|
||
|
gen_attention_block = GeneralizedAttention(16, kv_stride=2)
|
||
|
assert gen_attention_block.kv_downsample is not None
|
||
|
out = gen_attention_block(imgs)
|
||
|
assert out.shape == imgs.shape
|