mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge branch 'main' into fix-mqa-v2
This commit is contained in:
commit
2d5277e858
@ -2,7 +2,7 @@ import pytest
|
|||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
|
||||||
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, MultiQueryAttentionV2
|
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, Attention2d, MultiQueryAttentionV2
|
||||||
|
|
||||||
import importlib
|
import importlib
|
||||||
import os
|
import os
|
||||||
@ -120,6 +120,7 @@ def test_get_act_fn_none():
|
|||||||
assert get_act_fn(None) is None
|
assert get_act_fn(None) is None
|
||||||
assert get_act_fn('') is None
|
assert get_act_fn('') is None
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("dim", [128])
|
@pytest.mark.parametrize("dim", [128])
|
||||||
@pytest.mark.parametrize("dim_out", [128, 256])
|
@pytest.mark.parametrize("dim_out", [128, 256])
|
||||||
@pytest.mark.parametrize("use_m", [True, False])
|
@pytest.mark.parametrize("use_m", [True, False])
|
||||||
@ -134,4 +135,26 @@ def test_mqa_v2(dim, dim_out, use_m):
|
|||||||
|
|
||||||
y = mqa(x, m=m)
|
y = mqa(x, m=m)
|
||||||
|
|
||||||
assert (y.shape) == (1, dim_out, 32, 48)
|
assert (y.shape) == (1, dim_out, 32, 48)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("bias", [True, False])
|
||||||
|
@pytest.mark.parametrize("expand_first", [True, False])
|
||||||
|
@pytest.mark.parametrize("head_first", [True, False])
|
||||||
|
@pytest.mark.parametrize("attn_mask", [True, False])
|
||||||
|
def test_attn2d(bias, expand_first, head_first, attn_mask):
|
||||||
|
x = torch.randn(1, 128, 32, 48)
|
||||||
|
attn = Attention2d(
|
||||||
|
128, 128, num_heads=4, bias=bias, expand_first=expand_first, head_first=head_first
|
||||||
|
)
|
||||||
|
|
||||||
|
if attn_mask:
|
||||||
|
mask = torch.randint(0, 1, size=(32 * 48, 32 * 48), dtype=torch.float32)
|
||||||
|
else:
|
||||||
|
mask = None
|
||||||
|
|
||||||
|
o1 = attn(x, mask)
|
||||||
|
attn.fused_attn = False
|
||||||
|
o2 = attn(x, mask)
|
||||||
|
|
||||||
|
assert torch.allclose(o1, o2, atol=1e-5), f"{torch.abs(o1 - o2).max()}"
|
||||||
|
@ -312,7 +312,6 @@ class Attention2d(nn.Module):
|
|||||||
self.num_heads = num_heads
|
self.num_heads = num_heads
|
||||||
self.dim_head = dim_attn // num_heads
|
self.dim_head = dim_attn // num_heads
|
||||||
self.head_first = head_first
|
self.head_first = head_first
|
||||||
self.scale = num_heads ** -0.5
|
|
||||||
self.fused_attn = use_fused_attn()
|
self.fused_attn = use_fused_attn()
|
||||||
|
|
||||||
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
|
self.qkv = nn.Conv2d(dim, dim_attn * 3, 1, bias=bias)
|
||||||
@ -337,14 +336,15 @@ class Attention2d(nn.Module):
|
|||||||
dropout_p=self.attn_drop.p if self.training else 0.,
|
dropout_p=self.attn_drop.p if self.training else 0.,
|
||||||
).transpose(-1, -2).reshape(B, -1, H, W)
|
).transpose(-1, -2).reshape(B, -1, H, W)
|
||||||
else:
|
else:
|
||||||
q = q * self.scale
|
q = q.transpose(-1, -2)
|
||||||
attn = q.transpose(-2, -1) @ k
|
v = v.transpose(-1, -2)
|
||||||
|
attn = q @ k * q.size(-1) ** -0.5
|
||||||
if attn_mask is not None:
|
if attn_mask is not None:
|
||||||
# NOTE: assumes mask is float and in correct shape
|
# NOTE: assumes mask is float and in correct shape
|
||||||
attn = attn + attn_mask
|
attn = attn + attn_mask
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
attn = self.attn_drop(attn)
|
attn = self.attn_drop(attn)
|
||||||
x = (v @ attn.transpose(-2, -1)).view(B, -1, H, W)
|
x = (attn @ v).transpose(-1, -2).reshape(B, -1, H, W)
|
||||||
|
|
||||||
x = self.proj(x)
|
x = self.proj(x)
|
||||||
x = self.proj_drop(x)
|
x = self.proj_drop(x)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user