[Fix] Fix attenstion clamp max params (#1034)

pull/1059/head
Hubert 2022-09-26 14:12:51 +08:00 committed by GitHub
parent 0b4a67dd31
commit 6ebb3f77ad
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 4 deletions

View File

@ -261,10 +261,7 @@ class WindowMSAV2(BaseModule):
attn = (
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
logit_scale = torch.clamp(
self.logit_scale,
max=torch.log(
torch.tensor(1. / 0.01,
device=self.logit_scale.device))).exp()
self.logit_scale, max=np.log(1. / 0.01)).exp()
attn = attn * logit_scale
relative_position_bias_table = self.cpb_mlp(