mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Cleanup qkv_bias cat in beit model so it can be traced
This commit is contained in:
parent
1076a65df1
commit
f2006b2437
@ -407,7 +407,6 @@ def test_model_backward_fx(model_name, batch_size):
|
||||
|
||||
# reason: model is scripted after fx tracing, but beit has torch.jit.is_scripting() control flow
|
||||
EXCLUDE_FX_JIT_FILTERS = [
|
||||
'beit_*',
|
||||
'deit_*_distilled_patch16_224',
|
||||
'levit*',
|
||||
'pit_*_distilled_224',
|
||||
|
@ -86,9 +86,11 @@ class Attention(nn.Module):
|
||||
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
|
||||
if qkv_bias:
|
||||
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
self.register_buffer('k_bias', torch.zeros(all_head_dim), persistent=False)
|
||||
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
|
||||
else:
|
||||
self.q_bias = None
|
||||
self.k_bias = None
|
||||
self.v_bias = None
|
||||
|
||||
if window_size:
|
||||
@ -127,13 +129,7 @@ class Attention(nn.Module):
|
||||
|
||||
def forward(self, x, rel_pos_bias: Optional[torch.Tensor] = None):
|
||||
B, N, C = x.shape
|
||||
qkv_bias = None
|
||||
if self.q_bias is not None:
|
||||
if torch.jit.is_scripting():
|
||||
# FIXME requires_grad breaks w/ torchscript
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias), self.v_bias))
|
||||
else:
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||
qkv_bias = torch.cat((self.q_bias, self.k_bias, self.v_bias)) if self.q_bias is not None else None
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
Loading…
x
Reference in New Issue
Block a user