F.sdpa for visformer fails w/o contiguous on qkv, make experimental
parent
cf1884bfeb
commit
3eaf729f3f
|
@ -80,7 +80,7 @@ class Attention(nn.Module):
|
||||||
head_dim = round(dim // num_heads * head_dim_ratio)
|
head_dim = round(dim // num_heads * head_dim_ratio)
|
||||||
self.head_dim = head_dim
|
self.head_dim = head_dim
|
||||||
self.scale = head_dim ** -0.5
|
self.scale = head_dim ** -0.5
|
||||||
self.fused_attn = use_fused_attn()
|
self.fused_attn = use_fused_attn(experimental=True)
|
||||||
|
|
||||||
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
|
self.qkv = nn.Conv2d(dim, head_dim * num_heads * 3, 1, stride=1, padding=0, bias=False)
|
||||||
self.attn_drop = nn.Dropout(attn_drop)
|
self.attn_drop = nn.Dropout(attn_drop)
|
||||||
|
@ -94,7 +94,7 @@ class Attention(nn.Module):
|
||||||
|
|
||||||
if self.fused_attn:
|
if self.fused_attn:
|
||||||
x = torch.nn.functional.scaled_dot_product_attention(
|
x = torch.nn.functional.scaled_dot_product_attention(
|
||||||
q, k, v,
|
q.contiguous(), k.contiguous(), v.contiguous(),
|
||||||
dropout_p=self.attn_drop.p,
|
dropout_p=self.attn_drop.p,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
|
Loading…
Reference in New Issue