Add 'qkv_bias_separate' flag for EVA/beit/swinv2 attn modules to allow an override for easy quantization wrappers. Fix #2098

This commit is contained in:
Ross Wightman 2024-07-08 13:48:06 -07:00
parent 83c2c2f0c5
commit f81b094aaa
3 changed files with 33 additions and 7 deletions

View File

@ -86,6 +86,7 @@ class Attention(nn.Module):
dim: int,
num_heads: int = 8,
qkv_bias: bool = False,
qkv_bias_separate: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
window_size: Optional[Tuple[int, int]] = None,
@ -99,6 +100,7 @@ class Attention(nn.Module):
all_head_dim = head_dim * self.num_heads
self.scale = head_dim ** -0.5
self.fused_attn = use_fused_attn()
self.qkv_bias_separate = qkv_bias_separate
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
@ -136,8 +138,15 @@ class Attention(nn.Module):
def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None):
B, N, C = x.shape
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
if self.q_bias is None:
qkv = self.qkv(x)
else:
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
if self.qkv_bias_separate:
qkv = self.qkv(x)
qkv += qkv_bias
else:
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim

View File

@ -54,6 +54,7 @@ class EvaAttention(nn.Module):
qkv_bias: bool = True,
qkv_fused: bool = True,
num_prefix_tokens: int = 1,
qkv_bias_separate: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
attn_head_dim: Optional[int] = None,
@ -80,6 +81,7 @@ class EvaAttention(nn.Module):
self.scale = head_dim ** -0.5
self.num_prefix_tokens = num_prefix_tokens
self.fused_attn = use_fused_attn()
self.qkv_bias_separate = qkv_bias_separate
if qkv_fused:
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
@ -111,8 +113,15 @@ class EvaAttention(nn.Module):
B, N, C = x.shape
if self.qkv is not None:
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
if self.q_bias is None:
qkv = self.qkv(x)
else:
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
if self.qkv_bias_separate:
qkv = self.qkv(x)
qkv += qkv_bias
else:
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim
else:

View File

@ -86,6 +86,7 @@ class WindowAttention(nn.Module):
window_size: Tuple[int, int],
num_heads: int,
qkv_bias: bool = True,
qkv_bias_separate: bool = False,
attn_drop: float = 0.,
proj_drop: float = 0.,
pretrained_window_size: Tuple[int, int] = (0, 0),
@ -95,6 +96,7 @@ class WindowAttention(nn.Module):
self.window_size = window_size # Wh, Ww
self.pretrained_window_size = pretrained_window_size
self.num_heads = num_heads
self.qkv_bias_separate = qkv_bias_separate
self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))))
@ -156,10 +158,16 @@ class WindowAttention(nn.Module):
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
if self.q_bias is None:
qkv = self.qkv(x)
else:
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
if self.qkv_bias_separate:
qkv = self.qkv(x)
qkv += qkv_bias
else:
qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)