mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
A few more maybe_add_mask situations
This commit is contained in:
parent
dd2c1418d0
commit
842a786626
@ -4,6 +4,7 @@ import torch
|
|||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
|
|
||||||
|
from .attention import maybe_add_mask
|
||||||
from .config import use_fused_attn
|
from .config import use_fused_attn
|
||||||
from .mlp import Mlp
|
from .mlp import Mlp
|
||||||
from .weight_init import trunc_normal_tf_
|
from .weight_init import trunc_normal_tf_
|
||||||
@ -95,8 +96,7 @@ class AttentionPoolLatent(nn.Module):
|
|||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
if attn_mask is not None:
|
attn = maybe_add_mask(attn, attn_mask)
|
||||||
attn = attn + attn_mask
|
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
x = attn @ v
|
x = attn @ v
|
||||||
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
x = x.transpose(1, 2).reshape(B, self.latent_len, C)
|
||||||
|
@ -43,7 +43,7 @@ from timm.data import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCE
|
|||||||
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
OPENAI_CLIP_MEAN, OPENAI_CLIP_STD
|
||||||
from timm.layers import Attention, PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, \
|
from timm.layers import Attention, PatchEmbed, Mlp, DropPath, AttentionPoolLatent, RmsNorm, PatchDropout, \
|
||||||
SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
SwiGLUPacked, SwiGLU, trunc_normal_, lecun_normal_, resample_patch_embed, resample_abs_pos_embed, use_fused_attn, \
|
||||||
get_act_layer, get_norm_layer, LayerType
|
get_act_layer, get_norm_layer, LayerType, maybe_add_mask
|
||||||
from ._builder import build_model_with_cfg
|
from ._builder import build_model_with_cfg
|
||||||
from ._features import feature_take_indices
|
from ._features import feature_take_indices
|
||||||
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
from ._manipulate import named_apply, checkpoint_seq, adapt_input_conv
|
||||||
@ -256,8 +256,7 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
if attn_mask is not None:
|
attn = maybe_add_mask(attn, attn_mask)
|
||||||
attn = attn + attn_mask
|
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
x_attn = attn @ v
|
x_attn = attn @ v
|
||||||
|
@ -823,7 +823,7 @@ class VisionTransformerFlex(nn.Module):
|
|||||||
attn_mask: Optional[torch.Tensor] = None,
|
attn_mask: Optional[torch.Tensor] = None,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
|
|
||||||
if attn_mask is None and patch_valid is not None:
|
if attn_mask is None:
|
||||||
attn_mask = create_attention_mask(
|
attn_mask = create_attention_mask(
|
||||||
patch_valid,
|
patch_valid,
|
||||||
num_prefix_tokens=self.num_prefix_tokens,
|
num_prefix_tokens=self.num_prefix_tokens,
|
||||||
|
Loading…
x
Reference in New Issue
Block a user