mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix tracing of attention module with attn_mask support
This commit is contained in:
parent
162f49295e
commit
dd2c1418d0
@ -1,7 +1,7 @@
|
||||
from .activations import *
|
||||
from .adaptive_avgmax_pool import \
|
||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||
from .attention import Attention, AttentionRope
|
||||
from .attention import Attention, AttentionRope, maybe_add_mask
|
||||
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
||||
from .attention_pool import AttentionPoolLatent
|
||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||
|
@ -8,6 +8,11 @@ from .config import use_fused_attn
|
||||
from .pos_embed_sincos import apply_rot_embed_cat
|
||||
|
||||
|
||||
@torch.fx.wrap
|
||||
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||
return scores if attn_mask is None else scores + attn_mask
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""Standard Multi-head Self Attention module with QKV projection.
|
||||
|
||||
@ -74,8 +79,7 @@ class Attention(nn.Module):
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = maybe_add_mask(attn, attn_mask)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v
|
||||
@ -196,10 +200,7 @@ class AttentionRope(nn.Module):
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_mask = attn_mask.to(torch.bool)
|
||||
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
||||
attn = maybe_add_mask(attn, attn_mask)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
||||
attn = self.attn_drop(attn)
|
||||
|
Loading…
x
Reference in New Issue
Block a user