Fix tracing of attention module with attn_mask support

This commit is contained in:
Ross Wightman 2025-05-24 21:13:01 -07:00
parent 162f49295e
commit dd2c1418d0
2 changed files with 8 additions and 7 deletions

View File

@ -1,7 +1,7 @@
from .activations import *
from .adaptive_avgmax_pool import \
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
from .attention import Attention, AttentionRope
from .attention import Attention, AttentionRope, maybe_add_mask
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
from .attention_pool import AttentionPoolLatent
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding

View File

@ -8,6 +8,11 @@ from .config import use_fused_attn
from .pos_embed_sincos import apply_rot_embed_cat
@torch.fx.wrap
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
return scores if attn_mask is None else scores + attn_mask
class Attention(nn.Module):
"""Standard Multi-head Self Attention module with QKV projection.
@ -74,8 +79,7 @@ class Attention(nn.Module):
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn + attn_mask
attn = maybe_add_mask(attn, attn_mask)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = attn @ v
@ -196,10 +200,7 @@ class AttentionRope(nn.Module):
else:
q = q * self.scale
attn = (q @ k.transpose(-2, -1))
if attn_mask is not None:
attn_mask = attn_mask.to(torch.bool)
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
attn = maybe_add_mask(attn, attn_mask)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)