Update attention.py
parent
827a216155
commit
b64bba71e1
|
@ -38,7 +38,7 @@ def scaled_dot_product_attention_pyimpl(query,
|
|||
attn_mask = torch.ones(
|
||||
query.size(-2), key.size(-2), dtype=torch.bool).tril(diagonal=0)
|
||||
if attn_mask is not None and attn_mask.dtype == torch.bool:
|
||||
attn_mask = attn_mask.masked_fill(not attn_mask, -float('inf'))
|
||||
attn_mask = attn_mask.masked_fill(~attn_mask, -float('inf'))
|
||||
|
||||
attn_weight = query @ key.transpose(-2, -1) / scale
|
||||
if attn_mask is not None:
|
||||
|
|
Loading…
Reference in New Issue