mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Enhancement] Support value_proj_ratio in MultiScaleDeformableAttention (#2452)
* add ratio in ms_deform_attn_ * add ratio in ms_deform_attn * Update mmcv/ops/multi_scale_deform_attn.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_ops/test_ms_deformable_attn.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add ratio in ms_deform_attn * add ratio in ms_deform_attn * add ratio in ms_deform_attn * add ratio in ms_deform_attn * add ratio in ms_deform_attn * add ratio in ms_deform_attn Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
fb39e1e568
commit
433607030a
@ -182,6 +182,8 @@ class MultiScaleDeformableAttention(BaseModule):
|
||||
Default: None.
|
||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||
Default: None.
|
||||
value_proj_ratio (float): The expansion ratio of value_proj.
|
||||
Default: 1.0.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
@ -193,7 +195,8 @@ class MultiScaleDeformableAttention(BaseModule):
|
||||
dropout: float = 0.1,
|
||||
batch_first: bool = False,
|
||||
norm_cfg: Optional[dict] = None,
|
||||
init_cfg: Optional[mmengine.ConfigDict] = None):
|
||||
init_cfg: Optional[mmengine.ConfigDict] = None,
|
||||
value_proj_ratio: float = 1.0):
|
||||
super().__init__(init_cfg)
|
||||
if embed_dims % num_heads != 0:
|
||||
raise ValueError(f'embed_dims must be divisible by num_heads, '
|
||||
@ -228,8 +231,9 @@ class MultiScaleDeformableAttention(BaseModule):
|
||||
embed_dims, num_heads * num_levels * num_points * 2)
|
||||
self.attention_weights = nn.Linear(embed_dims,
|
||||
num_heads * num_levels * num_points)
|
||||
self.value_proj = nn.Linear(embed_dims, embed_dims)
|
||||
self.output_proj = nn.Linear(embed_dims, embed_dims)
|
||||
value_proj_size = int(embed_dims * value_proj_ratio)
|
||||
self.value_proj = nn.Linear(embed_dims, value_proj_size)
|
||||
self.output_proj = nn.Linear(value_proj_size, embed_dims)
|
||||
self.init_weights()
|
||||
|
||||
def init_weights(self) -> None:
|
||||
|
@ -54,6 +54,26 @@ def test_multiscale_deformable_attention(device):
|
||||
spatial_shapes=spatial_shapes,
|
||||
level_start_index=level_start_index)
|
||||
|
||||
# test with value_proj_ratio
|
||||
embed_dims = 6
|
||||
value_proj_ratio = 0.5
|
||||
query = torch.rand(num_query, bs, embed_dims).to(device)
|
||||
key = torch.rand(num_query, bs, embed_dims).to(device)
|
||||
msda = MultiScaleDeformableAttention(
|
||||
embed_dims=embed_dims,
|
||||
num_levels=2,
|
||||
num_heads=3,
|
||||
value_proj_ratio=value_proj_ratio)
|
||||
msda.init_weights()
|
||||
msda.to(device)
|
||||
msda(
|
||||
query,
|
||||
key,
|
||||
key,
|
||||
reference_points=reference_points,
|
||||
spatial_shapes=spatial_shapes,
|
||||
level_start_index=level_start_index)
|
||||
|
||||
|
||||
def test_forward_multi_scale_deformable_attn_pytorch():
|
||||
N, M, D = 1, 2, 2
|
||||
|
Loading…
x
Reference in New Issue
Block a user