[Fix] Fix attenstion clamp max params (#1034)
parent
0b4a67dd31
commit
6ebb3f77ad
|
@ -261,10 +261,7 @@ class WindowMSAV2(BaseModule):
|
|||
attn = (
|
||||
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
||||
logit_scale = torch.clamp(
|
||||
self.logit_scale,
|
||||
max=torch.log(
|
||||
torch.tensor(1. / 0.01,
|
||||
device=self.logit_scale.device))).exp()
|
||||
self.logit_scale, max=np.log(1. / 0.01)).exp()
|
||||
attn = attn * logit_scale
|
||||
|
||||
relative_position_bias_table = self.cpb_mlp(
|
||||
|
|
Loading…
Reference in New Issue