fix all SDPA dropouts

This commit is contained in:
Yassine 2023-10-04 14:30:19 -07:00 committed by Ross Wightman
parent b500cae4c5
commit 884ef88818
14 changed files with 22 additions and 22 deletions

View File

@ -155,7 +155,7 @@ class Attention(nn.Module):
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(
q, k, v, q, k, v,
attn_mask=rel_pos_bias, attn_mask=rel_pos_bias,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale

View File

@ -50,7 +50,7 @@ class ClassAttn(nn.Module):
if self.fused_attn: if self.fused_attn:
x_cls = torch.nn.functional.scaled_dot_product_attention( x_cls = torch.nn.functional.scaled_dot_product_attention(
q, k, v, q, k, v,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale

View File

@ -126,7 +126,7 @@ class EvaAttention(nn.Module):
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(
q, k, v, q, k, v,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale

View File

@ -514,7 +514,7 @@ class Attention(nn.Module):
if self.fused_attn: if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention( x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, q, k, v,
dropout_p=self.attn_drop.p if self.training else 0.0, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale

View File

@ -190,7 +190,7 @@ class Attention2d(nn.Module):
k.transpose(-1, -2).contiguous(), k.transpose(-1, -2).contiguous(),
v.transpose(-1, -2).contiguous(), v.transpose(-1, -2).contiguous(),
attn_mask=attn_bias, attn_mask=attn_bias,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
).transpose(-1, -2).reshape(B, -1, H, W) ).transpose(-1, -2).reshape(B, -1, H, W)
else: else:
q = q * self.scale q = q * self.scale
@ -259,7 +259,7 @@ class AttentionCl(nn.Module):
x = torch.nn.functional.scaled_dot_product_attention( x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, q, k, v,
attn_mask=attn_bias, attn_mask=attn_bias,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale

View File

@ -198,7 +198,7 @@ class Attention(nn.Module):
if self.fused_attn: if self.fused_attn:
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(
q, k, v, q, k, v,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
attn = (q @ k.transpose(-2, -1)) * self.scale attn = (q @ k.transpose(-2, -1)) * self.scale

View File

@ -66,7 +66,7 @@ class Attention(nn.Module):
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple) q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
if self.fused_attn: if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p) x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
else: else:
q = q * self.scale q = q * self.scale
attn = q @ k.transpose(-2, -1) # (B, H, T, N, N) attn = q @ k.transpose(-2, -1) # (B, H, T, N, N)

View File

@ -130,7 +130,7 @@ class Attention(nn.Module):
k, v = kv.unbind(0) k, v = kv.unbind(0)
if self.fused_attn: if self.fused_attn:
x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p) x = F.scaled_dot_product_attention(q, k, v, dropout_p=self.attn_drop.p if self.training else 0.)
else: else:
q = q * self.scale q = q * self.scale
attn = q @ k.transpose(-2, -1) attn = q @ k.transpose(-2, -1)

View File

@ -164,7 +164,7 @@ class WindowAttention(nn.Module):
x = torch.nn.functional.scaled_dot_product_attention( x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, q, k, v,
attn_mask=attn_mask, attn_mask=attn_mask,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale

View File

@ -75,7 +75,7 @@ class LocallyGroupedAttn(nn.Module):
if self.fused_attn: if self.fused_attn:
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(
q, k, v, q, k, v,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale
@ -172,7 +172,7 @@ class GlobalSubSampleAttn(nn.Module):
if self.fused_attn: if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention( x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, q, k, v,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale

View File

@ -95,7 +95,7 @@ class Attention(nn.Module):
if self.fused_attn: if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention( x = torch.nn.functional.scaled_dot_product_attention(
q.contiguous(), k.contiguous(), v.contiguous(), q.contiguous(), k.contiguous(), v.contiguous(),
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
attn = (q @ k.transpose(-2, -1)) * self.scale attn = (q @ k.transpose(-2, -1)) * self.scale

View File

@ -85,7 +85,7 @@ class Attention(nn.Module):
if self.fused_attn: if self.fused_attn:
x = F.scaled_dot_product_attention( x = F.scaled_dot_product_attention(
q, k, v, q, k, v,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale
@ -285,7 +285,7 @@ class ParallelScalingBlock(nn.Module):
if self.fused_attn: if self.fused_attn:
x_attn = F.scaled_dot_product_attention( x_attn = F.scaled_dot_product_attention(
q, k, v, q, k, v,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale

View File

@ -71,7 +71,7 @@ class RelPosAttention(nn.Module):
x = torch.nn.functional.scaled_dot_product_attention( x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, q, k, v,
attn_mask=attn_bias, attn_mask=attn_bias,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale

View File

@ -168,7 +168,7 @@ class Attention(nn.Module):
x = torch.nn.functional.scaled_dot_product_attention( x = torch.nn.functional.scaled_dot_product_attention(
q, k, v, q, k, v,
attn_mask=attn_bias, attn_mask=attn_bias,
dropout_p=self.attn_drop.p, dropout_p=self.attn_drop.p if self.training else 0.,
) )
else: else:
q = q * self.scale q = q * self.scale