F.sdpa for visformer fails w/o contiguous on qkv, make experimental

pull/1804/head
Ross Wightman 2023-05-11 11:37:10 -07:00
parent cf1884bfeb
commit 3eaf729f3f
1 changed files with 2 additions and 2 deletions

View File

@ -80,7 +80,7 @@ class Attention(nn.Module):
head_dim = round(dim // num_heads * head_dim_ratio)
self.head_dim = head_dim
self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.fused_attn = use_fused_attn(experimental=True)
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
self.attn_drop = nn.Dropout(attn_drop)
@ -94,7 +94,7 @@ class Attention(nn.Module):
if self.fused_attn:
x = torch.nn.functional.scaled_dot_product_attention(
q, k, v,
q.contiguous(), k.contiguous(), v.contiguous(),
dropout_p=self.attn_drop.p,
)
else: