mirror of
https://github.com/open-mmlab/mmcv.git
synced 2025-06-03 21:54:52 +08:00
[Enhancement] Support value_proj_ratio in MultiScaleDeformableAttention (#2452)
* add ratio in ms_deform_attn_ * add ratio in ms_deform_attn * Update mmcv/ops/multi_scale_deform_attn.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * Update tests/test_ops/test_ms_deformable_attn.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> * add ratio in ms_deform_attn * add ratio in ms_deform_attn * add ratio in ms_deform_attn * add ratio in ms_deform_attn * add ratio in ms_deform_attn * add ratio in ms_deform_attn Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
This commit is contained in:
parent
fb39e1e568
commit
433607030a
@ -182,6 +182,8 @@ class MultiScaleDeformableAttention(BaseModule):
|
|||||||
Default: None.
|
Default: None.
|
||||||
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
init_cfg (obj:`mmcv.ConfigDict`): The Config for initialization.
|
||||||
Default: None.
|
Default: None.
|
||||||
|
value_proj_ratio (float): The expansion ratio of value_proj.
|
||||||
|
Default: 1.0.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -193,7 +195,8 @@ class MultiScaleDeformableAttention(BaseModule):
|
|||||||
dropout: float = 0.1,
|
dropout: float = 0.1,
|
||||||
batch_first: bool = False,
|
batch_first: bool = False,
|
||||||
norm_cfg: Optional[dict] = None,
|
norm_cfg: Optional[dict] = None,
|
||||||
init_cfg: Optional[mmengine.ConfigDict] = None):
|
init_cfg: Optional[mmengine.ConfigDict] = None,
|
||||||
|
value_proj_ratio: float = 1.0):
|
||||||
super().__init__(init_cfg)
|
super().__init__(init_cfg)
|
||||||
if embed_dims % num_heads != 0:
|
if embed_dims % num_heads != 0:
|
||||||
raise ValueError(f'embed_dims must be divisible by num_heads, '
|
raise ValueError(f'embed_dims must be divisible by num_heads, '
|
||||||
@ -228,8 +231,9 @@ class MultiScaleDeformableAttention(BaseModule):
|
|||||||
embed_dims, num_heads * num_levels * num_points * 2)
|
embed_dims, num_heads * num_levels * num_points * 2)
|
||||||
self.attention_weights = nn.Linear(embed_dims,
|
self.attention_weights = nn.Linear(embed_dims,
|
||||||
num_heads * num_levels * num_points)
|
num_heads * num_levels * num_points)
|
||||||
self.value_proj = nn.Linear(embed_dims, embed_dims)
|
value_proj_size = int(embed_dims * value_proj_ratio)
|
||||||
self.output_proj = nn.Linear(embed_dims, embed_dims)
|
self.value_proj = nn.Linear(embed_dims, value_proj_size)
|
||||||
|
self.output_proj = nn.Linear(value_proj_size, embed_dims)
|
||||||
self.init_weights()
|
self.init_weights()
|
||||||
|
|
||||||
def init_weights(self) -> None:
|
def init_weights(self) -> None:
|
||||||
|
@ -54,6 +54,26 @@ def test_multiscale_deformable_attention(device):
|
|||||||
spatial_shapes=spatial_shapes,
|
spatial_shapes=spatial_shapes,
|
||||||
level_start_index=level_start_index)
|
level_start_index=level_start_index)
|
||||||
|
|
||||||
|
# test with value_proj_ratio
|
||||||
|
embed_dims = 6
|
||||||
|
value_proj_ratio = 0.5
|
||||||
|
query = torch.rand(num_query, bs, embed_dims).to(device)
|
||||||
|
key = torch.rand(num_query, bs, embed_dims).to(device)
|
||||||
|
msda = MultiScaleDeformableAttention(
|
||||||
|
embed_dims=embed_dims,
|
||||||
|
num_levels=2,
|
||||||
|
num_heads=3,
|
||||||
|
value_proj_ratio=value_proj_ratio)
|
||||||
|
msda.init_weights()
|
||||||
|
msda.to(device)
|
||||||
|
msda(
|
||||||
|
query,
|
||||||
|
key,
|
||||||
|
key,
|
||||||
|
reference_points=reference_points,
|
||||||
|
spatial_shapes=spatial_shapes,
|
||||||
|
level_start_index=level_start_index)
|
||||||
|
|
||||||
|
|
||||||
def test_forward_multi_scale_deformable_attn_pytorch():
|
def test_forward_multi_scale_deformable_attn_pytorch():
|
||||||
N, M, D = 1, 2, 2
|
N, M, D = 1, 2, 2
|
||||||
|
Loading…
x
Reference in New Issue
Block a user