diff --git a/dinov2/layers/attention.py b/dinov2/layers/attention.py index c789ebd..1f9b0c9 100644 --- a/dinov2/layers/attention.py +++ b/dinov2/layers/attention.py @@ -73,11 +73,7 @@ class MemEffAttention(Attention): q, k, v = unbind(qkv, 2) - if attn_bias is not None: - self_att_op = fmha.MemoryEfficientAttentionFlashAttentionOp - else: - self_att_op = None - x = memory_efficient_attention(q, k, v, attn_bias=attn_bias, op=self_att_op) + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) x = x.reshape([B, N, C]) x = self.proj(x)