diff --git a/timm/models/beit.py b/timm/models/beit.py index 57007cd7..90068a18 100644 --- a/timm/models/beit.py +++ b/timm/models/beit.py @@ -86,6 +86,7 @@ class Attention(nn.Module): dim: int, num_heads: int = 8, qkv_bias: bool = False, + qkv_bias_separate: bool = False, attn_drop: float = 0., proj_drop: float = 0., window_size: Optional[Tuple[int, int]] = None, @@ -99,6 +100,7 @@ class Attention(nn.Module): all_head_dim = head_dim * self.num_heads self.scale = head_dim ** -0.5 self.fused_attn = use_fused_attn() + self.qkv_bias_separate = qkv_bias_separate self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) if qkv_bias: @@ -136,8 +138,15 @@ class Attention(nn.Module): def forward(self, x, shared_rel_pos_bias: Optional[torch.Tensor] = None): B, N, C = x.shape - qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + if self.q_bias is None: + qkv = self.qkv(x) + else: + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) + if self.qkv_bias_separate: + qkv = self.qkv(x) + qkv += qkv_bias + else: + qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim diff --git a/timm/models/eva.py b/timm/models/eva.py index 7a1b67e1..ea29b103 100644 --- a/timm/models/eva.py +++ b/timm/models/eva.py @@ -54,6 +54,7 @@ class EvaAttention(nn.Module): qkv_bias: bool = True, qkv_fused: bool = True, num_prefix_tokens: int = 1, + qkv_bias_separate: bool = False, attn_drop: float = 0., proj_drop: float = 0., attn_head_dim: Optional[int] = None, @@ -80,6 +81,7 @@ class EvaAttention(nn.Module): self.scale = head_dim ** -0.5 self.num_prefix_tokens = num_prefix_tokens self.fused_attn = use_fused_attn() + self.qkv_bias_separate = qkv_bias_separate if qkv_fused: self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False) @@ -111,8 +113,15 @@ class EvaAttention(nn.Module): B, N, C = x.shape if self.qkv is not None: - qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + if self.q_bias is None: + qkv = self.qkv(x) + else: + qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) + if self.qkv_bias_separate: + qkv = self.qkv(x) + qkv += qkv_bias + else: + qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0) # B, num_heads, N, head_dim else: diff --git a/timm/models/swin_transformer_v2.py b/timm/models/swin_transformer_v2.py index 7bf91032..9651189c 100644 --- a/timm/models/swin_transformer_v2.py +++ b/timm/models/swin_transformer_v2.py @@ -86,6 +86,7 @@ class WindowAttention(nn.Module): window_size: Tuple[int, int], num_heads: int, qkv_bias: bool = True, + qkv_bias_separate: bool = False, attn_drop: float = 0., proj_drop: float = 0., pretrained_window_size: Tuple[int, int] = (0, 0), @@ -95,6 +96,7 @@ class WindowAttention(nn.Module): self.window_size = window_size # Wh, Ww self.pretrained_window_size = pretrained_window_size self.num_heads = num_heads + self.qkv_bias_separate = qkv_bias_separate self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1)))) @@ -156,10 +158,16 @@ class WindowAttention(nn.Module): mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None """ B_, N, C = x.shape - qkv_bias = None - if self.q_bias is not None: + + if self.q_bias is None: + qkv = self.qkv(x) + else: qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) - qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + if self.qkv_bias_separate: + qkv = self.qkv(x) + qkv += qkv_bias + else: + qkv = F.linear(x, weight=self.qkv.weight, bias=qkv_bias) qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) q, k, v = qkv.unbind(0)