[Fix] Fix device mismatch in Swin-v2. (#976)
parent
ec71d071d8
commit
517bd3d34b
|
@ -261,7 +261,10 @@ class WindowMSAV2(BaseModule):
|
|||
attn = (
|
||||
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
||||
logit_scale = torch.clamp(
|
||||
self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
|
||||
self.logit_scale,
|
||||
max=torch.log(
|
||||
torch.tensor(1. / 0.01,
|
||||
device=self.logit_scale.device))).exp()
|
||||
attn = attn * logit_scale
|
||||
|
||||
relative_position_bias_table = self.cpb_mlp(
|
||||
|
|
Loading…
Reference in New Issue