[Enhancement] Support value_proj_ratio in MultiScaleDeformableAttention (#2452)

* add ratio in ms_deform_attn_

* add ratio in ms_deform_attn

* Update mmcv/ops/multi_scale_deform_attn.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* Update tests/test_ops/test_ms_deformable_attn.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

* add ratio in ms_deform_attn

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
takuoko 2022-12-11 18:48:21 +09:00 committed by GitHub
parent fb39e1e568
commit 433607030a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 27 additions and 3 deletions

View File

@ -182,6 +182,8 @@ class MultiScaleDeformableAttention(BaseModule):
Default: None.
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
Default: None.
value_proj_ratio (float): The expansion ratio of value_proj.
Default: 1.0.
"""
def __init__(self,
@ -193,7 +195,8 @@ class MultiScaleDeformableAttention(BaseModule):
dropout: float = 0.1,
batch_first: bool = False,
norm_cfg: Optional[dict] = None,
init_cfg: Optional[mmengine.ConfigDict] = None):
init_cfg: Optional[mmengine.ConfigDict] = None,
value_proj_ratio: float = 1.0):
super().__init__(init_cfg)
if embed_dims % num_heads != 0:
raise ValueError(f'embed_dims must be divisible by num_heads, '
@ -228,8 +231,9 @@ class MultiScaleDeformableAttention(BaseModule):
embed_dims, num_heads * num_levels * num_points * 2)
self.attention_weights = nn.Linear(embed_dims,
num_heads * num_levels * num_points)
self.value_proj = nn.Linear(embed_dims, embed_dims)
self.output_proj = nn.Linear(embed_dims, embed_dims)
value_proj_size = int(embed_dims * value_proj_ratio)
self.value_proj = nn.Linear(embed_dims, value_proj_size)
self.output_proj = nn.Linear(value_proj_size, embed_dims)
self.init_weights()
def init_weights(self) -> None:

View File

@ -54,6 +54,26 @@ def test_multiscale_deformable_attention(device):
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)
# test with value_proj_ratio
embed_dims = 6
value_proj_ratio = 0.5
query = torch.rand(num_query, bs, embed_dims).to(device)
key = torch.rand(num_query, bs, embed_dims).to(device)
msda = MultiScaleDeformableAttention(
embed_dims=embed_dims,
num_levels=2,
num_heads=3,
value_proj_ratio=value_proj_ratio)
msda.init_weights()
msda.to(device)
msda(
query,
key,
key,
reference_points=reference_points,
spatial_shapes=spatial_shapes,
level_start_index=level_start_index)
def test_forward_multi_scale_deformable_attn_pytorch():
N, M, D = 1, 2, 2