Merge pull request #2236 from NightMachinery/patch-1

eva.py: fixed bug in applying attention mask
This commit is contained in:
Ross Wightman 2024-07-19 08:09:56 -07:00 committed by GitHub
commit 474c9cf768
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -134,10 +134,12 @@ class EvaAttention(nn.Module):
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
attn = attn.softmax(dim=-1)
if attn_mask is not None:
attn_mask = attn_mask.to(torch.bool)
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v