diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py index 509ae5f98..c1d415621 100644 --- a/mmcv/ops/multi_scale_deform_attn.py +++ b/mmcv/ops/multi_scale_deform_attn.py @@ -182,6 +182,8 @@ class MultiScaleDeformableAttention(BaseModule): Default: None. init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization. Default: None. + value_proj_ratio (float): The expansion ratio of value_proj. + Default: 1.0. """ def __init__(self, @@ -193,7 +195,8 @@ class MultiScaleDeformableAttention(BaseModule): dropout: float = 0.1, batch_first: bool = False, norm_cfg: Optional[dict] = None, - init_cfg: Optional[mmengine.ConfigDict] = None): + init_cfg: Optional[mmengine.ConfigDict] = None, + value_proj_ratio: float = 1.0): super().__init__(init_cfg) if embed_dims % num_heads != 0: raise ValueError(f'embed_dims must be divisible by num_heads, ' @@ -228,8 +231,9 @@ class MultiScaleDeformableAttention(BaseModule): embed_dims, num_heads * num_levels * num_points * 2) self.attention_weights = nn.Linear(embed_dims, num_heads * num_levels * num_points) - self.value_proj = nn.Linear(embed_dims, embed_dims) - self.output_proj = nn.Linear(embed_dims, embed_dims) + value_proj_size = int(embed_dims * value_proj_ratio) + self.value_proj = nn.Linear(embed_dims, value_proj_size) + self.output_proj = nn.Linear(value_proj_size, embed_dims) self.init_weights() def init_weights(self) -> None: diff --git a/tests/test_ops/test_ms_deformable_attn.py b/tests/test_ops/test_ms_deformable_attn.py index 94223a642..a29380552 100644 --- a/tests/test_ops/test_ms_deformable_attn.py +++ b/tests/test_ops/test_ms_deformable_attn.py @@ -54,6 +54,26 @@ def test_multiscale_deformable_attention(device): spatial_shapes=spatial_shapes, level_start_index=level_start_index) + # test with value_proj_ratio + embed_dims = 6 + value_proj_ratio = 0.5 + query = torch.rand(num_query, bs, embed_dims).to(device) + key = torch.rand(num_query, bs, embed_dims).to(device) + msda = MultiScaleDeformableAttention( + embed_dims=embed_dims, + num_levels=2, + num_heads=3, + value_proj_ratio=value_proj_ratio) + msda.init_weights() + msda.to(device) + msda( + query, + key, + key, + reference_points=reference_points, + spatial_shapes=spatial_shapes, + level_start_index=level_start_index) + def test_forward_multi_scale_deformable_attn_pytorch(): N, M, D = 1, 2, 2