mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #933 from t-vi/unbind
use .unbind instead of explicitly listing the indices
This commit is contained in:
commit
7da1b0b61c
@ -136,7 +136,7 @@ class Attention(nn.Module):
|
||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
|
@ -81,7 +81,7 @@ class Attention(nn.Module):
|
||||
B, T, N, C = x.shape
|
||||
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
|
||||
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
||||
attn = attn.softmax(dim=-1)
|
||||
|
@ -172,7 +172,7 @@ class WindowAttention(nn.Module):
|
||||
"""
|
||||
B_, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
q = q * self.scale
|
||||
attn = (q @ k.transpose(-2, -1))
|
||||
@ -649,4 +649,4 @@ def swin_large_patch4_window7_224_in22k(pretrained=False, **kwargs):
|
||||
"""
|
||||
model_kwargs = dict(
|
||||
patch_size=4, window_size=7, embed_dim=192, depths=(2, 2, 18, 2), num_heads=(6, 12, 24, 48), **kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
|
||||
return _create_swin_transformer('swin_large_patch4_window7_224_in22k', pretrained=pretrained, **model_kwargs)
|
||||
|
@ -61,7 +61,7 @@ class Attention(nn.Module):
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||
q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
|
@ -190,7 +190,7 @@ class Attention(nn.Module):
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||
attn = attn.softmax(dim=-1)
|
||||
@ -893,4 +893,4 @@ def vit_base_patch16_224_miil(pretrained=False, **kwargs):
|
||||
"""
|
||||
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12, qkv_bias=False, **kwargs)
|
||||
model = _create_vision_transformer('vit_base_patch16_224_miil', pretrained=pretrained, **model_kwargs)
|
||||
return model
|
||||
return model
|
||||
|
@ -267,7 +267,7 @@ class XCA(nn.Module):
|
||||
B, N, C = x.shape
|
||||
# Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
|
||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
|
||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
||||
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||
|
||||
# Paper section 3.2 l2-Normalization and temperature scaling
|
||||
q = torch.nn.functional.normalize(q, dim=-1)
|
||||
|
Loading…
x
Reference in New Issue
Block a user