mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix all SDPA dropouts
This commit is contained in:
parent
b500cae4c5
commit
884ef88818
@ -155,7 +155,7 @@ class Attention(nn.Module):
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=rel_pos_bias,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
@ -50,7 +50,7 @@ class ClassAttn(nn.Module):
|
||||
if self.fused_attn:
|
||||
x_cls = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
@ -126,7 +126,7 @@ class EvaAttention(nn.Module):
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
@ -514,7 +514,7 @@ class Attention(nn.Module):
|
||||
if self.fused_attn:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
@ -190,7 +190,7 @@ class Attention2d(nn.Module):
|
||||
k.transpose(-1, -2).contiguous(),
|
||||
v.transpose(-1, -2).contiguous(),
|
||||
attn_mask=attn_bias,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
).transpose(-1, -2).reshape(B, -1, H, W)
|
||||
else:
|
||||
q = q * self.scale
|
||||
@ -259,7 +259,7 @@ class AttentionCl(nn.Module):
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_bias,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
@ -198,7 +198,7 @@ class Attention(nn.Module):
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
@ -59,14 +59,14 @@ class Attention(nn.Module):
|
||||
def forward(self, x):
|
||||
"""
|
||||
x is shape: B (batch_size), T (image blocks), N (seq length per image block), C (embed dim)
|
||||
"""
|
||||
"""
|
||||
B, T, N, C = x.shape
|
||||
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
|
||||
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
|
||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)
|
||||
@ -330,7 +330,7 @@ class Nest(nn.Module):
|
||||
# Hint: (img_size // patch_size) gives number of patches along edge of image. sqrt(self.num_blocks[0]) is the
|
||||
# number of blocks along edge of image
|
||||
self.block_size = int((img_size // patch_size) // math.sqrt(self.num_blocks[0]))
|
||||
|
||||
|
||||
# Patch embedding
|
||||
self.patch_embed = PatchEmbed(
|
||||
img_size=img_size,
|
||||
|
@ -130,7 +130,7 @@ class Attention(nn.Module):
|
||||
k, v = kv.unbind(0)
|
||||
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
|
||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1)
|
||||
|
@ -164,7 +164,7 @@ class WindowAttention(nn.Module):
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_mask,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
@ -75,7 +75,7 @@ class LocallyGroupedAttn(nn.Module):
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
@ -172,7 +172,7 @@ class GlobalSubSampleAttn(nn.Module):
|
||||
if self.fused_attn:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
@ -95,7 +95,7 @@ class Attention(nn.Module):
|
||||
if self.fused_attn:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q.contiguous(), k.contiguous(), v.contiguous(),
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
@ -85,7 +85,7 @@ class Attention(nn.Module):
|
||||
if self.fused_attn:
|
||||
x = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
@ -285,7 +285,7 @@ class ParallelScalingBlock(nn.Module):
|
||||
if self.fused_attn:
|
||||
x_attn = F.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
@ -1151,7 +1151,7 @@ default_cfgs = generate_default_cfgs({
|
||||
url='https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth',
|
||||
hf_hub_id='timm/',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
|
||||
|
||||
# DINOv2 pretrained - https://arxiv.org/abs/2304.07193 (no classifier head, for fine-tune/features only)
|
||||
'vit_small_patch14_dinov2.lvd142m': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/dinov2/dinov2_vits14/dinov2_vits14_pretrain.pth',
|
||||
@ -1471,7 +1471,7 @@ default_cfgs = generate_default_cfgs({
|
||||
hf_hub_id='timm/',
|
||||
license='cc-by-nc-4.0',
|
||||
mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD, num_classes=0),
|
||||
|
||||
|
||||
'vit_huge_patch14_224_ijepa.in1k': _cfg(
|
||||
url='https://dl.fbaipublicfiles.com/ijepa/IN1K-vit.h.14-300e.pth.tar',
|
||||
# hf_hub_id='timm/',
|
||||
@ -2080,7 +2080,7 @@ def vit_giant_patch14_dinov2(pretrained=False, **kwargs) -> VisionTransformer:
|
||||
# With SwiGLUPacked, we need to set hidden_features = 2 * 4096 = 8192
|
||||
|
||||
model_args = dict(
|
||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
|
||||
patch_size=14, embed_dim=1536, depth=40, num_heads=24, init_values=1e-5,
|
||||
mlp_ratio=2.66667 * 2, mlp_layer=SwiGLUPacked, img_size=518, act_layer=nn.SiLU
|
||||
)
|
||||
model = _create_vision_transformer(
|
||||
|
@ -71,7 +71,7 @@ class RelPosAttention(nn.Module):
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_bias,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
@ -168,7 +168,7 @@ class Attention(nn.Module):
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
attn_mask=attn_bias,
|
||||
dropout_p=self.attn_drop.p,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
)
|
||||
else:
|
||||
q = q * self.scale
|
||||
|
Loading…
x
Reference in New Issue
Block a user