mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
fix all SDPA dropouts
This commit is contained in:
parent
b500cae4c5
commit
884ef88818
@ -155,7 +155,7 @@ class Attention(nn.Module):
|
|||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=rel_pos_bias,
|
attn_mask=rel_pos_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -50,7 +50,7 @@ class ClassAttn(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x_cls = torch.nn.functional.scaled_dot_product_attention(
|
x_cls = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -126,7 +126,7 @@ class EvaAttention(nn.Module):
|
|||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -514,7 +514,7 @@ class Attention(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -190,7 +190,7 @@ class Attention2d(nn.Module):
|
|||||||
k.transpose(-1, -2).contiguous(),
|
k.transpose(-1, -2).contiguous(),
|
||||||
v.transpose(-1, -2).contiguous(),
|
v.transpose(-1, -2).contiguous(),
|
||||||
attn_mask=attn_bias,
|
attn_mask=attn_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
).transpose(-1, -2).reshape(B, -1, H, W)
|
).transpose(-1, -2).reshape(B, -1, H, W)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
@ -259,7 +259,7 @@ class AttentionCl(nn.Module):
|
|||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_bias,
|
attn_mask=attn_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -198,7 +198,7 @@ class Attention(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
@ -66,7 +66,7 @@ class Attention(nn.Module):
|
|||||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
|
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)
|
attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)
|
||||||
|
@ -130,7 +130,7 @@ class Attention(nn.Module):
|
|||||||
k, v = kv.unbind(0)
|
k, v = kv.unbind(0)
|
||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p)
|
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = q @ k.transpose(-2, -1)
|
attn = q @ k.transpose(-2, -1)
|
||||||
|
@ -164,7 +164,7 @@ class WindowAttention(nn.Module):
|
|||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_mask,
|
attn_mask=attn_mask,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -75,7 +75,7 @@ class LocallyGroupedAttn(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
@ -172,7 +172,7 @@ class GlobalSubSampleAttn(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -95,7 +95,7 @@ class Attention(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q.contiguous(), k.contiguous(), v.contiguous(),
|
q.contiguous(), k.contiguous(), v.contiguous(),
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
@ -85,7 +85,7 @@ class Attention(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = F.scaled_dot_product_attention(
|
x = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
@ -285,7 +285,7 @@ class ParallelScalingBlock(nn.Module):
|
|||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x_attn = F.scaled_dot_product_attention(
|
x_attn = F.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -71,7 +71,7 @@ class RelPosAttention(nn.Module):
|
|||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_bias,
|
attn_mask=attn_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
@ -168,7 +168,7 @@ class Attention(nn.Module):
|
|||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q, k, v,
|
||||||
attn_mask=attn_bias,
|
attn_mask=attn_bias,
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
|
Loading…
x
Reference in New Issue
Block a user