From f527e43c1a1e52e5903410255aea1f71f005a161 Mon Sep 17 00:00:00 2001 From: Cedric Luo Date: Mon, 1 Aug 2022 14:12:45 +0800 Subject: [PATCH] [Fix] Fix init weights of MultiScaleDeformableAttention (#2158) * fix tensors on different device * fix lint * Update mmcv/ops/multi_scale_deform_attn.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> --- mmcv/ops/multi_scale_deform_attn.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mmcv/ops/multi_scale_deform_attn.py b/mmcv/ops/multi_scale_deform_attn.py index a06466fa5..3ac343d8a 100644 --- a/mmcv/ops/multi_scale_deform_attn.py +++ b/mmcv/ops/multi_scale_deform_attn.py @@ -235,9 +235,10 @@ class MultiScaleDeformableAttention(BaseModule): def init_weights(self) -> None: """Default initialization for Parameters of Module.""" constant_init(self.sampling_offsets, 0.) + device = next(self.parameters()).device thetas = torch.arange( - self.num_heads, - dtype=torch.float32) * (2.0 * math.pi / self.num_heads) + self.num_heads, dtype=torch.float32, + device=device) * (2.0 * math.pi / self.num_heads) grid_init = torch.stack([thetas.cos(), thetas.sin()], -1) grid_init = (grid_init / grid_init.abs().max(-1, keepdim=True)[0]).view(