[Fix] Fix device mismatch in Swin-v2. (#976)

pull/1034/head
Andrey Moskalenko 2022-09-01 13:03:49 +03:00 committed by GitHub
parent ec71d071d8
commit 517bd3d34b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 4 additions and 1 deletions

View File

@ -261,7 +261,10 @@ class WindowMSAV2(BaseModule):
attn = (
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(
self.logit_scale, max=torch.log(torch.tensor(1. / 0.01))).exp()
self.logit_scale,
max=torch.log(
torch.tensor(1. / 0.01,
device=self.logit_scale.device))).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(