Do not force using FlashAttention (#58)
Remove hardcoded selection of operator implementation and use xFormers fMHA dispatcher instead.pull/62/head
parent
ca58ffcd87
commit
c0ffb6ed71
|
@ -73,11 +73,7 @@ class MemEffAttention(Attention):
|
|||
|
||||
q, k, v = unbind(qkv, 2)
|
||||
|
||||
if attn_bias is not None:
|
||||
self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp
|
||||
else:
|
||||
self_att_op = None
|
||||
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=self_att_op)
|
||||
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
||||
x = x.reshape([B, N, C])
|
||||
|
||||
x = self.proj(x)
|
||||
|
|
Loading…
Reference in New Issue