mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix tracing of attention module with attn_mask support
This commit is contained in:
parent
162f49295e
commit
dd2c1418d0
@ -1,7 +1,7 @@
|
|||||||
from .activations import *
|
from .activations import *
|
||||||
from .adaptive_avgmax_pool import \
|
from .adaptive_avgmax_pool import \
|
||||||
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
adaptive_avgmax_pool2d, select_adaptive_pool2d, AdaptiveAvgMaxPool2d, SelectAdaptivePool2d
|
||||||
from .attention import Attention, AttentionRope
|
from .attention import Attention, AttentionRope, maybe_add_mask
|
||||||
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
from .attention2d import MultiQueryAttention2d, Attention2d, MultiQueryAttentionV2
|
||||||
from .attention_pool import AttentionPoolLatent
|
from .attention_pool import AttentionPoolLatent
|
||||||
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
from .attention_pool2d import AttentionPool2d, RotAttentionPool2d, RotaryEmbedding
|
||||||
|
@ -8,6 +8,11 @@ from .config import use_fused_attn
|
|||||||
from .pos_embed_sincos import apply_rot_embed_cat
|
from .pos_embed_sincos import apply_rot_embed_cat
|
||||||
|
|
||||||
|
|
||||||
|
@torch.fx.wrap
|
||||||
|
def maybe_add_mask(scores: torch.Tensor, attn_mask: Optional[torch.Tensor] = None):
|
||||||
|
return scores if attn_mask is None else scores + attn_mask
|
||||||
|
|
||||||
|
|
||||||
class Attention(nn.Module):
|
class Attention(nn.Module):
|
||||||
"""Standard Multi-head Self Attention module with QKV projection.
|
"""Standard Multi-head Self Attention module with QKV projection.
|
||||||
|
|
||||||
@ -74,8 +79,7 @@ class Attention(nn.Module):
|
|||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
if attn_mask is not None:
|
attn = maybe_add_mask(attn, attn_mask)
|
||||||
attn = attn + attn_mask
|
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
x = attn @ v
|
x = attn @ v
|
||||||
@ -196,10 +200,7 @@ class AttentionRope(nn.Module):
|
|||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = (q @ k.transpose(-2, -1))
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
attn = maybe_add_mask(attn, attn_mask)
|
||||||
if attn_mask is not None:
|
|
||||||
attn_mask = attn_mask.to(torch.bool)
|
|
||||||
attn = attn.masked_fill(~attn_mask[:, None, None, :], float("-inf"))
|
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
|
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user