mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
A few more maybe_add_mask situations
This commit is contained in:
parent
dd2c1418d0
commit
842a786626
@ -4,6 +4,7 @@ import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from .attention import maybe_add_mask
|
||||
from .config import use_fused_attn
|
||||
from .mlp import Mlp
|
||||
from .weight_init import trunc_normal_tf_
|
||||
@ -95,8 +96,7 @@ class AttentionPoolLatent(nn.Module):
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = maybe_add_mask(attn, attn_mask)
|
||||
attn = attn.softmax(dim=-1)
|
||||
x = attn @ v
|
||||
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
||||
|
@ -43,7 +43,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||
from timm.layers import Attention, PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, \
|
||||
SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||
get_act_layer, get_norm_layer, LayerType
|
||||
get_act_layer, get_norm_layer, LayerType, maybe_add_mask
|
||||
from ._builder import build_model_with_cfg
|
||||
from ._features import feature_take_indices
|
||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||
@ -256,8 +256,7 @@ class ParallelScalingBlock(nn.Module):
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
if attn_mask is not None:
|
||||
attn = attn + attn_mask
|
||||
attn = maybe_add_mask(attn, attn_mask)
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x_attn = attn @ v
|
||||
|
@ -823,7 +823,7 @@ class VisionTransformerFlex(nn.Module):
|
||||
attn_mask: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
|
||||
if attn_mask is None and patch_valid is not None:
|
||||
if attn_mask is None:
|
||||
attn_mask = create_attention_mask(
|
||||
patch_valid,
|
||||
num_prefix_tokens=self.num_prefix_tokens,
|
||||
|
Loading…
x
Reference in New Issue
Block a user