A few more maybe_add_mask situations

This commit is contained in:
Ross Wightman 2025-05-25 08:51:56 -07:00
parent dd2c1418d0
commit 842a786626
3 changed files with 5 additions and 6 deletions

View File

@ -4,6 +4,7 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
from .attention import maybe_add_mask
from .config import use_fused_attn
from .mlp import Mlp
from .weight_init import trunc_normal_tf_
@ -95,8 +96,7 @@ class AttentionPoolLatent(nn.Module):
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn + attn_mask
attn = maybe_add_mask(attn, attn_mask)
attn = attn.softmax(dim=-1)
x = attn @ v
x = x.transpose(1, 2).reshape(B, self.latent_len, C)

View File

@ -43,7 +43,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
from timm.layers import Attention, PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, \
SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
get_act_layer, get_norm_layer, LayerType
get_act_layer, get_norm_layer, LayerType, maybe_add_mask
from ._builder import build_model_with_cfg
from ._features import feature_take_indices
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
@ -256,8 +256,7 @@ class ParallelScalingBlock(nn.Module):
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
if attn_mask is not None:
attn = attn + attn_mask
attn = maybe_add_mask(attn, attn_mask)
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x_attn = attn @ v

View File

@ -823,7 +823,7 @@ class VisionTransformerFlex(nn.Module):
attn_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if attn_mask is None and patch_valid is not None:
if attn_mask is None:
attn_mask = create_attention_mask(
patch_valid,
num_prefix_tokens=self.num_prefix_tokens,