F.sdpa for visformer fails w/o contiguous on qkv, make experimental
parent
cf1884bfeb
commit
3eaf729f3f
|
@ -80,7 +80,7 @@ class Attention(nn.Module):
|
|||
head_dim = round(dim // num_heads * head_dim_ratio)
|
||||
self.head_dim = head_dim
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
self.fused_attn = use_fused_attn(experimental=True)
|
||||
|
||||
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
|
@ -94,7 +94,7 @@ class Attention(nn.Module):
|
|||
|
||||
if self.fused_attn:
|
||||
x = torch.nn.functional.scaled_dot_product_attention(
|
||||
q, k, v,
|
||||
q.contiguous(), k.contiguous(), v.contiguous(),
|
||||
dropout_p=self.attn_drop.p,
|
||||
)
|
||||
else:
|
||||
|
|
Loading…
Reference in New Issue