Update attention.py

pull/1754/head
Duo Li 2023-08-08 17:47:08 +08:00 committed by GitHub
parent 827a216155
commit b64bba71e1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 1 deletions

View File

@ -38,7 +38,7 @@ def scaled_dot_product_attention_pyimpl(query,
attn_mask = torch.ones(
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
if attn_mask is not None and attn_mask.dtype == torch.bool:
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
attn_mask = attn_mask.masked_fill(~attn_mask, -float('inf'))
attn_weight = query @ key.transpose(-2, -1) / scale
if attn_mask is not None: