Fixed unfused attn2d scale
parent
851e0746a9
commit
2d734d9058
|
@ -1,7 +1,8 @@
|
|||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
|
||||
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d
|
||||
|
||||
import importlib
|
||||
import os
|
||||
|
@ -119,3 +120,27 @@ def test_get_act_fn_none():
|
|||
assert get_act_fn(None) is None
|
||||
assert get_act_fn('') is None
|
||||
|
||||
|
||||
@pytest.mark.parametrize("bias", [True, False])
|
||||
@pytest.mark.parametrize("expand_first", [True, False])
|
||||
@pytest.mark.parametrize("head_first", [True, False])
|
||||
@pytest.mark.parametrize("attn_mask", [True, False])
|
||||
def test_attn2d(bias, expand_first, head_first, attn_mask):
|
||||
x = torch.randn(1, 128, 32, 48)
|
||||
attn = Attention2d(
|
||||
128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first
|
||||
)
|
||||
|
||||
if attn_mask:
|
||||
mask = torch.randint(0, 1, size=(32 * 48, 32 * 48), dtype=torch.float32)
|
||||
else:
|
||||
mask = None
|
||||
|
||||
o1 = attn(x, mask)
|
||||
attn.fused_attn = False
|
||||
o2 = attn(x, mask)
|
||||
|
||||
assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"
|
||||
|
||||
|
||||
|
|
@ -312,7 +312,6 @@ class Attention2d(nn.Module):
|
|||
self.num_heads = num_heads
|
||||
self.dim_head = dim_attn // num_heads
|
||||
self.head_first = head_first
|
||||
self.scale = num_heads ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
|
||||
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
|
||||
|
@ -337,14 +336,15 @@ class Attention2d(nn.Module):
|
|||
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||
).transpose(-1, -2).reshape(B, -1, H, W)
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q.transpose(-2, -1) @ k
|
||||
q = q.transpose(-1, -2)
|
||||
v = v.transpose(-1, -2)
|
||||
attn = q @ k * q.size(-1) ** -0.5
|
||||
if attn_mask is not None:
|
||||
# NOTE: assumes mask is float and in correct shape
|
||||
attn = attn + attn_mask
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
||||
x = (attn @ v).transpose(-1, -2).reshape(B, -1, H, W)
|
||||
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
|
|
Loading…
Reference in New Issue