mirror of
https://github.com/huggingface/pytorch-image-models.git
synced 2025-06-03 15:01:08 +08:00
Merge pull request #933 from t-vi/unbind
use .unbind instead of explicitly listing the indices
This commit is contained in:
commit
7da1b0b61c
@ -136,7 +136,7 @@ class Attention(nn.Module):
|
|||||||
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
|
||||||
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
|
||||||
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = (q @ k.transpose(-2, -1))
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
@ -81,7 +81,7 @@ class Attention(nn.Module):
|
|||||||
B, T, N, C = x.shape
|
B, T, N, C = x.shape
|
||||||
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
|
# result of next line is (qkv, B, num (H)eads, T, N, (C')hannels per head)
|
||||||
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
qkv = self.qkv(x).reshape(B, T, N, 3, self.num_heads, C // self.num_heads).permute(3, 0, 4, 1, 2, 5)
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, N, N)
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
|
@ -172,7 +172,7 @@ class WindowAttention(nn.Module):
|
|||||||
"""
|
"""
|
||||||
B_, N, C = x.shape
|
B_, N, C = x.shape
|
||||||
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
qkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
q = q * self.scale
|
q = q * self.scale
|
||||||
attn = (q @ k.transpose(-2, -1))
|
attn = (q @ k.transpose(-2, -1))
|
||||||
|
@ -61,7 +61,7 @@ class Attention(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
qk = self.qk(x).reshape(B, N, 2, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
||||||
q, k = qk[0], qk[1] # make torchscript happy (cannot use tensor as tuple)
|
q, k = qk.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
v = self.v(x).reshape(B, N, self.num_heads, -1).permute(0, 2, 1, 3)
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
|
@ -190,7 +190,7 @@ class Attention(nn.Module):
|
|||||||
def forward(self, x):
|
def forward(self, x):
|
||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
attn = (q @ k.transpose(-2, -1)) * self.scale
|
attn = (q @ k.transpose(-2, -1)) * self.scale
|
||||||
attn = attn.softmax(dim=-1)
|
attn = attn.softmax(dim=-1)
|
||||||
|
@ -267,7 +267,7 @@ class XCA(nn.Module):
|
|||||||
B, N, C = x.shape
|
B, N, C = x.shape
|
||||||
# Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
|
# Result of next line is (qkv, B, num (H)eads, (C')hannels per head, N)
|
||||||
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 4, 1)
|
||||||
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
q, k, v = qkv.unbind(0) # make torchscript happy (cannot use tensor as tuple)
|
||||||
|
|
||||||
# Paper section 3.2 l2-Normalization and temperature scaling
|
# Paper section 3.2 l2-Normalization and temperature scaling
|
||||||
q = torch.nn.functional.normalize(q, dim=-1)
|
q = torch.nn.functional.normalize(q, dim=-1)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user