mirror of
https://github.com/open-mmlab/mmpretrain.git
synced 2025-06-03 14:59:18 +08:00
[Fix] Fix attenstion clamp max params (#1034)
This commit is contained in:
parent
0b4a67dd31
commit
6ebb3f77ad
@ -261,10 +261,7 @@ class WindowMSAV2(BaseModule):
|
|||||||
attn = (
|
attn = (
|
||||||
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1))
|
||||||
logit_scale = torch.clamp(
|
logit_scale = torch.clamp(
|
||||||
self.logit_scale,
|
self.logit_scale, max=np.log(1. / 0.01)).exp()
|
||||||
max=torch.log(
|
|
||||||
torch.tensor(1. / 0.01,
|
|
||||||
device=self.logit_scale.device))).exp()
|
|
||||||
attn = attn * logit_scale
|
attn = attn * logit_scale
|
||||||
|
|
||||||
relative_position_bias_table = self.cpb_mlp(
|
relative_position_bias_table = self.cpb_mlp(
|
||||||
|
Loading…
x
Reference in New Issue
Block a user