mirror of https://github.com/open-mmlab/mmcv.git
[Fix] Fix init weights of MultiScaleDeformableAttention (#2158)
* fix tensors on different device * fix lint * Update mmcv/ops/multi_scale_deform_attn.py Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com> Co-authored-by: Zaida Zhou <58739961+zhouzaida@users.noreply.github.com>pull/2188/head
parent
b2ac245602
commit
f527e43c1a
|
@ -235,9 +235,10 @@ class MultiScaleDeformableAttention(BaseModule):
|
|||
def init_weights(self) -> None:
|
||||
"""Default initialization for Parameters of Module."""
|
||||
constant_init(self.sampling_offsets, 0.)
|
||||
device = next(self.parameters()).device
|
||||
thetas = torch.arange(
|
||||
self.num_heads,
|
||||
dtype=torch.float32) * (2.0 * math.pi / self.num_heads)
|
||||
self.num_heads, dtype=torch.float32,
|
||||
device=device) * (2.0 * math.pi / self.num_heads)
|
||||
grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
|
||||
grid_init = (grid_init /
|
||||
grid_init.abs().max(-1, keepdim=True)[0]).view(
|
||||
|
|
Loading…
Reference in New Issue