Do not force using FlashAttention (#58)

Remove hardcoded selection of operator implementation and use xFormers fMHA dispatcher instead.
pull/62/head
Patrick Labatut 2023-04-26 02:26:24 +02:00 committed by GitHub
parent ca58ffcd87
commit c0ffb6ed71
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
1 changed files with 1 additions and 5 deletions

View File

@ -73,11 +73,7 @@ class MemEffAttention(Attention):
q, k, v = unbind(qkv, 2)
if attn_bias is not None:
self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp
else:
self_att_op = None
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=self_att_op)
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
x = x.reshape([B, N, C])
x = self.proj(x)