mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Fix MQA V2 scale and out shape
This commit is contained in:
parent
851e0746a9
commit
6171e756d3
@ -1,7 +1,8 @@
|
||||
import pytest
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn
|
||||
from timm.layers import create_act_layer, set_layer_config, get_act_layer, get_act_fn, MultiQueryAttentionV2
|
||||
|
||||
import importlib
|
||||
import os
|
||||
@ -119,3 +120,18 @@ def test_get_act_fn_none():
|
||||
assert get_act_fn(None) is None
|
||||
assert get_act_fn('') is None
|
||||
|
||||
@pytest.mark.parametrize("dim", [128])
|
||||
@pytest.mark.parametrize("dim_out", [128, 256])
|
||||
@pytest.mark.parametrize("use_m", [True, False])
|
||||
def test_mqa_v2(dim, dim_out, use_m):
|
||||
mqa = MultiQueryAttentionV2(dim, dim_out)
|
||||
|
||||
x = torch.randn(1, dim, 32, 48)
|
||||
if use_m:
|
||||
m = torch.randn(1, dim, 16, 24)
|
||||
else:
|
||||
m = None
|
||||
|
||||
y = mqa(x, m=m)
|
||||
|
||||
assert (y.shape) == (1, dim_out, 32, 48)
|
@ -59,8 +59,8 @@ class MultiQueryAttentionV2(nn.Module):
|
||||
|
||||
def forward(self, x, m: Optional[torch.Tensor] = None):
|
||||
"""Run layer computation."""
|
||||
s = x.shape
|
||||
m = m or x
|
||||
b, _, h, w = x.shape
|
||||
m = m if m is not None else x
|
||||
|
||||
reshaped_x = self._reshape_input(x)
|
||||
reshaped_m = self._reshape_input(m)
|
||||
@ -68,15 +68,15 @@ class MultiQueryAttentionV2(nn.Module):
|
||||
q = torch.einsum('bnd,hkd->bnhk', reshaped_x, self.query_proj)
|
||||
k = torch.einsum('bmd,dk->bmk', reshaped_m, self.key_proj)
|
||||
|
||||
attn = torch.einsum('bnhk,bmk->bnhm', q, k)
|
||||
attn = torch.einsum('bnhk,bmk->bnhm', q, k) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
|
||||
v = torch.einsum('bmd,dv->bmv', reshaped_m, self.value_proj)
|
||||
o = torch.einsum('bnhm,bmv->bnhv', attn, v)
|
||||
result = torch.einsum('bnhv,dhv->bnd', o, self.out_proj)
|
||||
result = torch.einsum('bnhv,dhv->bdn', o, self.out_proj)
|
||||
result = self.proj_drop(result)
|
||||
return result.reshape(s)
|
||||
return result.reshape(b, -1, h, w)
|
||||
|
||||
|
||||
class MultiQueryAttention2d(nn.Module):
|
||||
|
Loading…
x
Reference in New Issue
Block a user