[Fix] Fix init weights of MultiScaleDeformableAttention (#2158)

* fix tensors on different device

* fix lint

* Update mmcv/ops/multi_scale_deform_attn.py

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>

Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>
pull/2188/head
Cedric Luo 2022-08-01 14:12:45 +08:00 committed by Zaida Zhou
parent b2ac245602
commit f527e43c1a
1 changed files with 3 additions and 2 deletions

View File

@ -235,9 +235,10 @@ class MultiScaleDeformableAttention(BaseModule):
def init_weights(self) -> None:
"""Default initialization for Parameters of Module."""
constant_init(self.sampling_offsets, 0.)
device = next(self.parameters()).device
thetas = torch.arange(
self.num_heads,
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
self.num_heads, dtype=torch.float32,
device=device) * (2.0 * math.pi / self.num_heads)
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
grid_init = (grid_init /
grid_init.abs().max(-1, keepdim=True)[0]).view(