mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Add 'qkv_bias_separate' flag for EVA/beit/swinv2 attn modules to allow an override for easy quantization wrappers. Fix #2098
This commit is contained in:
parent
83c2c2f0c5
commit
f81b094aaa
@ -86,6 +86,7 @@ class Attention(nn.Module):
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qkv_bias_separate: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
window_size: Optional[Tuple[int, int]] = None,
|
||||
@ -99,6 +100,7 @@ class Attention(nn.Module):
|
||||
all_head_dim = head_dim * self.num_heads
|
||||
self.scale = head_dim ** -0.5
|
||||
self.fused_attn = use_fused_attn()
|
||||
self.qkv_bias_separate = qkv_bias_separate
|
||||
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
@ -136,8 +138,15 @@ class Attention(nn.Module):
|
||||
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
|
||||
B, N, C = x.shape
|
||||
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
if self.q_bias is None:
|
||||
qkv = self.qkv(x)
|
||||
else:
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
|
||||
if self.qkv_bias_separate:
|
||||
qkv = self.qkv(x)
|
||||
qkv += qkv_bias
|
||||
else:
|
||||
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
|
||||
|
||||
|
@ -54,6 +54,7 @@ class EvaAttention(nn.Module):
|
||||
qkv_bias: bool = True,
|
||||
qkv_fused: bool = True,
|
||||
num_prefix_tokens: int = 1,
|
||||
qkv_bias_separate: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
attn_head_dim: Optional[int] = None,
|
||||
@ -80,6 +81,7 @@ class EvaAttention(nn.Module):
|
||||
self.scale = head_dim ** -0.5
|
||||
self.num_prefix_tokens = num_prefix_tokens
|
||||
self.fused_attn = use_fused_attn()
|
||||
self.qkv_bias_separate = qkv_bias_separate
|
||||
|
||||
if qkv_fused:
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
@ -111,8 +113,15 @@ class EvaAttention(nn.Module):
|
||||
B, N, C = x.shape
|
||||
|
||||
if self.qkv is not None:
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
if self.q_bias is None:
|
||||
qkv = self.qkv(x)
|
||||
else:
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
|
||||
if self.qkv_bias_separate:
|
||||
qkv = self.qkv(x)
|
||||
qkv += qkv_bias
|
||||
else:
|
||||
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
|
||||
else:
|
||||
|
@ -86,6 +86,7 @@ class WindowAttention(nn.Module):
|
||||
window_size: Tuple[int, int],
|
||||
num_heads: int,
|
||||
qkv_bias: bool = True,
|
||||
qkv_bias_separate: bool = False,
|
||||
attn_drop: float = 0.,
|
||||
proj_drop: float = 0.,
|
||||
pretrained_window_size: Tuple[int, int] = (0, 0),
|
||||
@ -95,6 +96,7 @@ class WindowAttention(nn.Module):
|
||||
self.window_size = window_size # Wh, Ww
|
||||
self.pretrained_window_size = pretrained_window_size
|
||||
self.num_heads = num_heads
|
||||
self.qkv_bias_separate = qkv_bias_separate
|
||||
|
||||
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
|
||||
|
||||
@ -156,10 +158,16 @@ class WindowAttention(nn.Module):
|
||||
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
|
||||
if self.q_bias is None:
|
||||
qkv = self.qkv(x)
|
||||
else:
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
if self.qkv_bias_separate:
|
||||
qkv = self.qkv(x)
|
||||
qkv += qkv_bias
|
||||
else:
|
||||
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user